From 08eee4e47842ad4db3c604aaa23214549a7dc209 Mon Sep 17 00:00:00 2001 From: CandiceGuo Date: Fri, 11 Jul 2025 09:39:41 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0in-memory?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/context/config.py | 2 +- jiuwen/core/context/memory/__init__.py | 0 jiuwen/core/context/memory/base.py | 78 ++++++++++++ jiuwen/core/context/state.py | 49 +++++--- jiuwen/core/context/utils.py | 168 +++++++++++++++++++++++++ jiuwen/core/graph/vertex.py | 4 +- jiuwen/core/workflow/base.py | 6 +- 7 files changed, 288 insertions(+), 19 deletions(-) create mode 100644 jiuwen/core/context/memory/__init__.py create mode 100644 jiuwen/core/context/memory/base.py create mode 100644 jiuwen/core/context/utils.py diff --git a/jiuwen/core/context/config.py b/jiuwen/core/context/config.py index 6a156f6..e19045f 100644 --- a/jiuwen/core/context/config.py +++ b/jiuwen/core/context/config.py @@ -62,7 +62,7 @@ class Config(ABC): :param source_node_id: source node id :param target_node_id: target node id """ - self._stream_edges[target_node_id].append(source_node_id) + self._stream_edges[source_node_id].append(target_node_id) def set_stream_edges(self, edges: dict[str, list[str]]) -> None: """ diff --git a/jiuwen/core/context/memory/__init__.py b/jiuwen/core/context/memory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/context/memory/base.py b/jiuwen/core/context/memory/base.py new file mode 100644 index 0000000..893a0e9 --- /dev/null +++ b/jiuwen/core/context/memory/base.py @@ -0,0 +1,78 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +from copy import deepcopy +from typing import Union, Optional, Any, Callable + +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.context.state import CommitState, StateLike, State +from jiuwen.core.context.utils import extract_origin_key, get_value_by_nested_path, update_dict + + +class InMemoryState(StateLike): + def __init__(self): + self._state: dict = dict() + + def get(self, key: Union[str, list, dict]) -> Optional[Any]: + if isinstance(key, str): + origin_key = extract_origin_key(key) + return get_value_by_nested_path(origin_key, self._state) + elif isinstance(key, dict): + result = {} + for target_key, target_schema in key.items(): + result[target_key] = self.get(target_schema) + return result + elif isinstance(key, list): + result = [] + for item in key: + result.append(self.get(item)) + return result + else: + raise JiuWenBaseException(1, "key type is not support") + + def get_by_transformer(self, transformer: Callable) -> Optional[Any]: + return transformer(self._state) + + def update(self, node_id: str, data: dict) -> None: + update_dict(self._state, data) + + +class InMemoryCommitState(CommitState): + def __init__(self): + self._state = InMemoryState() + self._updates: dict[str, list[dict]] = dict() + + def update(self, node_id: str, data: dict) -> None: + if node_id not in self._updates: + self._updates[node_id] = [] + self._updates[node_id].append(data) + + def commit(self) -> None: + for key, updates in self._updates.items(): + for update in updates: + self._state.update(key, update) + self._updates.clear() + + def get_updates(self, node_id: str) -> list[dict]: + if node_id not in self._updates: + return [] + return deepcopy(self._updates[node_id]) + + def rollback(self, failed_node_ids: list[str]) -> None: + for node_id in failed_node_ids: + self._updates[node_id] = [] + + def get_by_transformer(self, transformer: Callable) -> Optional[Any]: + return transformer(self._state) + + def get(self, key: Union[str, list, dict]) -> Optional[Any]: + return self._state.get(key) + +class InMemoryState(State): + def __init__(self): + super().__init__(io_state=InMemoryCommitState(), + global_state=InMemoryCommitState(), + trace_state=InMemoryState(), + comp_state=InMemoryState()) + + diff --git a/jiuwen/core/context/state.py b/jiuwen/core/context/state.py index 5d46198..beb6e00 100644 --- a/jiuwen/core/context/state.py +++ b/jiuwen/core/context/state.py @@ -2,27 +2,43 @@ # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved from abc import ABC, abstractmethod -from typing import Any, Union, Optional +from typing import Any, Union, Optional, Callable class StateLike(ABC): @abstractmethod - def get(self, key: Union[str, dict]) -> Optional[Any]: + def get(self, key: Union[str, list, dict]) -> Optional[Any]: pass @abstractmethod - def update(self, data: dict) -> None: + def get_by_transformer(self, transformer: Callable) -> Optional[Any]: pass + @abstractmethod + def update(self, node_id: str, data: dict) -> None: + pass class CommitState(StateLike): @abstractmethod - def commit(self) -> dict: + def commit(self) -> None: + pass + + @abstractmethod + def rollback(self, failed_node_ids: list[str]) -> None: pass + @abstractmethod + def get_updates(self, node_id: str) -> list[dict]: + pass class State(ABC): - def __init__(self, io_state: CommitState, global_state: CommitState, trace_state: StateLike, comp_state: StateLike): + def __init__( + self, + io_state: CommitState, + global_state: CommitState, + trace_state: StateLike, + comp_state: StateLike + ): self._io_state = io_state self._global_state = global_state self._trace_state = trace_state @@ -49,20 +65,25 @@ class State(ABC): return None return self._global_state.get(key) - def update(self, data: dict) -> None: + def update(self, node_id: str, data: dict) -> None: if self._global_state is None: return - self._global_state.update(data) + self._global_state.update(node_id, data) - def update_io(self, data: dict) -> None: + def update_io(self, node_id: str, data: dict) -> None: if self._io_state is None: return - self._io_state.update(data) + self._io_state.update(node_id, data) - def update_trace(self, data: dict) -> None: + def update_trace(self, node_id: str, data: dict) -> None: if self._trace_state is None: return - self._trace_state.update(data) + self._trace_state.update(node_id, data) + + def update_comp(self, node_id: str, data: dict) -> None: + if self._comp_state is None: + return + self._comp_state.update(node_id, data) def set_user_inputs(self, inputs: dict) -> None: if self._io_state is None: @@ -74,13 +95,13 @@ class State(ABC): return {} return self._io_state.get(input_schemas) - def get_outputs(self, node_id: str)-> Any: + def get_outputs(self, node_id: str) -> Any: if self._io_state is None: return {} return self._io_state.get(node_id) - def set_outputs(self, outputs: dict)-> None: + def set_outputs(self, node_id: str, outputs: dict) -> None: if self._io_state is None: return + return self._io_state.update(node_id, outputs) - return self._io_state.update(outputs) diff --git a/jiuwen/core/context/utils.py b/jiuwen/core/context/utils.py new file mode 100644 index 0000000..c01b43c --- /dev/null +++ b/jiuwen/core/context/utils.py @@ -0,0 +1,168 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +import re +from typing import Optional, Any, Union + +from jiuwen.core.common.exception.exception import JiuWenBaseException + +REGEX_MAX_LENGTH = 1000 +NESTED_PATH_LIST_PATTERN = re.compile(r'^([\w]+)((?:\[\d+\])*)$') +NESTED_PATH_SPLIT = '.' +NESTED_PATH_LIST_SPLIT = "[" + + +def update_dict(update: dict, source: dict) -> None: + """ + update source dict by update dict + Note: source is unnested structure, update is nested structure + + :param update: update dict, which key is nested + :param source: source dict, which key must not be nested + """ + for key, value in update.items(): + current_key, current = root_to_path(key, source, create_if_absent=True) + update_by_key(current_key, value, current) + + +def get_value_by_nested_path(nested_key: str, source: dict) -> Optional[Any]: + result = root_to_path(nested_key, source) + if result is None: + return result + return result[1][result[0]] + + +def split_nested_path(nested_key: str) -> list: + ''' + split nested path + :param nested_key: path + :return: e.g. a_1.b.c[1].d -> ["a_1", "b", "c", 1, "d"] + ''' + + if (NESTED_PATH_SPLIT not in nested_key) and (NESTED_PATH_LIST_SPLIT not in nested_key): + return [] + final_list = [] + params = nested_key.split(NESTED_PATH_SPLIT) + pattern = re.compile(r'^([\w]+)((?:\[\d+\])*)$') + for param in params: + match = re.match(pattern, param) + if match: + index = match.group(2) + if len(index) > 0: + numbers = re.findall(r'\d+', index) + idxes = [int(num) if str.isdigit(num) else num for num in numbers] + if len(idxes) == 0: + raise JiuWenBaseException(1, "failed to split nested path") + final_list.append((match.group(1), idxes)) + else: + final_list.append(match.group(1)) + return final_list + + +def extract_origin_key(key: str) -> str: + """ + extract the origin key from given key if the given key is reference structure + e.g. "${start123.p2}" -> "start123.p2" + :param key: reference key + :return: origin key + """ + if '$' not in key: + return key + pattern = re.compile(r"\${(.+?)\}") + match = pattern.search(key, endpos=REGEX_MAX_LENGTH) + if match: + return match.group(1) + return key + +def update_by_key(key: Union[str, int], new_value: Any, source: dict) -> None: + if key not in source: + source[key] = expand_nested_structure(new_value) + return + if isinstance(source[key], dict) and isinstance(new_value, dict): + update_dict(new_value, source[key]) + else: + source[key] = expand_nested_structure(new_value) + + +def expand_nested_structure(data: Any) -> Any: + if isinstance(data, list) or isinstance(data, tuple): + result = [] + for item in data: + result.append(expand_nested_structure(item)) + return result + elif isinstance(data, dict): + result = {} + for key, value in data.items(): + current_key, current = root_to_path(key, result, create_if_absent=True) + current[current_key] = expand_nested_structure(value) + return result + else: + return data + + +def root_to_path(nested_path: str, source: dict, create_if_absent: bool = False) -> tuple[Union[str, int], dict]: + paths = split_nested_path(nested_path) + if len(paths) == 0: + return (nested_path, source) + current = source + for i in range(len(paths)): + path = paths[i] + if isinstance(path, str): + if path not in current: + if not create_if_absent: + return (None, None) + current[path] = {} + if i == len(paths) - 1: + return (path, current) + current = current[path] + else: + token = path[0] + if token not in current: + if not create_if_absent: + return (None, None) + current[token] = [] + current = current[token] + if i == len(paths) - 1: + return root_to_index(path[1], current, create_if_absent) + else: + idx, current = root_to_index(path[1], current, create_if_absent) + if current is None: + return (None, None) + current = current[idx] + return (None, None) + + +def root_to_index(idxes: list[int], source: dict, create_if_absent: bool = False) -> Optional[tuple[int, dict]]: + current = source + if len(idxes) > 1: + for idx in idxes[:-1]: + if idx >= len(current): + if not create_if_absent: + return None + current += [None] * (idx - len(source) - idx) + current.append([]) + current = current[idx] + if idxes[-1] >= len(source): + if not create_if_absent: + return None + current += [None] * (idxes[-1] - len(source)) + current.append({}) + return idxes[-1], current + + +if __name__ == '__main__': + source = {} + # 增加a.b: nums属性 + update_dict({"a.b.nums": [1, 2, 3]}, source) + print(source) + # 增加a.b: name属性 + update_dict({ + "a.b.name": "shanghai" + }, source) + print(source) + # 增加a.b: class属性 + update_dict({"a.b": {"class":"hha"}}, source) + print(source) + # 覆盖a.b所有 + update_dict({"a.b": [1,2,3]}, source) + print(source) diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index 6a9aef2..f853d2b 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -20,7 +20,7 @@ class Vertex: def __call__(self, state: GraphState) -> Output: if self._context is None: - raise JiuWenBaseException(1, "vertex is not initialized, node is is %s", self._node_id) + raise JiuWenBaseException(1, "vertex is not initialized, node is is " + self._node_id) inputs = self.__pre_invoke__(state) is_stream = self.__is_stream__(state) try: @@ -32,7 +32,7 @@ class Vertex: self.__post_invoke__(results) except JiuWenBaseException as e: raise JiuWenBaseException(e.error_code, "failed to invoke, caused by " + e.message) - return {"source": self._node_id} + return {"source_node_id": self._node_id} def __pre_invoke__(self, state:GraphState) -> Optional[dict]: inputs_schema = self._context.config.get_inputs_schema(self._node_id) diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index bb30273..5686512 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -3,6 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved import asyncio from dataclasses import dataclass, Field +from enum import Enum from functools import partial from typing import Self, Dict, Any, Union, AsyncIterator, Iterator @@ -30,6 +31,7 @@ class WorkflowChunk(BaseModel): is_final: bool = Field(default=False) + class Workflow: def __init__(self, workflow_config: WorkflowConfig, graph: Graph = None): self._graph = graph @@ -85,9 +87,9 @@ class Workflow: def add_stream_connection(self, src_comp_id: str, target_comp_id: str) -> Self: self._graph.add_edge(source_node_id=src_comp_id, target_node_id=target_comp_id) if target_comp_id not in self._stream_edges: - self._stream_edges[target_comp_id] = [src_comp_id] + self._stream_edges[src_comp_id] = [target_comp_id] else: - self._stream_edges[target_comp_id].append(src_comp_id) + self._stream_edges[src_comp_id].append(target_comp_id) return self def add_conditional_connection(self, src_comp_id: str, router: Router) -> Self: -- Gitee From 911675f7531126c3c477e987b3915858b5fbb57f Mon Sep 17 00:00:00 2001 From: wang-guangge Date: Fri, 11 Jul 2025 18:01:12 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20=E8=A1=A5=E5=85=85tracer=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/stream/manager.py | 2 +- jiuwen/core/tracer/handler.py | 18 +++-- jiuwen/core/tracer/span.py | 2 +- tests/tracer/test.py | 68 +++++++++++++---- tests/unit_tests/tracer/test_tracer_output.py | 73 +++++++++++++++++++ 5 files changed, 138 insertions(+), 25 deletions(-) create mode 100644 tests/unit_tests/tracer/test_tracer_output.py diff --git a/jiuwen/core/stream/manager.py b/jiuwen/core/stream/manager.py index 40639f5..51f9096 100644 --- a/jiuwen/core/stream/manager.py +++ b/jiuwen/core/stream/manager.py @@ -48,7 +48,7 @@ class StreamWriterManager: self._writers[key] = writer def get_writer(self, key: StreamMode) -> Optional[StreamWriter]: - return self._writers.get[key] + return self._writers.get(key) def get_output_writer(self) -> Optional[StreamWriter]: return self.get_writer(BaseStreamMode.OUTPUT) diff --git a/jiuwen/core/tracer/handler.py b/jiuwen/core/tracer/handler.py index 42103d1..983d347 100644 --- a/jiuwen/core/tracer/handler.py +++ b/jiuwen/core/tracer/handler.py @@ -36,20 +36,24 @@ class TraceBaseHandler(BaseHandler): @abstractmethod def _format_data(self, span: Span) -> dict: - return {"event_name": self.event_name(), "playload": span} + return {"type": self.event_name(), "payload": span} async def _emit_stream_writer(self, span): # TODO 替换为使用TraceStreamWriter进行输出 - # await self._stream_writer.write(self._format_data(raw_data)) + await self._stream_writer.write(self._format_data(span)) print(f"_emit_stream_writer: {self._format_data(span)}") - wait_time = random.randint(0, 10) - await asyncio.sleep(wait_time) - print(f"wait_time {wait_time}") + # wait_time = random.randint(0, 10) + # await asyncio.sleep(wait_time) + # print(f"wait_time {wait_time}") def _send_data(self, span): print(f"send_data, {span}") asyncio.create_task(self.emit_stream_writer(copy.deepcopy(span))) + # loop = asyncio.get_event_loop() + # # loop.create_task(self.emit_stream_writer(copy.deepcopy(span))) + # asyncio.run_coroutine_threadsafe(self.emit_stream_writer(copy.deepcopy(span)), loop) + def _get_elapsed_time(self, start_time: datetime, end_time: datetime) -> str: """get elapsed time""" @@ -67,7 +71,7 @@ class TraceAgentHandler(TraceBaseHandler): return TracerHandlerName.TRACE_AGENT.value def _format_data(self, span: TraceAgentSpan) -> dict: - return {"event_name": self.event_name(), "playload": span.model_dump(by_alias=True)} + return {"type": self.event_name(), "payload": span.model_dump(by_alias=True)} def _update_start_trace_data(self, span: TraceAgentSpan, invoke_type: str, inputs: Any, instance_info: dict, **kwargs): @@ -225,7 +229,7 @@ class TraceWorkflowHandler(TraceBaseHandler): if span.end_time: status = NodeStatus.FINISH.value if span.outputs else NodeStatus.RUNNING.value span.status = status - return {"event_name": self.event_name(), "playload": span.model_dump(by_alias=True)} + return {"type": self.event_name(), "payload": span.model_dump(by_alias=True)} @trigger_event def on_pre_invoke(self, span: TraceWorkflowSpan, inputs: Any, component_metadata: dict, diff --git a/jiuwen/core/tracer/span.py b/jiuwen/core/tracer/span.py index 07c51aa..8c1b0d0 100644 --- a/jiuwen/core/tracer/span.py +++ b/jiuwen/core/tracer/span.py @@ -71,7 +71,7 @@ class SpanManager: def create_agent_span(self, parent_span: Optional[TraceAgentSpan] = None) -> TraceAgentSpan: return self._create_span(TraceAgentSpan, parent_span) - def create_workflow_agent(self, parent_span: Optional[TraceWorkflowSpan] = None) -> TraceWorkflowSpan: + def create_workflow_span(self, parent_span: Optional[TraceWorkflowSpan] = None) -> TraceWorkflowSpan: return self._create_span(TraceWorkflowSpan, parent_span) def update_span(self, span: Span, data: dict): diff --git a/tests/tracer/test.py b/tests/tracer/test.py index 2cd2b09..eb8f2d5 100644 --- a/tests/tracer/test.py +++ b/tests/tracer/test.py @@ -1,7 +1,23 @@ import asyncio import uuid +import sys +import types +import unittest +from unittest.mock import Mock + + +fake_base = types.ModuleType("base") +fake_base.logger = Mock() + +fake_exception_module = types.ModuleType("base") +fake_exception_module.JiuWenBaseException = Mock() + +sys.modules["jiuwen.core.common.logging.base"] = fake_base +sys.modules["jiuwen.core.common.exception.base"] = fake_exception_module + from jiuwen.core.runtime.callback_manager import CallbackManager +from jiuwen.core.stream.emitter import StreamEmitter from jiuwen.core.stream.manager import StreamWriterManager from jiuwen.core.tracer.handler import TraceAgentHandler, TraceWorkflowHandler from jiuwen.core.tracer.span import SpanManager @@ -11,23 +27,43 @@ def generate_tracer_id(): Generate tracer_id, which is also the execution_id. """ return str(uuid.uuid4()) + + +trace_id = generate_tracer_id() +callback_manager = CallbackManager() +stream_writer_manager = StreamWriterManager(StreamEmitter()) +trace_agent_span_manager = SpanManager(trace_id) +trace_workflow_span_manager = SpanManager(trace_id) +trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, trace_agent_span_manager) +trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, trace_workflow_span_manager) +callback_manager.register_handler({"tracer_agent": trace_agent_handler}) +callback_manager.register_handler({"tracer_workflow": trace_workflow_handler}) +tracer_agent_span = trace_agent_span_manager.create_agent_span() +tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() -async def main(): - trace_id = generate_tracer_id() - callback_manager = CallbackManager() - stream_writer_manager = StreamWriterManager() - trace_agent_span_manager = SpanManager(trace_id) - trace_workflow_span_manager = SpanManager(trace_id) - trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, trace_agent_span_manager) - trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, trace_workflow_span_manager) - callback_manager.register_handler({"tracer_agent": trace_agent_handler}) - callback_manager.register_handler({"tracer_workflow": trace_workflow_handler}) - tracer_agent_span = trace_agent_span_manager.create_agent_span() - tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() +def tracer_agent(): callback_manager.trigger("tracer_agent", "on_chain_start", span=tracer_agent_span, inputs={}, - instance_info={"class_name": "testagentnode"}) + instance_info={"class_name": "testagentnode"}) + +def tracer_workflow(): callback_manager.trigger("tracer_workflow", "on_pre_invoke", span=tracer_workflow_span, inputs={}, - component_metadata={"component_type": "testworkflownode"}) + component_metadata={"component_type": "testworkflownode"}) + +async def stream_output(): + async for data in stream_writer_manager.stream_output(): + print(f"Received data: {data}") + +async def test_stream_output(): + # loop = asyncio.get_event_loop() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop=loop) + + # 创建 asyncio 任务 + task1 = loop.run_in_executor(None, tracer_agent) + task2 = loop.run_in_executor(None, tracer_workflow) + task3 = stream_output() + + # 并发运行所有任务 + await asyncio.gather(task1, task2, task3) -if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file +asyncio.run(test_stream_output()) diff --git a/tests/unit_tests/tracer/test_tracer_output.py b/tests/unit_tests/tracer/test_tracer_output.py new file mode 100644 index 0000000..415457b --- /dev/null +++ b/tests/unit_tests/tracer/test_tracer_output.py @@ -0,0 +1,73 @@ +import asyncio +import sys +import types +import unittest +from unittest.mock import Mock + + +fake_base = types.ModuleType("base") +fake_base.logger = Mock() + +fake_exception_module = types.ModuleType("base") +fake_exception_module.JiuWenBaseException = Mock() + +sys.modules["jiuwen.core.common.logging.base"] = fake_base +sys.modules["jiuwen.core.common.exception.base"] = fake_exception_module +import uuid +from jiuwen.core.stream.emitter import StreamEmitter +from jiuwen.core.runtime.callback_manager import CallbackManager +from jiuwen.core.stream.manager import StreamWriterManager +from jiuwen.core.tracer.handler import TraceAgentHandler, TraceWorkflowHandler +from jiuwen.core.tracer.span import SpanManager + + + +def generate_tracer_id(): + """ + Generate tracer_id, which is also the execution_id. + """ + return str(uuid.uuid4()) +class TestTracer(unittest.TestCase): + + def setUp(self): + trace_id = generate_tracer_id() + callback_manager = CallbackManager() + stream_writer_manager = StreamWriterManager(StreamEmitter()) + trace_agent_span_manager = SpanManager(trace_id) + trace_workflow_span_manager = SpanManager(trace_id) + trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, trace_agent_span_manager) + trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, trace_workflow_span_manager) + callback_manager.register_handler({"tracer_agent": trace_agent_handler}) + callback_manager.register_handler({"tracer_workflow": trace_workflow_handler}) + tracer_agent_span = trace_agent_span_manager.create_agent_span() + tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() + self.callback_manager = callback_manager + self.stream_writer_manager = stream_writer_manager + self.tracer_agent_span = tracer_agent_span + self.tracer_workflow_span = tracer_workflow_span + + def tracer_agent(self): + self.callback_manager.trigger("tracer_agent", "on_chain_start", span=self.tracer_agent_span, inputs={}, + instance_info={"class_name": "testagentnode"}) + + def tracer_workflow(self): + self.callback_manager.trigger("tracer_workflow", "on_pre_invoke", span=self.tracer_workflow_span, inputs={}, + component_metadata={"component_type": "testworkflownode"}) + + async def stream_output(self): + async for data in self.stream_writer_manager.stream_output(): + print(f"Received data: {data}") + + async def test_stream_output(self): + loop = asyncio.get_event_loop() + + # 创建 asyncio 任务 + task1 = loop.run_in_executor(None, self.run_tracer_agent) + task2 = loop.run_in_executor(None, self.run_tracer_workflow) + task3 = self.stream_output() + + # 并发运行所有任务 + await asyncio.gather(task1, task2, task3) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file -- Gitee From 775a23b03b8eedb8874a784030fc6713f93cee1d Mon Sep 17 00:00:00 2001 From: wang-guangge Date: Sat, 12 Jul 2025 18:47:32 +0800 Subject: [PATCH 3/4] =?UTF-8?q?fix:=20tracer=E8=B0=83=E8=AF=95=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/tracer/handler.py | 81 +++++++++---------- jiuwen/core/tracer/span.py | 8 +- tests/tracer/test.py | 51 ++++++++---- tests/unit_tests/tracer/test_tracer_output.py | 73 ----------------- 4 files changed, 82 insertions(+), 131 deletions(-) delete mode 100644 tests/unit_tests/tracer/test_tracer_output.py diff --git a/jiuwen/core/tracer/handler.py b/jiuwen/core/tracer/handler.py index 983d347..1e57e51 100644 --- a/jiuwen/core/tracer/handler.py +++ b/jiuwen/core/tracer/handler.py @@ -1,10 +1,10 @@ import asyncio import copy import json -import random from abc import abstractmethod from datetime import datetime from enum import Enum +import threading from typing import Any from dateutil.tz import tzlocal @@ -39,20 +39,15 @@ class TraceBaseHandler(BaseHandler): return {"type": self.event_name(), "payload": span} async def _emit_stream_writer(self, span): - # TODO 替换为使用TraceStreamWriter进行输出 await self._stream_writer.write(self._format_data(span)) - print(f"_emit_stream_writer: {self._format_data(span)}") - # wait_time = random.randint(0, 10) - # await asyncio.sleep(wait_time) - # print(f"wait_time {wait_time}") + def _run_send_data_in_thread(self, span): + new_loop = asyncio.new_event_loop() + new_loop.run_until_complete(self.emit_stream_writer(span)) + new_loop.close() def _send_data(self, span): - print(f"send_data, {span}") - asyncio.create_task(self.emit_stream_writer(copy.deepcopy(span))) - # loop = asyncio.get_event_loop() - # # loop.create_task(self.emit_stream_writer(copy.deepcopy(span))) - # asyncio.run_coroutine_threadsafe(self.emit_stream_writer(copy.deepcopy(span)), loop) + threading.Thread(target=self._run_send_data_in_thread, args=(copy.deepcopy(span), )).start() def _get_elapsed_time(self, start_time: datetime, end_time: datetime) -> str: @@ -119,13 +114,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_chain_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_chain_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_chain_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_chain_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -135,13 +130,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_llm_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_llm_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_llm_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_llm_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -151,13 +146,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_prompt_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_prompt_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_prompt_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_prompt_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -167,13 +162,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_plugin_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_plugin_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_plugin_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_plugin_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -183,13 +178,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_retriever_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_retriever_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_retriever_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_retriever_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @trigger_event @@ -199,13 +194,13 @@ class TraceAgentHandler(TraceBaseHandler): self._send_data(span) @trigger_event - def on_evaluator_end(self, span: TraceAgentSpan, **kwargs): - self._update_end_trace_data(span=span, **kwargs) + def on_evaluator_end(self, span: TraceAgentSpan, outputs, **kwargs): + self._update_end_trace_data(span=span, outputs=outputs, **kwargs) self._send_data(span) @trigger_event - def on_evaluator_error(self, span: TraceAgentSpan, **kwargs): - self._update_error_trace_data(span=span, **kwargs) + def on_evaluator_error(self, span: TraceAgentSpan, error, **kwargs): + self._update_error_trace_data(span=span, error=error, **kwargs) self._send_data(span) @@ -221,15 +216,17 @@ class TraceWorkflowHandler(TraceBaseHandler): return TracerHandlerName.TRACER_WORKFLOW.value def _format_data(self, span: TraceWorkflowSpan) -> dict: - status = NodeStatus.START.value + span.status = self._get_node_status(span) + return {"type": self.event_name(), "payload": span.model_dump(by_alias=True)} + + def _get_node_status(self, span: TraceWorkflowSpan) -> str: if span.error: - status = NodeStatus.ERROR.value + return NodeStatus.ERROR.value if span.on_invoke_data: - status = NodeStatus.RUNNING.value if not span.outputs else NodeStatus.FINISH.value + return NodeStatus.RUNNING.value if not span.outputs else NodeStatus.FINISH.value if span.end_time: - status = NodeStatus.FINISH.value if span.outputs else NodeStatus.RUNNING.value - span.status = status - return {"type": self.event_name(), "payload": span.model_dump(by_alias=True)} + return NodeStatus.FINISH.value if span.outputs else NodeStatus.RUNNING.value + return NodeStatus.START.value @trigger_event def on_pre_invoke(self, span: TraceWorkflowSpan, inputs: Any, component_metadata: dict, @@ -271,6 +268,8 @@ class TraceWorkflowHandler(TraceBaseHandler): "elapsed_time": self._get_elapsed_time(span.start_time, end_time) } else: + if not isinstance(span.on_invoke_data, list): + span.on_invoke_data = [] span.on_invoke_data.append(on_invoke_data) if span.component_type == "LLM": # TODO diff --git a/jiuwen/core/tracer/span.py b/jiuwen/core/tracer/span.py index 8c1b0d0..205f954 100644 --- a/jiuwen/core/tracer/span.py +++ b/jiuwen/core/tracer/span.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime from typing import Optional, Dict, List, Callable -from pydantic import Field, BaseModel +from pydantic import ConfigDict, Field, BaseModel class Span(BaseModel): trace_id: str @@ -14,6 +14,8 @@ class Span(BaseModel): parent_invoke_id: Optional[str] = Field(default=None, alias="parentInvokeId") child_invokes_id: List[str] = Field(default=[], alias="childInvokes") + model_config = ConfigDict(populate_by_name=True) + def update(self, data: dict): for attr_name, value in data.items(): if not hasattr(self, attr_name): @@ -58,9 +60,9 @@ class SpanManager: def _create_span(self, span_class: Callable, parent_span = None): invoke_id = str(uuid.uuid4()) - span = span_class(invoke_id=invoke_id, parent_id=parent_span.invoke_id if parent_span else None, + span = span_class(invoke_id=invoke_id, parent_invoke_id=parent_span.invoke_id if parent_span else None, trace_id=self._trace_id) - + if parent_span: parent_span.child_invokes_id.append(span.invoke_id) self.refresh_span_record(parent_span.invoke_id, {parent_span.invoke_id: parent_span}) diff --git a/tests/tracer/test.py b/tests/tracer/test.py index eb8f2d5..4ea8d80 100644 --- a/tests/tracer/test.py +++ b/tests/tracer/test.py @@ -1,9 +1,9 @@ import asyncio +import time import uuid import sys import types -import unittest from unittest.mock import Mock @@ -22,6 +22,8 @@ from jiuwen.core.stream.manager import StreamWriterManager from jiuwen.core.tracer.handler import TraceAgentHandler, TraceWorkflowHandler from jiuwen.core.tracer.span import SpanManager + + def generate_tracer_id(): """ Generate tracer_id, which is also the execution_id. @@ -51,19 +53,40 @@ def tracer_workflow(): async def stream_output(): async for data in stream_writer_manager.stream_output(): - print(f"Received data: {data}") + print(f"Received data: {data}\n") -async def test_stream_output(): - # loop = asyncio.get_event_loop() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop=loop) - - # 创建 asyncio 任务 - task1 = loop.run_in_executor(None, tracer_agent) - task2 = loop.run_in_executor(None, tracer_workflow) - task3 = stream_output() +class MockAgent: + def invoke(self): + tracer_agent_span = trace_agent_span_manager.create_agent_span() + callback_manager.trigger("tracer_agent", "on_chain_start", span=tracer_agent_span, inputs={}, + instance_info={"class_name": "Agent"}) + # 模拟运行 + time.sleep(2) + workflow = MockWorkflow() + workflow.invoke() + callback_manager.trigger("tracer_agent", "on_chain_end", span=tracer_agent_span, outputs={}) + +class MockWorkflow: + def invoke(self): + tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() + callback_manager.trigger("tracer_workflow", "on_pre_invoke", span=tracer_workflow_span, inputs={}, + component_metadata={"component_type": "Workflow"}) + # 模拟运行 + time.sleep(2) + callback_manager.trigger("tracer_workflow", "on_invoke", span=tracer_workflow_span, + on_invoke_data={"on_invoke": "data"}, + component_metadata={"component_type": "Workflow"}) + # 模拟运行 + time.sleep(2) + callback_manager.trigger("tracer_workflow", "on_post_invoke", span=tracer_workflow_span, inputs=None, + outputs={"outputs": "result"}) + + + - # 并发运行所有任务 - await asyncio.gather(task1, task2, task3) +async def test_agent_workflow_trace(): + agent = MockAgent() + agent.invoke() + await stream_output() -asyncio.run(test_stream_output()) +asyncio.run(test_agent_workflow_trace()) diff --git a/tests/unit_tests/tracer/test_tracer_output.py b/tests/unit_tests/tracer/test_tracer_output.py deleted file mode 100644 index 415457b..0000000 --- a/tests/unit_tests/tracer/test_tracer_output.py +++ /dev/null @@ -1,73 +0,0 @@ -import asyncio -import sys -import types -import unittest -from unittest.mock import Mock - - -fake_base = types.ModuleType("base") -fake_base.logger = Mock() - -fake_exception_module = types.ModuleType("base") -fake_exception_module.JiuWenBaseException = Mock() - -sys.modules["jiuwen.core.common.logging.base"] = fake_base -sys.modules["jiuwen.core.common.exception.base"] = fake_exception_module -import uuid -from jiuwen.core.stream.emitter import StreamEmitter -from jiuwen.core.runtime.callback_manager import CallbackManager -from jiuwen.core.stream.manager import StreamWriterManager -from jiuwen.core.tracer.handler import TraceAgentHandler, TraceWorkflowHandler -from jiuwen.core.tracer.span import SpanManager - - - -def generate_tracer_id(): - """ - Generate tracer_id, which is also the execution_id. - """ - return str(uuid.uuid4()) -class TestTracer(unittest.TestCase): - - def setUp(self): - trace_id = generate_tracer_id() - callback_manager = CallbackManager() - stream_writer_manager = StreamWriterManager(StreamEmitter()) - trace_agent_span_manager = SpanManager(trace_id) - trace_workflow_span_manager = SpanManager(trace_id) - trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, trace_agent_span_manager) - trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, trace_workflow_span_manager) - callback_manager.register_handler({"tracer_agent": trace_agent_handler}) - callback_manager.register_handler({"tracer_workflow": trace_workflow_handler}) - tracer_agent_span = trace_agent_span_manager.create_agent_span() - tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() - self.callback_manager = callback_manager - self.stream_writer_manager = stream_writer_manager - self.tracer_agent_span = tracer_agent_span - self.tracer_workflow_span = tracer_workflow_span - - def tracer_agent(self): - self.callback_manager.trigger("tracer_agent", "on_chain_start", span=self.tracer_agent_span, inputs={}, - instance_info={"class_name": "testagentnode"}) - - def tracer_workflow(self): - self.callback_manager.trigger("tracer_workflow", "on_pre_invoke", span=self.tracer_workflow_span, inputs={}, - component_metadata={"component_type": "testworkflownode"}) - - async def stream_output(self): - async for data in self.stream_writer_manager.stream_output(): - print(f"Received data: {data}") - - async def test_stream_output(self): - loop = asyncio.get_event_loop() - - # 创建 asyncio 任务 - task1 = loop.run_in_executor(None, self.run_tracer_agent) - task2 = loop.run_in_executor(None, self.run_tracer_workflow) - task3 = self.stream_output() - - # 并发运行所有任务 - await asyncio.gather(task1, task2, task3) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file -- Gitee From b70e1d6a13809c8a12ba392a2e7d56c9eaf6099f Mon Sep 17 00:00:00 2001 From: chenchunzhou Date: Sat, 12 Jul 2025 19:15:38 +0800 Subject: [PATCH 4/4] =?UTF-8?q?test:=20=E4=BF=AE=E5=A4=8Dstream=20writer?= =?UTF-8?q?=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/stream/writer.py | 2 +- tests/unit_tests/stream/test_stream_output.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jiuwen/core/stream/writer.py b/jiuwen/core/stream/writer.py index 8fc0dd2..040a7c8 100644 --- a/jiuwen/core/stream/writer.py +++ b/jiuwen/core/stream/writer.py @@ -64,7 +64,7 @@ class TraceStreamWriter(StreamWriter[dict, TraceSchema]): class CustomSchema(BaseModel): def __init__(self, **kwargs): - super().__init_(**kwargs) + super().__init__(**kwargs) class Config: arbitrary_types_allowed = True diff --git a/tests/unit_tests/stream/test_stream_output.py b/tests/unit_tests/stream/test_stream_output.py index a0a6e79..7c1eb1c 100644 --- a/tests/unit_tests/stream/test_stream_output.py +++ b/tests/unit_tests/stream/test_stream_output.py @@ -1,6 +1,6 @@ import asyncio import unittest -from typing import AsyncIterator, Any, Type +from typing import AsyncIterator, Type from pydantic import BaseModel @@ -155,8 +155,8 @@ class TestStreamOutput(unittest.IsolatedAsyncioTestCase): async def write_data(): async for mock_data in mock_stream_output(): - trace_writer = self.manager.get_trace_writer() - await trace_writer.write(mock_data) + mock_writer = self.manager.get_writer(MockStreamMode.MOCK) + await mock_writer.write(mock_data) await self.emitter.close() async def read_data(): -- Gitee