diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index 2687912adbdfd9dc405ff8dfa4efefb4d3cd239e..5cf2a36c384929bfcab3cd2f9126ca0bd3a0117c 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -11,7 +11,6 @@ from jiuwen.core.context.utils import get_by_schema from jiuwen.core.graph.base import ExecutableGraph, INPUTS_KEY, CONFIG_KEY from jiuwen.core.graph.executable import Executable, Output from jiuwen.core.graph.graph_state import GraphState -from jiuwen.core.tracer.tracer import Tracer class Vertex: @@ -74,28 +73,24 @@ class Vertex: async def __trace_inputs__(self, inputs: Optional[dict]) -> None: # TODO 组件信息 - await self._context.tracer.trigger("tracer_workflow", "on_pre_invoke", invoke_id=self._context.executable_id, + parent_node_id=self._context.parent_id, inputs=inputs, component_metadata={"component_type": self._context.executable_id}) - self._context.state.update_trace(self._node_id, - self._context.tracer.tracer_workflow_span_manager.get_span(self._node_id)) + self._context.state.update_trace(self._context.executable_id, + self._context.tracer.get_workflow_span(self._context.executable_id, + self._context.parent_id)) if isinstance(self._executable, ExecWorkflowComponent): - self._origin_tracer = self._context.tracer - sub_tracer = Tracer(tracer_id=self._context.tracer._trace_id, parent_node_id=self._context.executable_id) - sub_tracer.init(self._context.stream_writer_manager, self._origin_tracer._callback_manager) - self._context.set_tracer(sub_tracer) + self._context.tracer.register_workflow_span_manager(self._context.executable_id) async def __trace_outputs__(self, outputs: Optional[dict] = None) -> None: - if isinstance(self._executable, ExecWorkflowComponent): - self._context.set_tracer(self._origin_tracer) - await self._context.tracer.trigger("tracer_workflow", "on_post_invoke", invoke_id=self._context.executable_id, + parent_node_id=self._context.parent_id, outputs=outputs) self._context.state.update_trace(self._context.executable_id, - self._context.tracer.tracer_workflow_span_manager.get_span( - self._context.executable_id)) + self._context.tracer.get_workflow_span(self._context.executable_id, + self._context.parent_id)) def __is_stream__(self, state: GraphState) -> bool: return False diff --git a/jiuwen/core/tracer/handler.py b/jiuwen/core/tracer/handler.py index 3dba858f8eeef6dc13ee729710594e9198ddab71..7249810a8d6dfe8b610e82709abd084165e7a20b 100644 --- a/jiuwen/core/tracer/handler.py +++ b/jiuwen/core/tracer/handler.py @@ -111,100 +111,100 @@ class TraceAgentHandler(TraceBaseHandler): self._span_manager.update_span(span, update_data) @trigger_event - def on_chain_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): + async def on_chain_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): self._update_start_trace_data(invoke_type=InvokeType.CHAIN.value, span=span, inputs=inputs, instance_info=instance_info, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_chain_end(self, span: TraceAgentSpan, outputs, **kwargs): + async def on_chain_end(self, span: TraceAgentSpan, outputs, **kwargs): self._update_end_trace_data(span=span, outputs=outputs, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_chain_error(self, span: TraceAgentSpan, error, **kwargs): + async def on_chain_error(self, span: TraceAgentSpan, error, **kwargs): self._update_error_trace_data(span=span, error=error, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_llm_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): + async def on_llm_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): self._update_start_trace_data(invoke_type=InvokeType.LLM.value, span=span, inputs=inputs, instance_info=instance_info, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_llm_end(self, span: TraceAgentSpan, outputs, **kwargs): + async def on_llm_end(self, span: TraceAgentSpan, outputs, **kwargs): self._update_end_trace_data(span=span, outputs=outputs, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_llm_error(self, span: TraceAgentSpan, error, **kwargs): + async def on_llm_error(self, span: TraceAgentSpan, error, **kwargs): self._update_error_trace_data(span=span, error=error, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_prompt_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): + async def on_prompt_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): self._update_start_trace_data(invoke_type=InvokeType.PROMPT.value, span=span, inputs=inputs, instance_info=instance_info, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_prompt_end(self, span: TraceAgentSpan, outputs, **kwargs): + async def on_prompt_end(self, span: TraceAgentSpan, outputs, **kwargs): self._update_end_trace_data(span=span, outputs=outputs, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_prompt_error(self, span: TraceAgentSpan, error, **kwargs): + async def on_prompt_error(self, span: TraceAgentSpan, error, **kwargs): self._update_error_trace_data(span=span, error=error, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_plugin_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): + async def on_plugin_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): self._update_start_trace_data(invoke_type=InvokeType.PLUGIN.value, span=span, inputs=inputs, instance_info=instance_info, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_plugin_end(self, span: TraceAgentSpan, outputs, **kwargs): + async def on_plugin_end(self, span: TraceAgentSpan, outputs, **kwargs): self._update_end_trace_data(span=span, outputs=outputs, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_plugin_error(self, span: TraceAgentSpan, error, **kwargs): + async def on_plugin_error(self, span: TraceAgentSpan, error, **kwargs): self._update_error_trace_data(span=span, error=error, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_retriever_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): + async def on_retriever_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): self._update_start_trace_data(invoke_type=InvokeType.RETRIEVER.value, span=span, inputs=inputs, instance_info=instance_info, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_retriever_end(self, span: TraceAgentSpan, outputs, **kwargs): + async def on_retriever_end(self, span: TraceAgentSpan, outputs, **kwargs): self._update_end_trace_data(span=span, outputs=outputs, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_retriever_error(self, span: TraceAgentSpan, error, **kwargs): + async def on_retriever_error(self, span: TraceAgentSpan, error, **kwargs): self._update_error_trace_data(span=span, error=error, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_evaluator_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): + async def on_evaluator_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs): self._update_start_trace_data(invoke_type=InvokeType.EVALUATOR.value, span=span, inputs=inputs, instance_info=instance_info, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_evaluator_end(self, span: TraceAgentSpan, outputs, **kwargs): + async def on_evaluator_end(self, span: TraceAgentSpan, outputs, **kwargs): self._update_end_trace_data(span=span, outputs=outputs, **kwargs) - self._send_data(span) + await self._send_data(span) @trigger_event - def on_evaluator_error(self, span: TraceAgentSpan, error, **kwargs): + async def on_evaluator_error(self, span: TraceAgentSpan, error, **kwargs): self._update_error_trace_data(span=span, error=error, **kwargs) - self._send_data(span) + await self._send_data(span) class TraceWorkflowHandler(TraceBaseHandler): @@ -264,7 +264,7 @@ class TraceWorkflowHandler(TraceBaseHandler): await self._send_data(span) @trigger_event - async def on_invoke(self, invoke_id: str, on_invoke_data: dict, exception: dict = None, **kwargs): + async def on_invoke(self, invoke_id: str, on_invoke_data: dict, exception: Exception = None, **kwargs): span = self._get_tracer_workflow_span(invoke_id) update_data = {} end_time = datetime.now(tz=tzlocal()).replace(tzinfo=None) @@ -315,5 +315,3 @@ class TraceWorkflowHandler(TraceBaseHandler): if span.component_type == "End" and span.end_time: span.llm_invoke_data.clear() self._span_manager.update_span(span, {}) - - await asyncio.sleep(1) diff --git a/jiuwen/core/tracer/tracer.py b/jiuwen/core/tracer/tracer.py index 5ea7ff09e0a35716e746a1885df391027f650f29..d43ff4f50e1a4d25bf4ad0e097c799232dc8ff42 100644 --- a/jiuwen/core/tracer/tracer.py +++ b/jiuwen/core/tracer/tracer.py @@ -5,30 +5,42 @@ from jiuwen.core.tracer.span import SpanManager class Tracer: - def __init__(self, tracer_id=None, parent_node_id=""): - self._callback_manager = None - self._trace_id = str(uuid.uuid4()) if tracer_id is None else tracer_id + def __init__(self): + self._trace_id = str(uuid.uuid4()) self.tracer_agent_span_manager = SpanManager(self._trace_id) - self.tracer_workflow_span_manager = SpanManager(self._trace_id, parent_node_id=parent_node_id) - self._parent_node_id = parent_node_id + # 一个workflow对应一个span_manager + self.tracer_workflow_span_manager_dict = {} + self._callback_manager = None + self._stream_writer_manager = None def init(self, stream_writer_manager, callback_manager): - # 用于注册子workflow tracer handler,子workflow中使用新的tracer handler - if self._parent_node_id != "": - trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, - self.tracer_workflow_span_manager) - callback_manager.register_handler( - {TracerHandlerName.TRACER_WORKFLOW.value + "." + self._parent_node_id: trace_workflow_handler}) - else: - trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, - self.tracer_agent_span_manager) - trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, - self.tracer_workflow_span_manager) - callback_manager.register_handler({TracerHandlerName.TRACE_AGENT.value: trace_agent_handler}) - callback_manager.register_handler({TracerHandlerName.TRACER_WORKFLOW.value: trace_workflow_handler}) - + trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, + self.tracer_agent_span_manager) + parent_tracer_workflow_span_manager = SpanManager(self._trace_id) + trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, + parent_tracer_workflow_span_manager) + self.tracer_workflow_span_manager_dict[None] = parent_tracer_workflow_span_manager + callback_manager.register_handler({TracerHandlerName.TRACE_AGENT.value: trace_agent_handler}) + callback_manager.register_handler({TracerHandlerName.TRACER_WORKFLOW.value: trace_workflow_handler}) self._callback_manager = callback_manager + self._stream_writer_manager = stream_writer_manager + + def register_workflow_span_manager(self, parent_node_id: str): + tracer_workflow_span_manager = SpanManager(self._trace_id, parent_node_id=parent_node_id) + self.tracer_workflow_span_manager_dict[parent_node_id] = tracer_workflow_span_manager + trace_workflow_handler = TraceWorkflowHandler(self._callback_manager, self._stream_writer_manager, + tracer_workflow_span_manager) + self._callback_manager.register_handler( + {TracerHandlerName.TRACER_WORKFLOW.value + "." + parent_node_id: trace_workflow_handler}) + + def get_workflow_span(self, invoke_id: str, parent_node_id: str): + workflow_span_manager = self.tracer_workflow_span_manager_dict.get(parent_node_id, None) + if workflow_span_manager is None: + return None + return self.tracer_agent_span_manager.get_span(invoke_id) async def trigger(self, handler_class_name: str, event_name: str, **kwargs): - handler_class_name += "." + self._parent_node_id if self._parent_node_id != "" else "" + parent_node_id = kwargs.get("parent_node_id", None) + if parent_node_id is not None: + handler_class_name += "." + parent_node_id if parent_node_id != "" else "" await self._callback_manager.trigger(handler_class_name, event_name, **kwargs) diff --git a/tests/tracer/test.py b/tests/tracer/test.py deleted file mode 100644 index cc35eb2be4cfca13506f3d42b160cf01c5c5a729..0000000000000000000000000000000000000000 --- a/tests/tracer/test.py +++ /dev/null @@ -1,79 +0,0 @@ -import asyncio -import time -import uuid - -from jiuwen.core.common.logging.base import logger -from jiuwen.core.runtime.callback_manager import CallbackManager -from jiuwen.core.stream.emitter import StreamEmitter -from jiuwen.core.stream.manager import StreamWriterManager -from jiuwen.core.tracer.handler import TraceAgentHandler, TraceWorkflowHandler -from jiuwen.core.tracer.span import SpanManager - - - -def generate_tracer_id(): - """ - Generate tracer_id, which is also the execution_id. - """ - return str(uuid.uuid4()) - - -trace_id = generate_tracer_id() -callback_manager = CallbackManager() -stream_writer_manager = StreamWriterManager(StreamEmitter()) -trace_agent_span_manager = SpanManager(trace_id) -trace_workflow_span_manager = SpanManager(trace_id) -trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, trace_agent_span_manager) -trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, trace_workflow_span_manager) -callback_manager.register_handler({"tracer_agent": trace_agent_handler}) -callback_manager.register_handler({"tracer_workflow": trace_workflow_handler}) -tracer_agent_span = trace_agent_span_manager.create_agent_span() -tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() - -def tracer_agent(): - callback_manager.trigger("tracer_agent", "on_chain_start", span=tracer_agent_span, inputs={}, - instance_info={"class_name": "testagentnode"}) - -def tracer_workflow(): - callback_manager.trigger("tracer_workflow", "on_pre_invoke", span=tracer_workflow_span, inputs={}, - component_metadata={"component_type": "testworkflownode"}) - -async def stream_output(): - async for data in stream_writer_manager.stream_output(): - logger.info(f"Received data: {data}\n") - -class MockAgent: - def invoke(self): - tracer_agent_span = trace_agent_span_manager.create_agent_span() - callback_manager.trigger("tracer_agent", "on_chain_start", span=tracer_agent_span, inputs={}, - instance_info={"class_name": "Agent"}) - # 模拟运行 - time.sleep(2) - workflow = MockWorkflow() - workflow.invoke() - callback_manager.trigger("tracer_agent", "on_chain_end", span=tracer_agent_span, outputs={}) - -class MockWorkflow: - def invoke(self): - tracer_workflow_span = trace_workflow_span_manager.create_workflow_span() - callback_manager.trigger("tracer_workflow", "on_pre_invoke", span=tracer_workflow_span, inputs={}, - component_metadata={"component_type": "Workflow"}) - # 模拟运行 - time.sleep(2) - callback_manager.trigger("tracer_workflow", "on_invoke", span=tracer_workflow_span, - on_invoke_data={"on_invoke": "data"}, - component_metadata={"component_type": "Workflow"}) - # 模拟运行 - time.sleep(2) - callback_manager.trigger("tracer_workflow", "on_post_invoke", span=tracer_workflow_span, inputs=None, - outputs={"outputs": "result"}) - - - - -async def test_agent_workflow_trace(): - agent = MockAgent() - agent.invoke() - await stream_output() - -asyncio.run(test_agent_workflow_trace()) diff --git a/tests/unit_tests/tracer/test_agent.py b/tests/unit_tests/tracer/test_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..6389fcfe335d109bcb653b05666c584968bc19ba --- /dev/null +++ b/tests/unit_tests/tracer/test_agent.py @@ -0,0 +1,172 @@ +import asyncio +import unittest + +from jiuwen.core.common.logging.base import logger +from jiuwen.core.context.config import Config +from jiuwen.core.context.context import Context +from jiuwen.core.context.memory.base import InMemoryState +from jiuwen.core.stream.emitter import StreamEmitter +from jiuwen.core.stream.manager import StreamWriterManager +from jiuwen.core.stream.writer import TraceSchema, CustomSchema +from jiuwen.core.tracer.tracer import Tracer +from tests.unit_tests.tracer.test_mock_node_with_tracer import StreamNodeWithTracer +from tests.unit_tests.tracer.test_workflow import record_tracer_info, create_flow +from tests.unit_tests.workflow.test_mock_node import MockEndNode, MockStartNode + + +class MockLLM: + def __init__(self, tracer): + self.tracer = tracer + + async def stream(self, span): + try: + await self.tracer.trigger("tracer_agent", "on_llm_start", span=span, inputs={"llm": "mock llm"}, + instance_info={"class_name": "Openai"}) + await asyncio.sleep(2) + except Exception as e: + await self.tracer.trigger("tracer_agent", "on_llm_error", span=span, error=e, + ) + raise e + finally: + await self.tracer.trigger("tracer_agent", "on_llm_end", span=span, outputs={"outputs": "mock llm"}, + ) + + +class MockPlugin: + def __init__(self, tracer): + self.tracer = tracer + + async def stream(self, span): + try: + await self.tracer.trigger("tracer_agent", "on_plugin_start", span=span, inputs={"llm": "mock plugin"}, + instance_info={"class_name": "RestFulAPI"}) + await asyncio.sleep(2) + except Exception as e: + await self.tracer.trigger("tracer_agent", "on_plugin_error", span=span, error=e, + ) + raise e + finally: + await self.tracer.trigger("tracer_agent", "on_plugin_end", span=span, outputs={"outputs": "mock plugin"}, + ) + + +class MockAgent(unittest.TestCase): + """ + Agent(llm -> plugin -> workflow) + """ + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.tracer_chunks = [] + + def tearDown(self): + record_tracer_info(self.tracer_chunks, "test_agent_workflow_seq_exec_stream_workflow_with_tracer.json") + + async def run_workflow_seq_exec_stream_workflow_with_tracer(self, tracer: Tracer): + """ + start -> a -> b -> end + """ + + # workflow与agent共用一个tracer + context = Context(config=Config(), state=InMemoryState(), store=None) + context.set_tracer(tracer) + + # async def stream_workflow(): + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start"), + inputs_schema={ + "a": "${a}", + "b": "${b}", + "c": 1, + "d": [1, 2, 3]}) + + node_a_expected_datas = [ + {"node_id": "a", "id": 1, "data": "1"}, + {"node_id": "a", "id": 2, "data": "2"}, + ] + node_a_expected_datas_model = [CustomSchema(**item) for item in node_a_expected_datas] + flow.add_workflow_comp("a", StreamNodeWithTracer("a", node_a_expected_datas), + inputs_schema={ + "aa": "${start.a}", + "ac": "${start.c}"}) + + node_b_expected_datas = [ + {"node_id": "b", "id": 1, "data": "1"}, + {"node_id": "b", "id": 2, "data": "2"}, + ] + node_b_expected_datas_model = [CustomSchema(**item) for item in node_b_expected_datas] + flow.add_workflow_comp("b", StreamNodeWithTracer("b", node_b_expected_datas), + inputs_schema={ + "ba": "${a.aa}", + "bc": "${a.ac}"}) + + flow.set_end_comp("end", MockEndNode("end"), + inputs_schema={ + "result": "${b.ba}"}) + + flow.add_connection("start", "a") + flow.add_connection("a", "b") + flow.add_connection("b", "end") + + expected_datas_model = { + "a": node_a_expected_datas_model, + "b": node_b_expected_datas_model + } + index_dict = {key: 0 for key in expected_datas_model.keys()} + + async for chunk in flow.stream({"a": 1, "b": "haha"}, context): + if not isinstance(chunk, TraceSchema): + node_id = chunk.node_id + index = index_dict[node_id] + assert chunk == expected_datas_model[node_id][index], f"Mismatch at node {node_id} index {index}" + logger.info(f"stream chunk: {chunk}") + index_dict[node_id] = index_dict[node_id] + 1 + else: + print(f"stream chunk: {chunk}") + self.tracer_chunks.append(chunk) + + async def run_agent_workflow_seq_exec_stream_workflow_with_tracer(self): + # context手动初始化tracer,agent和workflow共用一个tracer + context = Context(config=Config(), state=InMemoryState(), store=None) + context.set_stream_writer_manager(StreamWriterManager(StreamEmitter())) + tracer = Tracer() + tracer.init(context.stream_writer_manager, context.callback_manager) + context.set_tracer(tracer) + self.tracer = tracer + + agent_span = self.tracer.tracer_agent_span_manager.create_agent_span() + try: + await self.tracer.trigger("tracer_agent", "on_chain_start", span=agent_span, + inputs={"intput": "mock chain"}, + instance_info={"class_name": "Agent"}) # class_name为必选参数 + + # 模拟需要运行llm、plugin + for runner in [MockLLM(self.tracer), MockPlugin(self.tracer)]: + runner_span = self.tracer.tracer_agent_span_manager.create_agent_span(agent_span) # 用于记录父子span关系 + await runner.stream(runner_span) + + # 模拟运行workflow + await self.run_workflow_seq_exec_stream_workflow_with_tracer(context.tracer) + + await self.tracer.trigger("tracer_agent", "on_chain_end", span=agent_span, + outputs={"outputs": "mock chain"}, + ) + + except Exception as e: + await self.tracer.trigger("tracer_agent", "on_chain_error", span=agent_span, error=e, + ) + raise e + finally: + await context.stream_writer_manager.stream_emitter.close() + + async def get_stream_output(self): + async for item in self.tracer._stream_writer_manager.stream_output(need_close=True): + self.tracer_chunks.append(item) + + def test_agent_workflow_seq_exec_stream_workflow_with_tracer(self): + async def main(): + await self.run_agent_workflow_seq_exec_stream_workflow_with_tracer() + await self.get_stream_output() + + self.loop.run_until_complete(main()) diff --git a/tests/unit_tests/tracer/test_mock_node_with_tracer.py b/tests/unit_tests/tracer/test_mock_node_with_tracer.py index 96ee4628e5594562b85f56ea185e37964874dfee..508a322d27875c5633f376fb0b375da1f7499286 100644 --- a/tests/unit_tests/tracer/test_mock_node_with_tracer.py +++ b/tests/unit_tests/tracer/test_mock_node_with_tracer.py @@ -15,10 +15,25 @@ class StreamNodeWithTracer(MockNodeBase): async def invoke(self, inputs: Input, context: Context) -> Output: context.state.set_outputs(self.node_id, inputs) - await context.tracer.trigger("tracer_workflow", "on_invoke", invoke_id=context.executable_id, - on_invoke_data={"on_invoke_data": "mock with" + str(inputs)}) - context.state.update_trace(context.executable_id, - context.tracer.tracer_workflow_span_manager.get_span(context.executable_id)) + try: + await context.tracer.trigger("tracer_workflow", "on_invoke", invoke_id=context.executable_id, + parent_node_id=context.parent_id, + on_invoke_data={"on_invoke_data": "mock with" + str(inputs)}) + context.state.update_trace(context.executable_id, + context.tracer.get_workflow_span(context.executable_id, + context.parent_id)) + + # 运行时操作 + + except Exception as e: + await context.tracer.trigger("tracer_workflow", "on_invoke", invoke_id=context.executable_id, + parent_node_id=context.parent_id, + error=e) + context.state.update_trace(context.executable_id, + context.tracer.get_workflow_span(context.executable_id, + context.parent_id)) + raise e + await asyncio.sleep(random.randint(0, 5)) for data in self._datas: await asyncio.sleep(1) diff --git a/tests/unit_tests/tracer/test_workflow.py b/tests/unit_tests/tracer/test_workflow.py index 2a5421ffe41a639edbf5e352fbfa5a970a626561..d34eabbe68dc56106912b89a6e1b0e5ce905b1bc 100644 --- a/tests/unit_tests/tracer/test_workflow.py +++ b/tests/unit_tests/tracer/test_workflow.py @@ -1,3 +1,4 @@ +import copy import json import sys import types @@ -351,13 +352,87 @@ class WorkflowTest(unittest.TestCase): elif payload.get("invokeId") == "end": assert payload.get("parentInvokeId") == "b", f"b node parent_invoke_id should be a" assert payload.get("parentNodeId") == "", f"b node parent_node_id should be ''" - elif payload.get("invokeId") == "sub_start": + elif payload.get("invokeId") == "a.sub_start": assert payload.get("parentInvokeId") == None, f"sub_start node parent_invoke_id should be None" assert payload.get("parentNodeId") == "a", f"sub_start node parent_node_id should be a" - elif payload.get("invokeId") == "sub_a": - assert payload.get("parentInvokeId") == "sub_start", f"sub_a node parent_invoke_id should be sub_start" + elif payload.get("invokeId") == "a.sub_a": + assert payload.get("parentInvokeId") == "a.sub_start", f"sub_a node parent_invoke_id should be sub_start" assert payload.get("parentNodeId") == "a", f"sub_a node parent_node_id should be a" - elif payload.get("invokeId") == "sub_end": - assert payload.get("parentInvokeId") == "sub_a", f"sub_end node parent_invoke_id should be sub_a" + elif payload.get("invokeId") == "a.sub_end": + assert payload.get("parentInvokeId") == "a.sub_a", f"sub_end node parent_invoke_id should be sub_a" assert payload.get("parentNodeId") == "a", f"sub_end node parent_node_id should be a" record_tracer_info(tracer_chunks, "test_nested_stream_workflow_with_tracer.json") + + def test_nested_parallel_stream_workflow_with_tracer(self): + """ + main_workflow: start -> a(sub_workflow) | b(sub_workflow) -> end + sub_workflow: sub_start -> sub_a -> sub_end + """ + tracer_chunks = [] + + async def stream_workflow(): + # sub_workflow: start->a(stream out)->end + sub_workflow = create_flow() + sub_workflow.set_start_comp("sub_start", MockStartNode("start"), + inputs_schema={ + "a": "${a}", + "b": "${b}", + "c": 1, + "d": [1, 2, 3]}) + expected_datas = [ + {"node_id": "sub_start", "id": 1, "data": "1"}, + {"node_id": "sub_start", "id": 2, "data": "2"}, + ] + expected_datas_model = [CustomSchema(**item) for item in expected_datas] + + sub_workflow.add_workflow_comp("sub_a", StreamNodeWithTracer("a", expected_datas), + inputs_schema={ + "aa": "${sub_start.a}", + "ac": "${sub_start.c}"}) + sub_workflow.set_end_comp("sub_end", MockEndNode("end"), + inputs_schema={ + "result": "${sub_a.aa}"}) + sub_workflow.add_connection("sub_start", "sub_a") + sub_workflow.add_connection("sub_a", "sub_end") + + sub_workflow_2 = copy.deepcopy(sub_workflow) + + # main_workflow: start->a(sub workflow) | b(sub workflow) ->end + main_workflow = create_flow() + main_workflow.set_start_comp("start", MockStartNode("start"), + inputs_schema={ + "a": "${a}", + "b": "${b}", + "c": 1, + "d": [1, 2, 3]}) + + main_workflow.add_workflow_comp("a", ExecWorkflowComponent("a", sub_workflow), + inputs_schema={ + "aa": "${start.a}", + "ac": "${start.c}"}) + + node_b_expected_datas = [ + {"node_id": "b", "id": 1, "data": "1"}, + {"node_id": "b", "id": 2, "data": "2"}, + ] + + main_workflow.add_workflow_comp("b", ExecWorkflowComponent("b", sub_workflow_2), + inputs_schema={ + "aa": "${start.a}", + "ac": "${start.c}"}) + + main_workflow.set_end_comp("end", MockEndNode("end"), + inputs_schema={ + "result": "${a.aa}"}) + main_workflow.add_connection("start", "a") + main_workflow.add_connection("a", "end") + main_workflow.add_connection("start", "b") + main_workflow.add_connection("b", "end") + + async for chunk in main_workflow.stream({"a": 1, "b": "haha"}, create_context_with_tracer()): + if isinstance(chunk, TraceSchema): + print(f"stream chunk: {chunk}") + tracer_chunks.append(chunk) + + self.loop.run_until_complete(stream_workflow()) + record_tracer_info(tracer_chunks, "test_nested_parallel_stream_workflow_with_tracer.json")