diff --git a/jiuwen/core/context/config.py b/jiuwen/core/context/config.py index 6a156f6fc7b872478afb709a5a2af8acd2df1641..e19045fd247f98f01846fd989c45ef306b6417e7 100644 --- a/jiuwen/core/context/config.py +++ b/jiuwen/core/context/config.py @@ -62,7 +62,7 @@ class Config(ABC): :param source_node_id: source node id :param target_node_id: target node id """ - self._stream_edges[target_node_id].append(source_node_id) + self._stream_edges[source_node_id].append(target_node_id) def set_stream_edges(self, edges: dict[str, list[str]]) -> None: """ diff --git a/jiuwen/core/context/memory/__init__.py b/jiuwen/core/context/memory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/jiuwen/core/context/memory/base.py b/jiuwen/core/context/memory/base.py new file mode 100644 index 0000000000000000000000000000000000000000..893a0e9a9734f0f1828ea5ede8dc0bada56c6c17 --- /dev/null +++ b/jiuwen/core/context/memory/base.py @@ -0,0 +1,78 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +from copy import deepcopy +from typing import Union, Optional, Any, Callable + +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.context.state import CommitState, StateLike, State +from jiuwen.core.context.utils import extract_origin_key, get_value_by_nested_path, update_dict + + +class InMemoryState(StateLike): + def __init__(self): + self._state: dict = dict() + + def get(self, key: Union[str, list, dict]) -> Optional[Any]: + if isinstance(key, str): + origin_key = extract_origin_key(key) + return get_value_by_nested_path(origin_key, self._state) + elif isinstance(key, dict): + result = {} + for target_key, target_schema in key.items(): + result[target_key] = self.get(target_schema) + return result + elif isinstance(key, list): + result = [] + for item in key: + result.append(self.get(item)) + return result + else: + raise JiuWenBaseException(1, "key type is not support") + + def get_by_transformer(self, transformer: Callable) -> Optional[Any]: + return transformer(self._state) + + def update(self, node_id: str, data: dict) -> None: + update_dict(self._state, data) + + +class InMemoryCommitState(CommitState): + def __init__(self): + self._state = InMemoryState() + self._updates: dict[str, list[dict]] = dict() + + def update(self, node_id: str, data: dict) -> None: + if node_id not in self._updates: + self._updates[node_id] = [] + self._updates[node_id].append(data) + + def commit(self) -> None: + for key, updates in self._updates.items(): + for update in updates: + self._state.update(key, update) + self._updates.clear() + + def get_updates(self, node_id: str) -> list[dict]: + if node_id not in self._updates: + return [] + return deepcopy(self._updates[node_id]) + + def rollback(self, failed_node_ids: list[str]) -> None: + for node_id in failed_node_ids: + self._updates[node_id] = [] + + def get_by_transformer(self, transformer: Callable) -> Optional[Any]: + return transformer(self._state) + + def get(self, key: Union[str, list, dict]) -> Optional[Any]: + return self._state.get(key) + +class InMemoryState(State): + def __init__(self): + super().__init__(io_state=InMemoryCommitState(), + global_state=InMemoryCommitState(), + trace_state=InMemoryState(), + comp_state=InMemoryState()) + + diff --git a/jiuwen/core/context/state.py b/jiuwen/core/context/state.py index 5d46198adfdd584acb20f03e97c4624830c9e88b..beb6e0097c395aa37166df96674c6fc5b1ba3c99 100644 --- a/jiuwen/core/context/state.py +++ b/jiuwen/core/context/state.py @@ -2,27 +2,43 @@ # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved from abc import ABC, abstractmethod -from typing import Any, Union, Optional +from typing import Any, Union, Optional, Callable class StateLike(ABC): @abstractmethod - def get(self, key: Union[str, dict]) -> Optional[Any]: + def get(self, key: Union[str, list, dict]) -> Optional[Any]: pass @abstractmethod - def update(self, data: dict) -> None: + def get_by_transformer(self, transformer: Callable) -> Optional[Any]: pass + @abstractmethod + def update(self, node_id: str, data: dict) -> None: + pass class CommitState(StateLike): @abstractmethod - def commit(self) -> dict: + def commit(self) -> None: + pass + + @abstractmethod + def rollback(self, failed_node_ids: list[str]) -> None: pass + @abstractmethod + def get_updates(self, node_id: str) -> list[dict]: + pass class State(ABC): - def __init__(self, io_state: CommitState, global_state: CommitState, trace_state: StateLike, comp_state: StateLike): + def __init__( + self, + io_state: CommitState, + global_state: CommitState, + trace_state: StateLike, + comp_state: StateLike + ): self._io_state = io_state self._global_state = global_state self._trace_state = trace_state @@ -49,20 +65,25 @@ class State(ABC): return None return self._global_state.get(key) - def update(self, data: dict) -> None: + def update(self, node_id: str, data: dict) -> None: if self._global_state is None: return - self._global_state.update(data) + self._global_state.update(node_id, data) - def update_io(self, data: dict) -> None: + def update_io(self, node_id: str, data: dict) -> None: if self._io_state is None: return - self._io_state.update(data) + self._io_state.update(node_id, data) - def update_trace(self, data: dict) -> None: + def update_trace(self, node_id: str, data: dict) -> None: if self._trace_state is None: return - self._trace_state.update(data) + self._trace_state.update(node_id, data) + + def update_comp(self, node_id: str, data: dict) -> None: + if self._comp_state is None: + return + self._comp_state.update(node_id, data) def set_user_inputs(self, inputs: dict) -> None: if self._io_state is None: @@ -74,13 +95,13 @@ class State(ABC): return {} return self._io_state.get(input_schemas) - def get_outputs(self, node_id: str)-> Any: + def get_outputs(self, node_id: str) -> Any: if self._io_state is None: return {} return self._io_state.get(node_id) - def set_outputs(self, outputs: dict)-> None: + def set_outputs(self, node_id: str, outputs: dict) -> None: if self._io_state is None: return + return self._io_state.update(node_id, outputs) - return self._io_state.update(outputs) diff --git a/jiuwen/core/context/utils.py b/jiuwen/core/context/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c01b43c4aef761ac3382b7790dabf3986f983dbe --- /dev/null +++ b/jiuwen/core/context/utils.py @@ -0,0 +1,168 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +import re +from typing import Optional, Any, Union + +from jiuwen.core.common.exception.exception import JiuWenBaseException + +REGEX_MAX_LENGTH = 1000 +NESTED_PATH_LIST_PATTERN = re.compile(r'^([\w]+)((?:\[\d+\])*)$') +NESTED_PATH_SPLIT = '.' +NESTED_PATH_LIST_SPLIT = "[" + + +def update_dict(update: dict, source: dict) -> None: + """ + update source dict by update dict + Note: source is unnested structure, update is nested structure + + :param update: update dict, which key is nested + :param source: source dict, which key must not be nested + """ + for key, value in update.items(): + current_key, current = root_to_path(key, source, create_if_absent=True) + update_by_key(current_key, value, current) + + +def get_value_by_nested_path(nested_key: str, source: dict) -> Optional[Any]: + result = root_to_path(nested_key, source) + if result is None: + return result + return result[1][result[0]] + + +def split_nested_path(nested_key: str) -> list: + ''' + split nested path + :param nested_key: path + :return: e.g. a_1.b.c[1].d -> ["a_1", "b", "c", 1, "d"] + ''' + + if (NESTED_PATH_SPLIT not in nested_key) and (NESTED_PATH_LIST_SPLIT not in nested_key): + return [] + final_list = [] + params = nested_key.split(NESTED_PATH_SPLIT) + pattern = re.compile(r'^([\w]+)((?:\[\d+\])*)$') + for param in params: + match = re.match(pattern, param) + if match: + index = match.group(2) + if len(index) > 0: + numbers = re.findall(r'\d+', index) + idxes = [int(num) if str.isdigit(num) else num for num in numbers] + if len(idxes) == 0: + raise JiuWenBaseException(1, "failed to split nested path") + final_list.append((match.group(1), idxes)) + else: + final_list.append(match.group(1)) + return final_list + + +def extract_origin_key(key: str) -> str: + """ + extract the origin key from given key if the given key is reference structure + e.g. "${start123.p2}" -> "start123.p2" + :param key: reference key + :return: origin key + """ + if '$' not in key: + return key + pattern = re.compile(r"\${(.+?)\}") + match = pattern.search(key, endpos=REGEX_MAX_LENGTH) + if match: + return match.group(1) + return key + +def update_by_key(key: Union[str, int], new_value: Any, source: dict) -> None: + if key not in source: + source[key] = expand_nested_structure(new_value) + return + if isinstance(source[key], dict) and isinstance(new_value, dict): + update_dict(new_value, source[key]) + else: + source[key] = expand_nested_structure(new_value) + + +def expand_nested_structure(data: Any) -> Any: + if isinstance(data, list) or isinstance(data, tuple): + result = [] + for item in data: + result.append(expand_nested_structure(item)) + return result + elif isinstance(data, dict): + result = {} + for key, value in data.items(): + current_key, current = root_to_path(key, result, create_if_absent=True) + current[current_key] = expand_nested_structure(value) + return result + else: + return data + + +def root_to_path(nested_path: str, source: dict, create_if_absent: bool = False) -> tuple[Union[str, int], dict]: + paths = split_nested_path(nested_path) + if len(paths) == 0: + return (nested_path, source) + current = source + for i in range(len(paths)): + path = paths[i] + if isinstance(path, str): + if path not in current: + if not create_if_absent: + return (None, None) + current[path] = {} + if i == len(paths) - 1: + return (path, current) + current = current[path] + else: + token = path[0] + if token not in current: + if not create_if_absent: + return (None, None) + current[token] = [] + current = current[token] + if i == len(paths) - 1: + return root_to_index(path[1], current, create_if_absent) + else: + idx, current = root_to_index(path[1], current, create_if_absent) + if current is None: + return (None, None) + current = current[idx] + return (None, None) + + +def root_to_index(idxes: list[int], source: dict, create_if_absent: bool = False) -> Optional[tuple[int, dict]]: + current = source + if len(idxes) > 1: + for idx in idxes[:-1]: + if idx >= len(current): + if not create_if_absent: + return None + current += [None] * (idx - len(source) - idx) + current.append([]) + current = current[idx] + if idxes[-1] >= len(source): + if not create_if_absent: + return None + current += [None] * (idxes[-1] - len(source)) + current.append({}) + return idxes[-1], current + + +if __name__ == '__main__': + source = {} + # 增加a.b: nums属性 + update_dict({"a.b.nums": [1, 2, 3]}, source) + print(source) + # 增加a.b: name属性 + update_dict({ + "a.b.name": "shanghai" + }, source) + print(source) + # 增加a.b: class属性 + update_dict({"a.b": {"class":"hha"}}, source) + print(source) + # 覆盖a.b所有 + update_dict({"a.b": [1,2,3]}, source) + print(source) diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index 6a9aef2c012fb55ba72edd5749f7e7dccc1b973c..f853d2b59f6d66181e5a94f3fcab20ad88e357fd 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -20,7 +20,7 @@ class Vertex: def __call__(self, state: GraphState) -> Output: if self._context is None: - raise JiuWenBaseException(1, "vertex is not initialized, node is is %s", self._node_id) + raise JiuWenBaseException(1, "vertex is not initialized, node is is " + self._node_id) inputs = self.__pre_invoke__(state) is_stream = self.__is_stream__(state) try: @@ -32,7 +32,7 @@ class Vertex: self.__post_invoke__(results) except JiuWenBaseException as e: raise JiuWenBaseException(e.error_code, "failed to invoke, caused by " + e.message) - return {"source": self._node_id} + return {"source_node_id": self._node_id} def __pre_invoke__(self, state:GraphState) -> Optional[dict]: inputs_schema = self._context.config.get_inputs_schema(self._node_id) diff --git a/jiuwen/core/stream/manager.py b/jiuwen/core/stream/manager.py index 40639f59f9b26bf3bb3f07150111453545996428..51f90968851a8e943a4c05cf931c08ea8c152ee4 100644 --- a/jiuwen/core/stream/manager.py +++ b/jiuwen/core/stream/manager.py @@ -48,7 +48,7 @@ class StreamWriterManager: self._writers[key] = writer def get_writer(self, key: StreamMode) -> Optional[StreamWriter]: - return self._writers.get[key] + return self._writers.get(key) def get_output_writer(self) -> Optional[StreamWriter]: return self.get_writer(BaseStreamMode.OUTPUT) diff --git a/jiuwen/core/stream/writer.py b/jiuwen/core/stream/writer.py index 8fc0dd2f18480a1291108738b53ef677e6bfbe58..040a7c828b274b56b5b4b571c7f52e14f4b307e3 100644 --- a/jiuwen/core/stream/writer.py +++ b/jiuwen/core/stream/writer.py @@ -64,7 +64,7 @@ class TraceStreamWriter(StreamWriter[dict, TraceSchema]): class CustomSchema(BaseModel): def __init__(self, **kwargs): - super().__init_(**kwargs) + super().__init__(**kwargs) class Config: arbitrary_types_allowed = True diff --git a/jiuwen/core/tracer/handler.py b/jiuwen/core/tracer/handler.py index 42103d11cd6e08ec058979d82c9bd9e0a0c43498..1e57e5189715c3d783575ce9c9bc2bc64c3822d2 100644 --- a/jiuwen/core/tracer/handler.py +++ b/jiuwen/core/tracer/handler.py @@ -1,10 +1,10 @@ import asyncio import copy import json -import random from abc import abstractmethod from datetime import datetime from enum import Enum +import threading from typing import Any from dateutil.tz import tzlocal @@ -36,20 +36,19 @@ class TraceBaseHandler(BaseHandler): @abstractmethod def _format_data(self, span: Span) -> dict: - return {"event_name": self.event_name(), "playload": span} + return {"type": self.event_name(), "payload": span} async def _emit_stream_writer(self, span): - # TODO 替换为使用TraceStreamWriter进行输出 - # await self._stream_writer.write(self._format_data(raw_data)) + await self._stream_writer.write(self._format_data(span)) - print(f"_emit_stream_writer: {self._format_data(span)}") - wait_time = random.randint(0, 10) - await asyncio.sleep(wait_time) - print(f"wait_time {wait_time}") + def _run_send_data_in_thread(self, span): + new_loop = asyncio.new_event_loop() + new_loop.run_until_complete(self.emit_stream_writer(span)) + new_loop.close() def _send_data(self, span): - print(f"send_data, {span}") - asyncio.create_task(self.emit_stream_writer(copy.deepcopy(span))) + threading.Thread(target=self._run_send_data_in_thread, args=(copy.deepcopy(span), )).start() + def _get_elapsed_time(self, start_time: datetime, end_time: datetime) -> str: """get elapsed time""" @@ -67,7 +66,7 @@ class TraceAgentHandler(TraceBaseHandler): return TracerHandlerName.TRACE_AGENT.value def _format_data(self, span: TraceAgentSpan) -> dict: - return {"event_name": self.event_name(), "playload": span.model_dump(by_alias=True)} + return {"type": self.event_name(), "payload": span.model_dump(by_alias=True)} def _update_start_trace_data(self, span: TraceAgentSpan, invoke_type: str, inputs: Any, instance_info: dict, **kwargs): @@ -115,13 +114,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_chain_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_chain_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_chain_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_chain_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -131,13 +130,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_llm_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_llm_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_llm_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_llm_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -147,13 +146,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_prompt_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_prompt_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_prompt_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_prompt_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -163,13 +162,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_plugin_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_plugin_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_plugin_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_plugin_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -179,13 +178,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_retriever_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_retriever_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_retriever_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_retriever_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -195,13 +194,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_evaluator_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_evaluator_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_evaluator_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_evaluator_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @@ -217,15 +216,17 @@ class TraceWorkflowHandler(TraceBaseHandler): return TracerHandlerName.TRACER_WORKFLOW.value def _format_data(self, span: TraceWorkflowSpan) -> dict: - status = NodeStatus.START.value + span.status = self._get_node_status(span) + return {"type": self.event_name(), "payload": span.model_dump(by_alias=True)} + + def _get_node_status(self, span: TraceWorkflowSpan) -> str: if span.error: - status = NodeStatus.ERROR.value + return NodeStatus.ERROR.value if span.on_invoke_data: - status = NodeStatus.RUNNING.value if not span.outputs else NodeStatus.FINISH.value + return NodeStatus.RUNNING.value if not span.outputs else NodeStatus.FINISH.value if span.end_time: - status = NodeStatus.FINISH.value if span.outputs else NodeStatus.RUNNING.value - span.status = status - return {"event_name": self.event_name(), "playload": span.model_dump(by_alias=True)} + return NodeStatus.FINISH.value if span.outputs else NodeStatus.RUNNING.value + return NodeStatus.START.value @trigger_event def on_pre_invoke(self, span: TraceWorkflowSpan, inputs: Any, component_metadata: dict, @@ -267,6 +268,8 @@ class TraceWorkflowHandler(TraceBaseHandler): "elapsed_time": self._get_elapsed_time(span.start_time, end_time) } else: + if not isinstance(span.on_invoke_data, list): + span.on_invoke_data = [] span.on_invoke_data.append(on_invoke_data) if span.component_type == "LLM": # TODO diff --git a/jiuwen/core/tracer/span.py b/jiuwen/core/tracer/span.py index 07c51aae6cb26f6a54299cd7827901844f037153..205f9547f3610a5d98f21cc1a515105e8f9792c9 100644 --- a/jiuwen/core/tracer/span.py +++ b/jiuwen/core/tracer/span.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime from typing import Optional, Dict, List, Callable -from pydantic import Field, BaseModel +from pydantic import ConfigDict, Field, BaseModel class Span(BaseModel): trace_id: str @@ -14,6 +14,8 @@ class Span(BaseModel): parent_invoke_id: Optional[str] = Field(default=None, alias="parentInvokeId") child_invokes_id: List[str] = Field(default=[], alias="childInvokes") + model_config = ConfigDict(populate_by_name=True) + def update(self, data: dict): for attr_name, value in data.items(): if not hasattr(self, attr_name): @@ -58,9 +60,9 @@ class SpanManager: def _create_span(self, span_class: Callable, parent_span = None): invoke_id = str(uuid.uuid4()) - span = span_class(invoke_id=invoke_id, parent_id=parent_span.invoke_id if parent_span else None, + span = span_class(invoke_id=invoke_id, parent_invoke_id=parent_span.invoke_id if parent_span else None, trace_id=self._trace_id) - + if parent_span: parent_span.child_invokes_id.append(span.invoke_id) self.refresh_span_record(parent_span.invoke_id, {parent_span.invoke_id: parent_span}) @@ -71,7 +73,7 @@ class SpanManager: def create_agent_span(self, parent_span: Optional[TraceAgentSpan] = None) -> TraceAgentSpan: return self._create_span(TraceAgentSpan, parent_span) - def create_workflow_agent(self, parent_span: Optional[TraceWorkflowSpan] = None) -> TraceWorkflowSpan: + def create_workflow_span(self, parent_span: Optional[TraceWorkflowSpan] = None) -> TraceWorkflowSpan: return self._create_span(TraceWorkflowSpan, parent_span) def update_span(self, span: Span, data: dict): diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index bb3027342786cf69eb82db6611818d70201e481e..5686512e59131e41bd9ed27f289af6f75919c3ff 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -3,6 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved import asyncio from dataclasses import dataclass, Field +from enum import Enum from functools import partial from typing import Self, Dict, Any, Union, AsyncIterator, Iterator @@ -30,6 +31,7 @@ class WorkflowChunk(BaseModel): is_final: bool = Field(default=False) + class Workflow: def __init__(self, workflow_config: WorkflowConfig, graph: Graph = None): self._graph = graph @@ -85,9 +87,9 @@ class Workflow: def add_stream_connection(self, src_comp_id: str, target_comp_id: str) -> Self: self._graph.add_edge(source_node_id=src_comp_id, target_node_id=target_comp_id) if target_comp_id not in self._stream_edges: - self._stream_edges[target_comp_id] = [src_comp_id] + self._stream_edges[src_comp_id] = [target_comp_id] else: - self._stream_edges[target_comp_id].append(src_comp_id) + self._stream_edges[src_comp_id].append(target_comp_id) return self def add_conditional_connection(self, src_comp_id: str, router: Router) -> Self: diff --git a/tests/tracer/test.py b/tests/tracer/test.py index 2cd2b09e40dc447821992377ceeea076c25e6023..4ea8d80d2449def3ebee46a636da315ae0ebf7a2 100644 --- a/tests/tracer/test.py +++ b/tests/tracer/test.py @@ -1,33 +1,92 @@ import asyncio +import time import uuid +import sys +import types +from unittest.mock import Mock + + +fake_base = types.ModuleType("base") +fake_base.logger = Mock() + +fake_exception_module = types.ModuleType("base") +fake_exception_module.JiuWenBaseException = Mock() + +sys.modules["jiuwen.core.common.logging.base"] = fake_base +sys.modules["jiuwen.core.common.exception.base"] = fake_exception_module + from jiuwen.core.runtime.callback_manager import CallbackManager +from jiuwen.core.stream.emitter import StreamEmitter from jiuwen.core.stream.manager import StreamWriterManager from jiuwen.core.tracer.handler import TraceAgentHandler, TraceWorkflowHandler from jiuwen.core.tracer.span import SpanManager + + def generate_tracer_id(): """ Generate tracer_id, which is also the execution_id. """ return str(uuid.uuid4()) + + +trace_id = generate_tracer_id() +callback_manager = CallbackManager() +stream_writer_manager = StreamWriterManager(StreamEmitter()) +trace_agent_span_manager = SpanManager(trace_id) +trace_workflow_span_manager = SpanManager(trace_id) +trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, trace_agent_span_manager) +trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, trace_workflow_span_manager) +callback_manager.register_handler({"tracer_agent": trace_agent_handler}) +callback_manager.register_handler({"tracer_workflow": trace_workflow_handler}) +tracer_agent_span = trace_agent_span_manager.create_agent_span() +tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() -async def main(): - trace_id = generate_tracer_id() - callback_manager = CallbackManager() - stream_writer_manager = StreamWriterManager() - trace_agent_span_manager = SpanManager(trace_id) - trace_workflow_span_manager = SpanManager(trace_id) - trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, trace_agent_span_manager) - trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, trace_workflow_span_manager) - callback_manager.register_handler({"tracer_agent": trace_agent_handler}) - callback_manager.register_handler({"tracer_workflow": trace_workflow_handler}) - tracer_agent_span = trace_agent_span_manager.create_agent_span() - tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() +def tracer_agent(): callback_manager.trigger("tracer_agent", "on_chain_start", span=tracer_agent_span, inputs={}, - instance_info={"class_name": "testagentnode"}) + instance_info={"class_name": "testagentnode"}) + +def tracer_workflow(): callback_manager.trigger("tracer_workflow", "on_pre_invoke", span=tracer_workflow_span, inputs={}, - component_metadata={"component_type": "testworkflownode"}) + component_metadata={"component_type": "testworkflownode"}) + +async def stream_output(): + async for data in stream_writer_manager.stream_output(): + print(f"Received data: {data}\n") + +class MockAgent: + def invoke(self): + tracer_agent_span = trace_agent_span_manager.create_agent_span() + callback_manager.trigger("tracer_agent", "on_chain_start", span=tracer_agent_span, inputs={}, + instance_info={"class_name": "Agent"}) + # 模拟运行 + time.sleep(2) + workflow = MockWorkflow() + workflow.invoke() + callback_manager.trigger("tracer_agent", "on_chain_end", span=tracer_agent_span, outputs={}) + +class MockWorkflow: + def invoke(self): + tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() + callback_manager.trigger("tracer_workflow", "on_pre_invoke", span=tracer_workflow_span, inputs={}, + component_metadata={"component_type": "Workflow"}) + # 模拟运行 + time.sleep(2) + callback_manager.trigger("tracer_workflow", "on_invoke", span=tracer_workflow_span, + on_invoke_data={"on_invoke": "data"}, + component_metadata={"component_type": "Workflow"}) + # 模拟运行 + time.sleep(2) + callback_manager.trigger("tracer_workflow", "on_post_invoke", span=tracer_workflow_span, inputs=None, + outputs={"outputs": "result"}) + + + + +async def test_agent_workflow_trace(): + agent = MockAgent() + agent.invoke() + await stream_output() -if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file +asyncio.run(test_agent_workflow_trace()) diff --git a/tests/unit_tests/stream/test_stream_output.py b/tests/unit_tests/stream/test_stream_output.py index a0a6e7987cfbdc0edd9af9a251ad34ccc4290ebc..7c1eb1c87551af95288c3531ea82a21c9f2768b3 100644 --- a/tests/unit_tests/stream/test_stream_output.py +++ b/tests/unit_tests/stream/test_stream_output.py @@ -1,6 +1,6 @@ import asyncio import unittest -from typing import AsyncIterator, Any, Type +from typing import AsyncIterator, Type from pydantic import BaseModel @@ -155,8 +155,8 @@ class TestStreamOutput(unittest.IsolatedAsyncioTestCase): async def write_data(): async for mock_data in mock_stream_output(): - trace_writer = self.manager.get_trace_writer() - await trace_writer.write(mock_data) + mock_writer = self.manager.get_writer(MockStreamMode.MOCK) + await mock_writer.write(mock_data) await self.emitter.close() async def read_data():