From 405467e9a7c78a8e34486ded6fd79083ee50f51e Mon Sep 17 00:00:00 2001 From: CandiceGuo Date: Sat, 19 Jul 2025 18:06:28 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20Context,=20State=E6=94=B9=E4=B8=BA?= =?UTF-8?q?=E4=BB=A3=E7=90=86=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/component/condition/array.py | 16 +- jiuwen/core/component/condition/expression.py | 2 +- jiuwen/core/component/condition/number.py | 12 +- jiuwen/core/component/llm_comp.py | 4 +- .../loop_callback/intermediate_loop_var.py | 10 +- jiuwen/core/component/loop_callback/output.py | 26 +-- jiuwen/core/component/loop_comp.py | 22 +- jiuwen/core/component/questioner_comp.py | 4 +- jiuwen/core/component/set_variable_comp.py | 8 +- jiuwen/core/context/config.py | 109 ++++++++-- jiuwen/core/context/context.py | 159 +++++++++----- jiuwen/core/context/memory/base.py | 32 ++- jiuwen/core/context/state.py | 197 +++++++++++++----- jiuwen/core/context/store.py | 8 - jiuwen/core/context/utils.py | 24 +-- jiuwen/core/graph/base.py | 2 +- jiuwen/core/graph/vertex.py | 80 ++++--- jiuwen/core/tracer/tracer.py | 10 + jiuwen/core/workflow/base.py | 36 ++-- tests/unit_tests/core/context/test_context.py | 66 ++++++ .../tracer/test_mock_node_with_tracer.py | 12 +- ...st_nested_stream_workflow_with_tracer.json | 16 ++ ...llel_exec_stream_workflow_with_tracer.json | 10 + tests/unit_tests/tracer/test_workflow.py | 4 +- tests/unit_tests/workflow/test_mock_node.py | 6 +- tests/unit_tests/workflow/test_workflow.py | 5 +- 26 files changed, 608 insertions(+), 272 deletions(-) create mode 100644 tests/unit_tests/core/context/test_context.py create mode 100644 tests/unit_tests/tracer/test_nested_stream_workflow_with_tracer.json create mode 100644 tests/unit_tests/tracer/test_parallel_exec_stream_workflow_with_tracer.json diff --git a/jiuwen/core/component/condition/array.py b/jiuwen/core/component/condition/array.py index eb900bf..1869785 100644 --- a/jiuwen/core/component/condition/array.py +++ b/jiuwen/core/component/condition/array.py @@ -4,7 +4,7 @@ from typing import Union, Any from jiuwen.core.component.condition.condition import Condition -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, ExecutableContext from jiuwen.core.context.utils import extract_origin_key, NESTED_PATH_SPLIT DEFAULT_MAX_LOOP_NUMBER = 1000 @@ -13,18 +13,18 @@ DEFAULT_MAX_LOOP_NUMBER = 1000 class ArrayCondition(Condition): def __init__(self, context: Context, node_id: str, arrays: dict[str, Union[str, list[Any]]], index_path: str = None, array_root: str = None): - self._context = context.create_executable_context(node_id) + self._context = ExecutableContext(context, node_id) self._node_id = node_id self._arrays = arrays self._index_path = index_path if index_path else node_id + NESTED_PATH_SPLIT + "index" self._arrays_root = array_root if array_root else node_id + NESTED_PATH_SPLIT + "arrLoopVar" def init(self): - self._context.state.update_io({self._index_path: -1}) - self._context.state.update_io({self._arrays_root: {}}) + self._context.state().update_io({self._index_path: -1}) + self._context.state().update_io({self._arrays_root: {}}) def __call__(self) -> bool: - current_idx = self._context.state.get(self._index_path) + 1 + current_idx = self._context.state().get(self._index_path) + 1 min_length = DEFAULT_MAX_LOOP_NUMBER updates: dict[str, Any] = {} for key, array_info in self._arrays.items(): @@ -35,7 +35,7 @@ class ArrayCondition(Condition): elif isinstance(array_info, str): ref_str = extract_origin_key(array_info) if ref_str != "": - arr = self._context.state.get(ref_str) + arr = self._context.state().get(ref_str) else: raise RuntimeError("error value: " + array_info + " is not a array path") else: @@ -45,6 +45,6 @@ class ArrayCondition(Condition): return False updates[key_path] = arr[current_idx] - self._context.state.update_io({self._index_path: current_idx}) - self._context.state.update_io(updates) + self._context.state().update_io({self._index_path: current_idx}) + self._context.state().update_io(updates) return True diff --git a/jiuwen/core/component/condition/expression.py b/jiuwen/core/component/condition/expression.py index c770eda..b862a4b 100644 --- a/jiuwen/core/component/condition/expression.py +++ b/jiuwen/core/component/condition/expression.py @@ -20,7 +20,7 @@ class ExpressionCondition(Condition): matches = re.findall(pattern, self._expression) inputs = {} for match in matches: - inputs[match] = self._context.state.get(match[2:-1]) + inputs[match] = self._context.state().get(match[2:-1]) return self._evalueate_expression(self._expression, inputs) def _evalueate_expression(self, expression, inputs) -> bool: diff --git a/jiuwen/core/component/condition/number.py b/jiuwen/core/component/condition/number.py index 97bf896..074f63c 100644 --- a/jiuwen/core/component/condition/number.py +++ b/jiuwen/core/component/condition/number.py @@ -4,30 +4,30 @@ from typing import Union from jiuwen.core.component.condition.condition import Condition -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, ExecutableContext from jiuwen.core.context.utils import NESTED_PATH_SPLIT class NumberCondition(Condition): def __init__(self, context: Context, node_id: str, limit: Union[str, int], index_path: str = None): - self._context = context.create_executable_context(node_id) + self._context = ExecutableContext(context, node_id) self._index_path = index_path if index_path else node_id + NESTED_PATH_SPLIT + "index" self._limit = limit self._node_id = node_id def init(self): - self._context.state.update_io({self._index_path: -1}) + self._context.state().update_io({self._index_path: -1}) def __call__(self) -> bool: - current_idx = self._context.state.get(self._index_path) + 1 + current_idx = self._context.state().get(self._index_path) + 1 limit_num: int if isinstance(self._limit, int): limit_num = self._limit else: - limit_num = self._context.state.get(self._limit) + limit_num = self._context.state().get(self._limit) result = current_idx < limit_num if result: - self._context.state.update_io({self._index_path: current_idx}) + self._context.state().update_io({self._index_path: current_idx}) return result diff --git a/jiuwen/core/component/llm_comp.py b/jiuwen/core/component/llm_comp.py index bca67fe..468e39c 100644 --- a/jiuwen/core/component/llm_comp.py +++ b/jiuwen/core/component/llm_comp.py @@ -142,10 +142,10 @@ class LLMExecutable(Executable): try: self._set_context(context) model_inputs = self._prepare_model_inputs(inputs) - logger.info("[%s] model inputs %s", self._context.executable_id, model_inputs) + logger.info("[%s] model inputs %s", self._context.executable_id(), model_inputs) llm_response = await self._llm.ainvoke(model_inputs) response = llm_response.content - logger.info("[%s] model outputs %s", self._context.executable_id, response) + logger.info("[%s] model outputs %s", self._context.executable_id(), response) return self._create_output(response) except JiuWenBaseException: raise diff --git a/jiuwen/core/component/loop_callback/intermediate_loop_var.py b/jiuwen/core/component/loop_callback/intermediate_loop_var.py index b503499..a364a83 100644 --- a/jiuwen/core/component/loop_callback/intermediate_loop_var.py +++ b/jiuwen/core/component/loop_callback/intermediate_loop_var.py @@ -4,14 +4,14 @@ from typing import Union, Any from jiuwen.core.component.loop_callback.loop_callback import LoopCallback -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, ExecutableContext from jiuwen.core.context.utils import NESTED_PATH_SPLIT, is_ref_path, extract_origin_key class IntermediateLoopVarCallback(LoopCallback): def __init__(self, context: Context, node_id: str, intermediate_loop_var: dict[str, Union[str, Any]], intermediate_loop_var_root: str = None): - self._context = context.create_executable_context(node_id) + self._context = ExecutableContext(context, node_id) self._node_id = node_id self._intermediate_loop_var = intermediate_loop_var self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root \ @@ -24,15 +24,15 @@ class IntermediateLoopVarCallback(LoopCallback): if isinstance(value, str): if is_ref_path(value): ref_str = extract_origin_key(value) - update = self._context.state.get(ref_str) + update = self._context.state().get(ref_str) else: update = value else: update = value - self._context.state.update_io({path: update}) + self._context.state().update_io({path: update}) def out_loop(self): - self._context.state.update_io({self._intermediate_loop_var_root: {}}) + self._context.state().update_io({self._intermediate_loop_var_root: {}}) def start_round(self): pass diff --git a/jiuwen/core/component/loop_callback/output.py b/jiuwen/core/component/loop_callback/output.py index 9ffee87..c09302b 100644 --- a/jiuwen/core/component/loop_callback/output.py +++ b/jiuwen/core/component/loop_callback/output.py @@ -4,7 +4,7 @@ from typing import Any from jiuwen.core.component.loop_callback.loop_callback import LoopCallback -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, ExecutableContext from jiuwen.core.context.utils import is_ref_path, extract_origin_key, NESTED_PATH_SPLIT @@ -12,7 +12,7 @@ class OutputCallback(LoopCallback): def __init__(self, context: Context, node_id: str, outputs_format: dict[str, Any], round_result_root: str = None, result_root: str = None, intermediate_loop_var_root: str = None): self._node_id = node_id - self._context = context.create_executable_context(node_id) + self._context = ExecutableContext(context, node_id) self._outputs_format = outputs_format self._round_result_root = round_result_root if round_result_root else node_id + NESTED_PATH_SPLIT + "round" self._result_root = result_root if result_root else node_id @@ -29,34 +29,34 @@ class OutputCallback(LoopCallback): def first_in_loop(self): _results: list[(str, Any)] = [] self._generate_results(_results) - self._context.state.update({self._round_result_root: _results}) + self._context.state().update({self._round_result_root: _results}) def out_loop(self): - results: list[(str, Any)] = self._context.state.get(self._round_result_root) + results: list[(str, Any)] = self._context.state().get(self._round_result_root) if not isinstance(results, list): raise RuntimeError("error results in loop process") for (path, value) in results: - self._context.state.update_io({path: value}) - self._context.state.commit() - result = self._context.state.get_io(self._outputs_format) - self._context.state.update({self._round_result_root : {}}) - self._context.state.set_outputs(self._node_id, result) + self._context.state().update_io({path: value}) + self._context.state().commit() + result = self._context.state().get_io(self._outputs_format) + self._context.state().update({self._round_result_root : {}}) + self._context.state().update_io(result) def start_round(self): pass def end_round(self): - results: list[(str, Any)] = self._context.state.get(self._round_result_root) + results: list[(str, Any)] = self._context.state().get(self._round_result_root) if not isinstance(results, list): raise RuntimeError("error results in round process") for value in results: path = value[0] if path.startswith(self._intermediate_loop_var_root): - value[1] = self._context.state.get(path) + value[1] = self._context.state().get(path) elif isinstance(value, list): if value[1] is None: value[1] = [] - value[1].append(self._context.state.get(path)) + value[1].append(self._context.state().get(path)) else: raise RuntimeError("error process in loop: " + path + ", " + str(value)) - self._context.state.update({self._round_result_root : results}) + self._context.state().update({self._round_result_root : results}) diff --git a/jiuwen/core/component/loop_comp.py b/jiuwen/core/component/loop_comp.py index 74a1ae0..bcda4b3 100644 --- a/jiuwen/core/component/loop_comp.py +++ b/jiuwen/core/component/loop_comp.py @@ -11,7 +11,7 @@ from jiuwen.core.component.condition.condition import Condition, AlwaysTrue, Fun from jiuwen.core.component.condition.expression import ExpressionCondition from jiuwen.core.component.loop_callback.loop_callback import LoopCallback from jiuwen.core.context.config import Transformer, CompIOConfig -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, ExecutableContext from jiuwen.core.context.utils import NESTED_PATH_SPLIT from jiuwen.core.graph.base import Graph, Router, ExecutableGraph from jiuwen.core.graph.executable import Output, Input, Executable @@ -45,7 +45,7 @@ class LoopGroup: inputs_schema: dict = None, outputs_schema: dict = None, inputs_transformer: Transformer = None, outputs_transformer: Transformer = None) -> Self: component.add_component(self._graph, node_id, wait_for_all=wait_for_all) - self._context.config.set_comp_io_config(node_id, CompIOConfig(inputs_schema=inputs_schema, + self._context.config().set_comp_io_config(node_id, CompIOConfig(inputs_schema=inputs_schema, outputs_schema=outputs_schema, inputs_transformer=inputs_transformer, outputs_transformer=outputs_transformer)) @@ -89,7 +89,7 @@ class LoopComponent(WorkflowComponent, LoopController): if context_root is None: context_root = node_id - self._context = context.create_executable_context(self._node_id) + self._context = ExecutableContext(context, self._node_id) self._callbacks: list[LoopCallback] = [] self._context_root = context_root @@ -129,8 +129,8 @@ class LoopComponent(WorkflowComponent, LoopController): self._compiled = self._graph.compile(self._context) def init(self): - self._context.state.update({self._context_root + NESTED_PATH_SPLIT + BROKEN: False}) - self._context.state.update({self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: True}) + self._context.state().update({self._context_root + NESTED_PATH_SPLIT + BROKEN: False}) + self._context.state().update({self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: True}) self._condition.init() def to_executable(self) -> Executable: @@ -153,7 +153,7 @@ class LoopComponent(WorkflowComponent, LoopController): else: callback.out_loop() - self._context.state.commit() + self._context.state().commit() if continue_loop: return self._in_loop @@ -161,19 +161,19 @@ class LoopComponent(WorkflowComponent, LoopController): return self._out_loop def first_in_loop(self) -> bool: - _first_in_loop = self._context.state.get(self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP) + _first_in_loop = self._context.state().get(self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP) if isinstance(_first_in_loop, bool): if _first_in_loop: - self._context.state.update({self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: False}) + self._context.state().update({self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: False}) return _first_in_loop - self._context.state.update({self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: False}) + self._context.state().update({self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: False}) return True def is_broken(self) -> bool: - _is_broken = self._context.state.get(self._context_root + NESTED_PATH_SPLIT + BROKEN) + _is_broken = self._context.state().get(self._context_root + NESTED_PATH_SPLIT + BROKEN) if isinstance(_is_broken, bool): return _is_broken return False def break_loop(self): - self._context.state.update({self._context_root + NESTED_PATH_SPLIT + BROKEN: True}) + self._context.state().update({self._context_root + NESTED_PATH_SPLIT + BROKEN: True}) diff --git a/jiuwen/core/component/questioner_comp.py b/jiuwen/core/component/questioner_comp.py index 7e8cbdf..7dce7ff 100644 --- a/jiuwen/core/component/questioner_comp.py +++ b/jiuwen/core/component/questioner_comp.py @@ -384,7 +384,7 @@ class QuestionerExecutable(Executable): @staticmethod def _load_state_from_context(context) -> QuestionerState: - state_dict = context.state.get(QUESTIONER_STATE_KEY) + state_dict = context.state().get(QUESTIONER_STATE_KEY) if state_dict: return QuestionerState.deserialize(state_dict) return QuestionerState() @@ -392,7 +392,7 @@ class QuestionerExecutable(Executable): @staticmethod def _store_state_to_context(state: QuestionerState, context): state_dict = state.serialize() - context.state.update({QUESTIONER_STATE_KEY: state_dict}) + context.state().update({QUESTIONER_STATE_KEY: state_dict}) def state(self, state: QuestionerState): self._state = state diff --git a/jiuwen/core/component/set_variable_comp.py b/jiuwen/core/component/set_variable_comp.py index 55dc169..af98e29 100644 --- a/jiuwen/core/component/set_variable_comp.py +++ b/jiuwen/core/component/set_variable_comp.py @@ -6,7 +6,7 @@ from functools import partial from typing import AsyncIterator, Iterator, Any from jiuwen.core.component.base import WorkflowComponent -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, ExecutableContext from jiuwen.core.context.utils import extract_origin_key, is_ref_path from jiuwen.core.graph.executable import Executable, Input, Output @@ -14,7 +14,7 @@ from jiuwen.core.graph.executable import Executable, Input, Output class SetVariableComponent(WorkflowComponent, Executable): def __init__(self, node_id: str, context: Context, variable_mapping: dict[str, Any]): - self._context = context.create_executable_context(node_id) + self._context = ExecutableContext(context, node_id) self._node_id = node_id self._variable_mapping = variable_mapping @@ -25,9 +25,9 @@ class SetVariableComponent(WorkflowComponent, Executable): left_ref_str = left if isinstance(right, str) and is_ref_path(right): ref_str = extract_origin_key(right) - self._context.state.update_io({left_ref_str: self._context.state.get(ref_str)}) + self._context.state().update_io({left_ref_str: self._context.state().get(ref_str)}) continue - self._context.state.update_io({left_ref_str: right}) + self._context.state().update_io({left_ref_str: right}) return None diff --git a/jiuwen/core/context/config.py b/jiuwen/core/context/config.py index 9c62009..3c074a4 100644 --- a/jiuwen/core/context/config.py +++ b/jiuwen/core/context/config.py @@ -1,16 +1,18 @@ #!/usr/bin/python3.10 # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved -from abc import ABC +from abc import ABC, abstractmethod from typing import TypedDict, Any, Optional from jiuwen.core.component.common.configs.workflow_config import WorkflowMetadata from jiuwen.core.context.state import Transformer + class MetadataLike(TypedDict): name: str event: str + class CompIOConfig(ABC): def __init__(self, inputs_schema: dict = None, outputs_schema: dict = None, @@ -21,12 +23,14 @@ class CompIOConfig(ABC): self.inputs_transformer = inputs_transformer self.outputs_transformer = outputs_transformer + class WorkflowConfig: metadata: WorkflowMetadata = WorkflowMetadata() comp_configs: dict[str, CompIOConfig] = {} stream_edges: dict[str, list[str]] = {} -class Config(ABC): + +class BaseConfig(ABC): """ Config is the class defines the basic infos of workflow """ @@ -37,11 +41,95 @@ class Config(ABC): """ self._callback_metadata: dict[str, MetadataLike] = {} self._env: dict = {} - self._workflow_config: WorkflowConfig = WorkflowConfig() + self._workflow_configs: dict[str, WorkflowConfig] = {} self.__load_envs__() + def add_workflow_config(self, workflow_executable_id: str, workflow_config: WorkflowConfig) -> None: + self._workflow_configs[workflow_executable_id] = workflow_config + + def set_envs(self, envs: dict[str, str]) -> None: + """ + set environment variables + :param envs: envs + """ + self._env.update(envs) + + def get_env(self, key: str) -> Any: + """ + get environment variable by given key + :param key: environment variable key + :return: environment variable value + """ + if key in self._env: + return self._env[key] + else: + return None + + def __load_envs__(self) -> None: + pass + + +class WorkflowExecutableConfigLike(ABC): + @abstractmethod + def set_comp_io_config(self, node_id: str, comp_io_config: CompIOConfig) -> None: + pass + + @abstractmethod + def get_inputs_schema(self, node_id: str) -> dict: + pass + + @abstractmethod + def get_outputs_schema(self, node_id: str) -> dict: + pass + + @abstractmethod + def get_input_transformer(self, node_id: str) -> Optional[Transformer]: + pass + + @abstractmethod + def get_output_transformer(self, node_id: str) -> Optional[Transformer]: + pass + + @abstractmethod + def set_stream_edge(self, source_node_id: str, target_node_id: str) -> None: + pass + + @abstractmethod + def set_stream_edges(self, edges: dict[str, list[str]]) -> None: + pass + + @abstractmethod + def is_stream_edge(self, source_node_id: str, target_node_id: str) -> bool: + pass + + +class ConfigLike(ABC): + @abstractmethod + def set_envs(self, envs: dict[str, str]) -> None: + pass + + @abstractmethod + def get_env(self, key: str) -> Any: + pass + + def set_workflow_config(self, workflow_config: WorkflowConfig) -> None: + pass + + +class Config(ConfigLike, WorkflowExecutableConfigLike): + def __init__(self, base_config: BaseConfig = None, workflow_executable_id: str = ''): + self._workflow_executable_id = workflow_executable_id + self._base_config = base_config if base_config is not None else BaseConfig() + self._workflow_config: WorkflowConfig = WorkflowConfig() + self._base_config.add_workflow_config(self._workflow_executable_id, self._workflow_config) + + def base_config(self) -> BaseConfig: + return self._base_config + def set_workflow_config(self, workflow_config: WorkflowConfig) -> None: - self._workflow_config = workflow_config + self._workflow_config.metadata = workflow_config.metadata + self._workflow_config.comp_configs.update(workflow_config.comp_configs) + self._workflow_config.stream_edges.update(workflow_config.stream_edges) def set_comp_io_config(self, node_id: str, comp_io_config: CompIOConfig) -> None: """ @@ -117,14 +205,15 @@ class Config(ABC): :param target_node_id: target node id :return: true if is stream edge """ - return (target_node_id in source_node_id) and (source_node_id in self._workflow_config.stream_edges[target_node_id]) + return (target_node_id in source_node_id) and ( + source_node_id in self._workflow_config.stream_edges[target_node_id]) def set_envs(self, envs: dict[str, str]) -> None: """ set environment variables :param envs: envs """ - self._env.update(envs) + self._base_config.set_envs(envs) def get_env(self, key: str) -> Any: """ @@ -132,10 +221,4 @@ class Config(ABC): :param key: environment variable key :return: environment variable value """ - if key in self._env: - return self._env[key] - else: - return None - - def __load_envs__(self) -> None: - pass + return self._base_config.get_env(key) diff --git a/jiuwen/core/context/context.py b/jiuwen/core/context/context.py index 685ece8..8fd2b83 100644 --- a/jiuwen/core/context/context.py +++ b/jiuwen/core/context/context.py @@ -1,11 +1,11 @@ #!/usr/bin/python3.10 # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved -from abc import ABC -from typing import Any, Self +from abc import ABC, abstractmethod +from typing import Any from jiuwen.core.context.config import Config -from jiuwen.core.context.state import State +from jiuwen.core.context.state import State, ExecutableState from jiuwen.core.context.store import Store from jiuwen.core.runtime.callback_manager import CallbackManager from jiuwen.core.stream.manager import StreamWriterManager @@ -13,72 +13,135 @@ from jiuwen.core.tracer.tracer import Tracer class Context(ABC): - def __init__(self, config: Config, state: State, store: Store = None): - self._config = config - self._state = state - self._store = store - self._tracer = None - self._callback_manager = CallbackManager() - self._stream_writer_manager: StreamWriterManager = None - self._controller_context_manager = None + @abstractmethod + def config(self) -> Config: + pass - def set_stream_writer_manager(self, stream_writer_manager: StreamWriterManager): - if self._stream_writer_manager is not None: - return - self._stream_writer_manager = stream_writer_manager + @abstractmethod + def state(self) -> State: + pass - def set_tracer(self, tracer: Tracer): - self._tracer = tracer + @abstractmethod + def store(self) -> Store: + pass - def set_controller_context_manager(self, controller_context_manager): - self._controller_context_manager = controller_context_manager + @abstractmethod + def tracer(self) -> Any: + pass + + @abstractmethod + def stream_writer_manager(self) -> StreamWriterManager: + pass + + @abstractmethod + def callback_manager(self) -> CallbackManager: + pass + + @abstractmethod + def controller_context_manager(self) -> Any: + pass + + +class WorkflowContext(Context): + def __init__(self, config: Config = Config(), state: State = None, store: Store = None): + self.__config = config + self.__state = state + self.__store = store + self.__tracer = None + self.__callback_manager = CallbackManager() + + def set_tracer(self, tracer: Tracer): + self.__tracer = tracer - @property def config(self) -> Config: - return self._config + return self.__config - @property def state(self) -> State: - return self._state + return self.__state - @property def store(self) -> Store: - return self._store + return self.__store - @property def tracer(self) -> Any: - return self._tracer - - @property - def stream_writer_manager(self) -> StreamWriterManager: - return self._stream_writer_manager + return self.__tracer - @property def callback_manager(self) -> CallbackManager: - return self._callback_manager + return self.__callback_manager + + def stream_writer_manager(self) -> StreamWriterManager: + return None - def create_executable_context(self, node_id: str) -> Self: - context = ExecutableContext(self, node_id) - context.set_stream_writer_manager(self._stream_writer_manager) - context.set_tracer(self.tracer) - return context + def controller_context_manager(self) -> Any: + return None class ExecutableContext(Context): - def __init__(self, context: Context, node_id: str): + def __init__(self, context: Context, node_id: str = '', is_workflow: bool = False): + super().__init__() + self._context = context + self._tracer = context.tracer() self._node_id = node_id - self._parent_id = context.executable_id if isinstance(context, ExecutableContext) else None - self._executable_id = self._parent_id + "." + node_id if self._parent_id is not None else node_id - super().__init__(context.config, context.state.create_executable_state(self._executable_id), context.store) + if isinstance(context, ExecutableContext): + self._parent_id = context.executable_id() + self._executable_id = (self._parent_id + "." + node_id) if self._parent_id is not None and len(self._parent_id) > 0 else node_id + self._stream_writer_manager = context.stream_writer_manager() + self._controller_context_manager = context.controller_context_manager() + self._workflow_executable_id = self._executable_id if is_workflow else context.workflow_executable_id() + self._workflow_id = self._node_id if is_workflow else context.workflow_id() + self._workflow_config = Config(context.config().base_config(), self._executable_id) if is_workflow else context.config() + else: + self._executable_id = node_id + self._stream_writer_manager = None + self._controller_context_manager = None + self._parent_id = '' + self._workflow_executable_id = '' + self._workflow_id = '' + self._workflow_config = Config(context.config().base_config(), self._executable_id) + self._state = ExecutableState(context.state().base_state(), self._executable_id) - @property - def node_id(self): + def config(self) -> Config: + return self._workflow_config + + def state(self) -> State: + return self._state + + def tracer(self) -> Tracer: + return self._tracer + + def store(self) -> Store: + return self._context.store() + + def node_id(self) -> str: return self._node_id - @property - def executable_id(self): + def executable_id(self) -> str: return self._executable_id - @property - def parent_id(self): + def parent_id(self) -> str: return self._parent_id + + def workflow_executable_id(self) -> str: + return self._workflow_executable_id + + def workflow_id(self) -> str: + return self._workflow_id + + def stream_writer_manager(self) -> StreamWriterManager: + return self._stream_writer_manager + + def callback_manager(self) -> CallbackManager: + return self._context.callback_manager() + + def controller_context_manager(self) -> Any: + return self._context.controller_context_manager() + + def set_tracer(self, new_tracer: Tracer) -> None: + self._tracer = new_tracer + + def set_stream_writer_manager(self, stream_writer_manager: StreamWriterManager): + if self._stream_writer_manager is not None: + return + self._stream_writer_manager = stream_writer_manager + + def set_controller_context_manager(self, controller_context_manager): + self._controller_context_manager = controller_context_manager diff --git a/jiuwen/core/context/memory/base.py b/jiuwen/core/context/memory/base.py index 6e28660..b2af4d3 100644 --- a/jiuwen/core/context/memory/base.py +++ b/jiuwen/core/context/memory/base.py @@ -5,9 +5,10 @@ from copy import deepcopy from typing import Union, Optional, Any, Callable from jiuwen.core.common.exception.exception import JiuWenBaseException -from jiuwen.core.context.state import Transformer -from jiuwen.core.context.state import CommitState, StateLike, State -from jiuwen.core.context.utils import update_dict, get_by_schema +from jiuwen.core.context.state import Transformer, BaseState, ExecutableState +from jiuwen.core.context.state import CommitState, StateLike +from jiuwen.core.context.store import Store +from jiuwen.core.context.utils import update_dict, get_by_schema, get_value_by_nested_path class InMemoryStateLike(StateLike): @@ -17,6 +18,9 @@ class InMemoryStateLike(StateLike): def get(self, key: Union[str, list, dict]) -> Optional[Any]: return get_by_schema(key, self._state) + def get_by_prefix(self, key: Union[str, list, dict], prefix: str = None) -> Optional[Any]: + return get_by_schema(key, get_value_by_nested_path(prefix, self._state)) + def get_by_transformer(self, transformer: Callable) -> Optional[Any]: return transformer(self._state) @@ -60,11 +64,23 @@ class InMemoryCommitState(CommitState): def get(self, key: Union[str, list, dict]) -> Optional[Any]: return self._state.get(key) -class InMemoryState(State): + def get_by_prefix(self, key: Union[str, list, dict], prefix: str = None) -> Optional[Any]: + return self._state.get_by_prefix(key, prefix) + + +class InMemoryState(ExecutableState): + def __init__(self): + super().__init__(BaseState(io_state=InMemoryCommitState(), + global_state=InMemoryCommitState(), + trace_state=dict(), + comp_state=InMemoryCommitState())) + +class InMemoryStore(Store): def __init__(self): - super().__init__(io_state=InMemoryCommitState(), - global_state=InMemoryCommitState(), - trace_state=dict(), - comp_state=InMemoryCommitState()) + self._data = {} + def read(self, key: Union[str, dict]) -> Optional[Any]: + return get_by_schema(self._data, key) + def write(self, value: dict) -> None: + update_dict(self._data, value) diff --git a/jiuwen/core/context/state.py b/jiuwen/core/context/state.py index ac5d11f..b579934 100644 --- a/jiuwen/core/context/state.py +++ b/jiuwen/core/context/state.py @@ -5,7 +5,7 @@ import uuid from abc import ABC, abstractmethod from typing import Any, Union, Optional, Callable, Self -from jiuwen.core.common.logging.base import logger +from urllib3.util.wait import select_wait_for_socket class ReadableStateLike(ABC): @@ -13,7 +13,6 @@ class ReadableStateLike(ABC): def get(self, key: Union[str, list, dict]) -> Optional[Any]: pass - Transformer = Callable[[ReadableStateLike], Any] @@ -22,6 +21,10 @@ class StateLike(ReadableStateLike): def get_by_transformer(self, transformer: Transformer) -> Optional[Any]: pass + @abstractmethod + def get_by_prefix(self, key: Union[str, list, dict], prefix: str = None) -> Optional[Any]: + pass + @abstractmethod def update(self, node_id: str, data: dict) -> None: pass @@ -41,96 +44,180 @@ class CommitState(StateLike): pass -class State(ABC): +class BaseState(ABC): def __init__( self, io_state: CommitState, global_state: CommitState, comp_state: CommitState, - trace_state: dict = {}, - node_id: str = None + trace_state: dict ): self._io_state = io_state self._global_state = global_state self._trace_state = trace_state self._comp_state = comp_state - self._node_id = node_id + + def io_state(self) -> CommitState: + return self._io_state + + def global_state(self) -> CommitState: + return self._global_state + + def comp_state(self) -> CommitState: + return self._comp_state + + def trace_state(self) -> dict: + return self._trace_state + + +class State(ABC): + @abstractmethod + def get(self, key: Union[str, list, dict]) -> Optional[Any]: + pass + + @abstractmethod + def get_io(self, key: Union[str, list, dict]) -> Optional[Any]: + pass + + @abstractmethod + def get_comp(self, key: Union[str, list, dict]) -> Optional[Any]: + pass + + @abstractmethod + def get_trace(self, key: Union[str, list, dict]) -> Optional[Any]: + pass + + @abstractmethod + def update(self, data: dict) -> None: + pass + + @abstractmethod + def update_io(self, data: dict) -> None: + pass + + @abstractmethod + def update_trace(self, span) -> None: + pass + + @abstractmethod + def update_comp(self, data: dict) -> None: + pass + + @abstractmethod + def get_inputs_by_transformer(self, transformer: Callable) -> dict: + pass + + @abstractmethod + def set_user_inputs(self, inputs: Any) -> None: + pass + + @abstractmethod + def get_outputs(self, node_id: str) -> Any: + pass + + @abstractmethod + def commit(self) -> None: + pass + + @abstractmethod + def rollback(self, node_id: str) -> None: + pass + + @abstractmethod + def get_updates(self, node_id: str) -> list[dict]: + pass + + +class ExecutableState(State): + def __init__(self, state: BaseState, node_id: str = None): + self._base_state = state.base_state() if isinstance(state, ExecutableState) else state + self._node_id = node_id if node_id is not None else "" + + def base_state(self) -> BaseState: + return self._base_state def get(self, key: Union[str, list, dict]) -> Optional[Any]: - if self._global_state is None: + if self._base_state.global_state() is None: return None - value = self._global_state.get(key) + value = self._base_state.global_state().get(key) if value is None: - return self._io_state.get(key) + return self._base_state.io_state().get(key) return value + def get_io(self, key: Union[str, list, dict]) -> Optional[Any]: + if self._base_state.io_state() is None: + return None + return self._base_state.io_state().get(key) + + def get_comp(self, key: Union[str, list, dict]) -> Optional[Any]: + if self._base_state.comp_state() is None: + return None + if key == '': + return self._base_state.comp_state().get(self._node_id) + return self._base_state.comp_state().get_by_prefix(key, self._node_id) + + def get_trace(self, key: Union[str, list, dict]) -> Optional[Any]: + if self._base_state.trace_state() is None: + return None + return self._base_state.trace_state().get(key) + def update(self, data: dict) -> None: - if self._global_state is None: + if self._base_state.global_state() is None: return - self._global_state.update(self._node_id, data) + self._base_state.global_state().update(self._node_id, data) def update_io(self, data: dict) -> None: - if self._io_state is None: + if self._base_state.io_state() is None: return - self._io_state.update(self._node_id, data) + self._base_state.io_state().update(self._node_id, {self._node_id:data}) - def get_io(self, key: Union[str, list, dict]) -> Optional[Any]: - if self._io_state is None: + def update_trace(self, span) -> None: + if self._base_state.trace_state() is None: return - return self._io_state.get(key) - - def update_trace(self, invoke_id: str, span): - self._trace_state.update({invoke_id: span}) + self._base_state.trace_state()[self._node_id] = span def update_comp(self, data: dict) -> None: - if self._comp_state is None: + if self._base_state.io_state() is None: return - self._comp_state.update(self._node_id, data) + self._base_state.comp_state().update(self._node_id, {self._node_id:data}) - def get_comp(self, key: Union[str, list, dict]) -> Optional[Any]: - if self._comp_state is None: - return - return self._comp_state.get(key) + def get_inputs_by_transformer(self, transformer: Transformer) -> Optional[dict]: + if self._base_state.io_state() is None: + return None + else: + return self._base_state.io_state().get_by_transformer(transformer) + + def get_outputs(self, node_id: str) -> Any: + if self._base_state.io_state() is None: + return None + else: + return self._base_state.io_state().get(node_id) def set_user_inputs(self, inputs: Any) -> None: - if self._io_state is None or inputs is None: + if self._base_state.io_state() is None or inputs is None: return - self._io_state.update("", inputs) - self._global_state.update("", inputs) + self._base_state.io_state().update(self._node_id, inputs) + self._base_state.global_state().update(self._node_id, inputs) self.commit() - def get_inputs_by_transformer(self, transformer: Callable) -> dict: - if self._io_state is None: - return {} - return self._io_state.get_by_transformer(transformer) - - def get_outputs(self, node_id: str) -> Any: - if self._io_state is None: - return {} - return self._io_state.get(node_id) - def set_outputs(self, node_id: str, outputs: dict) -> None: - if self._io_state is None or outputs is None: + if self._base_state.io_state() is None or outputs is None: return - return self._io_state.update(node_id, {node_id: outputs}) - - def create_executable_state(self, node_id: str) -> Self: - return State(io_state=self._io_state, global_state=self._global_state, comp_state=self._comp_state, - trace_state=self._trace_state, node_id=node_id) + return self._base_state.io_state().update(node_id, {node_id: outputs}) def commit(self) -> None: - self._io_state.commit() - self._comp_state.commit() - self._global_state.commit() + self._base_state.io_state().commit() + self._base_state.comp_state().commit() + self._base_state.global_state().commit() - def rollback(self) -> None: - self._comp_state.rollback(self._node_id) - self._io_state.rollback(self._node_id) - self._global_state.rollback(self._node_id) + def rollback(self, node_id: str) -> None: + self._base_state.comp_state().rollback(self._node_id) + self._base_state.io_state().rollback(self._node_id) + self._base_state.global_state().rollback(self._node_id) - def get_updates(self) -> dict: + def get_updates(self, node_id: str) -> dict: return { - "io": self._io_state.get_updates(self._node_id), - "global": self._global_state.get_updates(self._node_id), - "comp": self._comp_state.get_updates(self._node_id) + "io": self._base_state.io_state().get_updates(self._node_id), + "global": self._base_state.global_state().get_updates(self._node_id), + "comp": self._base_state.comp_state().get_updates(self._node_id) } diff --git a/jiuwen/core/context/store.py b/jiuwen/core/context/store.py index 81ffc47..12c0b79 100644 --- a/jiuwen/core/context/store.py +++ b/jiuwen/core/context/store.py @@ -25,11 +25,3 @@ class FileStore(Store): def write(self, value: dict) -> None: pass - - -class MemoryStore(Store): - def read(self, key: Union[str, dict]) -> Optional[Any]: - pass - - def write(self, value: dict) -> None: - pass diff --git a/jiuwen/core/context/utils.py b/jiuwen/core/context/utils.py index af1ff07..1d402c6 100644 --- a/jiuwen/core/context/utils.py +++ b/jiuwen/core/context/utils.py @@ -24,24 +24,14 @@ def update_dict(update: dict, source: dict) -> None: current_key, current = root_to_path(key, source, create_if_absent=True) update_by_key(current_key, value, current) - def get(self, key: Union[str, list, dict]) -> Optional[Any]: - if isinstance(key, str): - origin_key = extract_origin_key(key) - return get_value_by_nested_path(origin_key, self._state) - elif isinstance(key, dict): - result = {} - for target_key, target_schema in key.items(): - result[target_key] = self.get(target_schema) - return result - elif isinstance(key, list): - result = [] - for item in key: - result.append(self.get(item)) - return result - else: - return key - def get_by_schema(schema: Union[str, list, dict], data: dict) -> Any: + """ + get dict values by schema + + :param schema: + :param data: + :return: + """ if schema is None or data is None: return None if isinstance(schema, str): diff --git a/jiuwen/core/graph/base.py b/jiuwen/core/graph/base.py index ce81820..3a4da9a 100644 --- a/jiuwen/core/graph/base.py +++ b/jiuwen/core/graph/base.py @@ -14,7 +14,7 @@ CONFIG_KEY = "config" class ExecutableGraph(Executable[Input, Output]): async def invoke(self, inputs: Input, context: Context) -> Output: - context.state.set_user_inputs(inputs.get(INPUTS_KEY)) + context.state().set_user_inputs(inputs.get(INPUTS_KEY)) results = await self._invoke(inputs.get(CONFIG_KEY)) return results diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index 2687912..5f1ef05 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -18,84 +18,80 @@ class Vertex: def __init__(self, node_id: str, executable: Executable = None): self._node_id = node_id self._executable = executable - self._context: ExecutableContext = None def init(self, context: Context) -> bool: - self._context = context.create_executable_context(self._node_id) + self._context = context return True async def __call__(self, state: GraphState, config: Any = None) -> Output: if self._context is None or self._executable is None: raise JiuWenBaseException(1, "vertex is not initialized, node is is " + self._node_id) - inputs = await self.__pre_invoke__() - logger.info("vertex[%s] inputs %s", self._context.executable_id, inputs) + context = ExecutableContext(self._context , self._node_id) + inputs = await self.__pre_invoke__(context, config) + logger.info("vertex[%s] inputs %s", context.executable_id(), inputs) is_stream = self.__is_stream__(state) - if isinstance(self._executable, ExecWorkflowComponent) or isinstance(self._executable, ExecutableGraph): - inputs = {INPUTS_KEY: inputs, CONFIG_KEY: config} - try: if is_stream: - result_iter = await self._executable.stream(inputs, context=self._context) + result_iter = await self._executable.stream(inputs, context=context) self.__post_stream__(result_iter) else: - results = await self._executable.invoke(inputs, context=self._context) - await self.__post_invoke__(results) + results = await self._executable.invoke(inputs, context=context) + outputs = await self.__post_invoke__(context, results) + logger.info("vertex[%s] outputs %s", context.executable_id(), outputs) except JiuWenBaseException as e: raise JiuWenBaseException(e.error_code, "failed to invoke, caused by " + e.message) return {"source_node_id": [self._node_id]} - async def __pre_invoke__(self) -> Optional[dict]: - inputs_transformer = self._context.config.get_input_transformer(self._node_id) + async def __pre_invoke__(self, context: ExecutableContext, config: Any) -> Optional[dict]: + inputs_transformer = context.config().get_input_transformer(self._node_id) if inputs_transformer is None: - inputs_schema = self._context.config.get_inputs_schema(self._node_id) - inputs = self._context.state.get_io(inputs_schema) + inputs_schema = context.config().get_inputs_schema(self._node_id) + inputs = context.state().get_io(inputs_schema) else: - inputs = self._context.state.get_inputs_by_transformer(inputs_transformer) - if self._context.tracer is not None: - await self.__trace_inputs__(inputs) + inputs = context.state().get_inputs_by_transformer(inputs_transformer) + if context.tracer() is not None: + await self.__trace_inputs__(context, inputs) + if isinstance(self._executable, ExecWorkflowComponent) or isinstance(self._executable, ExecutableGraph): + inputs = {INPUTS_KEY: inputs, CONFIG_KEY: config} return inputs - async def __post_invoke__(self, results: Optional[dict]) -> None: - output_transformer = self._context.config.get_output_transformer(self._node_id) + async def __post_invoke__(self, context: ExecutableContext, results: Optional[dict]) -> Any: + output_transformer = context.config().get_output_transformer(self._node_id) if output_transformer is None: - output_schema = self._context.config.get_outputs_schema(self._node_id) + output_schema = context.config().get_outputs_schema(self._node_id) results = get_by_schema(output_schema, results) if output_schema else results else: results = output_transformer(results) - logger.info("vertex[%s] outputs %s", self._context.executable_id, results) - self._context.state.set_outputs(self._node_id, results) - # todo: need move to checkpoint - self._context.state.commit() - if self._context.tracer is not None: - await self.__trace_outputs__(results) + context.state().update_io(results) + context.state().commit() + if context.tracer() is not None: + await self.__trace_outputs__(results, context) + return results def __post_stream__(self, results_iter: Any) -> None: pass - async def __trace_inputs__(self, inputs: Optional[dict]) -> None: + async def __trace_inputs__(self, context: ExecutableContext, inputs: Optional[dict]) -> None: # TODO 组件信息 - - await self._context.tracer.trigger("tracer_workflow", "on_pre_invoke", invoke_id=self._context.executable_id, + await context.tracer().trigger("tracer_workflow", "on_pre_invoke", invoke_id=context.executable_id(), inputs=inputs, - component_metadata={"component_type": self._context.executable_id}) - self._context.state.update_trace(self._node_id, - self._context.tracer.tracer_workflow_span_manager.get_span(self._node_id)) + component_metadata={"component_type": context.executable_id()}) + context.state().update_trace(context.tracer().tracer_workflow_span_manager.get_span(self._node_id)) if isinstance(self._executable, ExecWorkflowComponent): - self._origin_tracer = self._context.tracer - sub_tracer = Tracer(tracer_id=self._context.tracer._trace_id, parent_node_id=self._context.executable_id) - sub_tracer.init(self._context.stream_writer_manager, self._origin_tracer._callback_manager) - self._context.set_tracer(sub_tracer) + self._origin_tracer = context.tracer() + sub_tracer = Tracer(tracer_id=context.tracer()._trace_id, parent_node_id=context.executable_id()) + sub_tracer.init(context.stream_writer_manager(), self._origin_tracer._callback_manager) + context.set_tracer(sub_tracer) - async def __trace_outputs__(self, outputs: Optional[dict] = None) -> None: + async def __trace_outputs__(self, outputs: Optional[dict] = None, context: ExecutableContext = None) -> None: if isinstance(self._executable, ExecWorkflowComponent): - self._context.set_tracer(self._origin_tracer) + context.set_tracer(self._origin_tracer) - await self._context.tracer.trigger("tracer_workflow", "on_post_invoke", invoke_id=self._context.executable_id, + await context.tracer().trigger("tracer_workflow", "on_post_invoke", invoke_id=context.executable_id(), outputs=outputs) - self._context.state.update_trace(self._context.executable_id, - self._context.tracer.tracer_workflow_span_manager.get_span( - self._context.executable_id)) + context.state().update_trace(context.tracer().tracer_workflow_span_manager.get_span( + context.executable_id())) def __is_stream__(self, state: GraphState) -> bool: return False diff --git a/jiuwen/core/tracer/tracer.py b/jiuwen/core/tracer/tracer.py index 5ea7ff0..ec6db09 100644 --- a/jiuwen/core/tracer/tracer.py +++ b/jiuwen/core/tracer/tracer.py @@ -4,6 +4,10 @@ from jiuwen.core.tracer.handler import TraceAgentHandler, TraceWorkflowHandler, from jiuwen.core.tracer.span import SpanManager +class Tracer: + pass + + class Tracer: def __init__(self, tracer_id=None, parent_node_id=""): self._callback_manager = None @@ -11,6 +15,7 @@ class Tracer: self.tracer_agent_span_manager = SpanManager(self._trace_id) self.tracer_workflow_span_manager = SpanManager(self._trace_id, parent_node_id=parent_node_id) self._parent_node_id = parent_node_id + self._children: dict[str, Tracer] = {} def init(self, stream_writer_manager, callback_manager): # 用于注册子workflow tracer handler,子workflow中使用新的tracer handler @@ -32,3 +37,8 @@ class Tracer: async def trigger(self, handler_class_name: str, event_name: str, **kwargs): handler_class_name += "." + self._parent_node_id if self._parent_node_id != "" else "" await self._callback_manager.trigger(handler_class_name, event_name, **kwargs) + + def add_child(self, trace_id: str, child: Tracer)->None: + self._children[trace_id] = child + + diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index 10116ed..331b6c2 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -9,7 +9,7 @@ from pydantic import BaseModel from jiuwen.core.common.logging.base import logger from jiuwen.core.component.base import WorkflowComponent, StartComponent, EndComponent from jiuwen.core.context.config import CompIOConfig, Transformer, WorkflowConfig -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, ExecutableContext from jiuwen.core.graph.base import Graph, Router, INPUTS_KEY, CONFIG_KEY from jiuwen.core.graph.executable import Executable, Input, Output from jiuwen.core.stream.base import StreamMode, BaseStreamMode @@ -112,10 +112,10 @@ class Workflow: async def invoke(self, inputs: Input, context: Context, config: Any = None) -> Output: logger.info("begin to invoke, input=%s", inputs) - context.config.set_workflow_config(self._workflow_config) - compiled_graph = self._graph.compile(context) + executable_context = self._create_executable_context(context) + compiled_graph = self._graph.compile(executable_context) await compiled_graph.invoke({INPUTS_KEY: inputs, CONFIG_KEY: config}, context) - results = context.state.get_outputs(self._end_comp_id) + results = context.state().get_outputs(self._end_comp_id) logger.info("end to invoke, results=%s", results) return results @@ -125,23 +125,31 @@ class Workflow: context: Context, stream_modes: list[StreamMode] = None ) -> AsyncIterator[WorkflowChunk]: - context.config.set_workflow_config(self._workflow_config) - context.set_stream_writer_manager(StreamWriterManager(stream_emitter=StreamEmitter(), modes=stream_modes)) - if context.tracer is None and (stream_modes is None or BaseStreamMode.TRACE in stream_modes): - tracer = Tracer() - tracer.init(context.stream_writer_manager, context.callback_manager) - context.set_tracer(tracer) + context = self._create_executable_context(context, True, stream_modes) compiled_graph = self._graph.compile(context) - context.state.set_user_inputs(inputs) - context.state.commit() + context.state().set_user_inputs(inputs) + context.state().commit() async def stream_process(): await compiled_graph.invoke(inputs, context) - await context.stream_writer_manager.stream_emitter.close() + await context.stream_writer_manager().stream_emitter.close() asyncio.create_task(stream_process()) - async for chunk in context.stream_writer_manager.stream_output(): + async for chunk in context.stream_writer_manager().stream_output(): yield chunk def _convert_to_component(self, executable: Executable) -> WorkflowComponent: pass + + def _create_executable_context(self, context: Context, stream_out: bool = False, + stream_modes: list[StreamMode] = None) -> ExecutableContext: + executable_context = ExecutableContext(context=context, is_workflow=True) if not isinstance(context, ExecutableContext) else context + executable_context.config().set_workflow_config(self._workflow_config) + if stream_out: + executable_context.set_stream_writer_manager( + StreamWriterManager(stream_emitter=StreamEmitter(), modes=stream_modes)) + if executable_context.tracer() is None and (stream_modes is None or BaseStreamMode.TRACE in stream_modes): + tracer = Tracer() + tracer.init(executable_context.stream_writer_manager(), executable_context.callback_manager()) + executable_context.set_tracer(tracer) + return executable_context diff --git a/tests/unit_tests/core/context/test_context.py b/tests/unit_tests/core/context/test_context.py new file mode 100644 index 0000000..6669d96 --- /dev/null +++ b/tests/unit_tests/core/context/test_context.py @@ -0,0 +1,66 @@ +import unittest + +from jiuwen.core.context.config import Config +from jiuwen.core.context.context import WorkflowContext, ExecutableContext +from jiuwen.core.context.memory.base import InMemoryState + + +class ContextTest(unittest.TestCase): + def assert_context(self, context: ExecutableContext, node_id: str, executable_id: str, workflow_id: str, + workflow_executable_id: str): + assert context.node_id() == node_id + assert context.executable_id() == executable_id + assert context.workflow_id() == workflow_id + assert context.workflow_executable_id() == workflow_executable_id + + + def test_basic(self): + # Workflow context/ + context = WorkflowContext(config=Config(), state=InMemoryState(), store=None) + w1_context = ExecutableContext(context=context) + w1_context.state().set_user_inputs({'a': 1, 'b': 2}) + + self.assert_context(w1_context, '', '', '', '') + assert w1_context.state().get('a') == 1 + assert w1_context.state().get('b') == 2 + + w1_node1_context = ExecutableContext(context=w1_context, node_id='node1') + w1_node1_context.state().update({'a': 2}) + w1_node1_context.state().update_io({"n1": 1, "n2": 2}) + w1_node1_context.state().update_comp({"url": "0.0.0.1"}) + w1_node1_context.state().commit() + + self.assert_context(w1_node1_context, 'node1', 'node1', '', '') + assert w1_node1_context.state().get('a') == 2 + assert w1_node1_context.state().get_comp('url') == '0.0.0.1' + assert w1_node1_context.state().get_io({"n": ["${node1.n1}", "${node1.n2}"]}) == {"n": [1, 2]} + + w1_w2_context = ExecutableContext(context=w1_context, node_id='workflow2', is_workflow=True) + w1_w2_context.state().update({'a': 3}) # a is overwritten + w1_w2_context.state().update_io({"n1": 1, "n2": 2}) + w1_w2_context.state().update_comp({"url": "0.0.0.2"}) # node scope variable + w1_w2_context.state().commit() + + self.assert_context(w1_w2_context, 'workflow2', 'workflow2', 'workflow2', 'workflow2') + assert w1_w2_context.state().get('a') == 3 + assert w1_w2_context.state().get_comp('url') == '0.0.0.2' + assert w1_w2_context.state().get_io({"n": ["${workflow2.n1}", "${workflow2.n2}"]}) == {"n": [1, 2]} + + w1_w2_node1_context = ExecutableContext(context=w1_w2_context, node_id='node1') + w1_w2_node1_context.state().update({'a': 4}) # a is overwritten + w1_w2_node1_context.state().update_io({"n1": 1, "n2": 2}) + w1_w2_node1_context.state().update_comp({"url": "0.0.0.3"}) # node scope variable + w1_w2_node1_context.state().commit() + + self.assert_context(w1_w2_node1_context, 'node1', 'workflow2.node1', 'workflow2', 'workflow2') + assert w1_w2_node1_context.state().get('a') == 4 + assert w1_w2_node1_context.state().get_comp('url') == '0.0.0.3' + assert w1_w2_node1_context.state().get_io({"n": ["${workflow2.node1.n1}", "${workflow2.node1.n2}"]}) == { + "n": [1, 2]} + + w1_w2_w3_context = ExecutableContext(context=w1_w2_context, node_id='workflow3', is_workflow=True) + self.assert_context(w1_w2_w3_context, 'workflow3', 'workflow2.workflow3', 'workflow3', 'workflow2.workflow3') + + w1_w2_w3_node1_context = ExecutableContext(context=w1_w2_w3_context, node_id='node1') + self.assert_context(w1_w2_w3_node1_context, 'node1', 'workflow2.workflow3.node1', 'workflow3', + 'workflow2.workflow3') diff --git a/tests/unit_tests/tracer/test_mock_node_with_tracer.py b/tests/unit_tests/tracer/test_mock_node_with_tracer.py index 96ee462..d5a0d92 100644 --- a/tests/unit_tests/tracer/test_mock_node_with_tracer.py +++ b/tests/unit_tests/tracer/test_mock_node_with_tracer.py @@ -1,5 +1,4 @@ import asyncio -import random from jiuwen.core.context.context import Context from jiuwen.core.graph.executable import Input, Output @@ -14,15 +13,14 @@ class StreamNodeWithTracer(MockNodeBase): self._datas: list[dict] = datas async def invoke(self, inputs: Input, context: Context) -> Output: - context.state.set_outputs(self.node_id, inputs) - await context.tracer.trigger("tracer_workflow", "on_invoke", invoke_id=context.executable_id, + context.state().update_io(inputs) + await context.tracer().trigger("tracer_workflow", "on_invoke",invoke_id=context.executable_id(), on_invoke_data={"on_invoke_data": "mock with" + str(inputs)}) - context.state.update_trace(context.executable_id, - context.tracer.tracer_workflow_span_manager.get_span(context.executable_id)) - await asyncio.sleep(random.randint(0, 5)) + context.state().update_trace(context.tracer().tracer_workflow_span_manager.get_span(context.executable_id())) + await asyncio.sleep(5) for data in self._datas: await asyncio.sleep(1) - await context.stream_writer_manager.get_custom_writer().write(data) + await context.stream_writer_manager().get_custom_writer().write(data) print("StreamNode: output = " + str(inputs)) return inputs diff --git a/tests/unit_tests/tracer/test_nested_stream_workflow_with_tracer.json b/tests/unit_tests/tracer/test_nested_stream_workflow_with_tracer.json new file mode 100644 index 0000000..0c329d4 --- /dev/null +++ b/tests/unit_tests/tracer/test_nested_stream_workflow_with_tracer.json @@ -0,0 +1,16 @@ +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:45.855777", "endTime": null, "inputs": {"a": 1, "b": "haha", "c": 1, "d": [1, 2, 3]}, "outputs": null, "error": null, "invokeId": "start", "parentInvokeId": null, "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "start", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"start\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:45.855777", "endTime": "2025-07-21 15:24:45.857741", "inputs": {"a": 1, "b": "haha", "c": 1, "d": [1, 2, 3]}, "outputs": {"a": 1, "b": "haha", "c": 1, "d": [1, 2, 3]}, "error": null, "invokeId": "start", "parentInvokeId": null, "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "start", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"start\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:46.852855", "endTime": null, "inputs": {"aa": 1, "ac": 1}, "outputs": null, "error": null, "invokeId": "a", "parentInvokeId": "start", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "a", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:46.852855", "endTime": null, "inputs": {"ba": "haha", "bc": [1, 2, 3]}, "outputs": null, "error": null, "invokeId": "b", "parentInvokeId": "a", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "b", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"b\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:46.852855", "endTime": null, "inputs": {"ba": "haha", "bc": [1, 2, 3]}, "outputs": null, "error": null, "invokeId": "b", "parentInvokeId": "a", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": [{"on_invoke_data": "mock with{'ba': 'haha', 'bc': [1, 2, 3]}"}], "agentId": "", "componentId": "", "componentName": "", "componentType": "b", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"b\"}", "loopNodeId": null, "loopIndex": null, "status": "running", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:46.858106", "endTime": null, "inputs": {"a": 1, "b": "haha", "c": 1, "d": [1, 2, 3]}, "outputs": null, "error": null, "invokeId": "a.sub_start", "parentInvokeId": null, "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "a.sub_start", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a.sub_start\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": "a"}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:46.858106", "endTime": "2025-07-21 15:24:46.858840", "inputs": {"a": 1, "b": "haha", "c": 1, "d": [1, 2, 3]}, "outputs": {"a": 1, "b": "haha", "c": 1, "d": [1, 2, 3]}, "error": null, "invokeId": "a.sub_start", "parentInvokeId": null, "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "a.sub_start", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a.sub_start\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": "a"}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:47.878381", "endTime": null, "inputs": {"aa": null, "ac": null}, "outputs": null, "error": null, "invokeId": "a.sub_a", "parentInvokeId": "a.sub_start", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "a.sub_a", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a.sub_a\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": "a"}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:47.878381", "endTime": null, "inputs": {"aa": null, "ac": null}, "outputs": null, "error": null, "invokeId": "a.sub_a", "parentInvokeId": "a.sub_start", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": [{"on_invoke_data": "mock with{'aa': None, 'ac': None}"}], "agentId": "", "componentId": "", "componentName": "", "componentType": "a.sub_a", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a.sub_a\"}", "loopNodeId": null, "loopIndex": null, "status": "running", "parentNodeId": "a"}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:46.852855", "endTime": "2025-07-21 15:24:53.902170", "inputs": {"ba": "haha", "bc": [1, 2, 3]}, "outputs": {"ba": "haha", "bc": [1, 2, 3]}, "error": null, "invokeId": "b", "parentInvokeId": "a", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": [{"on_invoke_data": "mock with{'ba': 'haha', 'bc': [1, 2, 3]}"}], "agentId": "", "componentId": "", "componentName": "", "componentType": "b", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"b\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:47.878381", "endTime": "2025-07-21 15:24:54.910614", "inputs": {"aa": null, "ac": null}, "outputs": {"aa": null, "ac": null}, "error": null, "invokeId": "a.sub_a", "parentInvokeId": "a.sub_start", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": [{"on_invoke_data": "mock with{'aa': None, 'ac': None}"}], "agentId": "", "componentId": "", "componentName": "", "componentType": "a.sub_a", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a.sub_a\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": "a"}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:55.922850", "endTime": null, "inputs": {"result": null}, "outputs": null, "error": null, "invokeId": "a.sub_end", "parentInvokeId": "a.sub_a", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "a.sub_end", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a.sub_end\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": "a"}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:55.922850", "endTime": "2025-07-21 15:24:55.924846", "inputs": {"result": null}, "outputs": {"result": null}, "error": null, "invokeId": "a.sub_end", "parentInvokeId": "a.sub_a", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "a.sub_end", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a.sub_end\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": "a"}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:46.852855", "endTime": "2025-07-21 15:24:56.943741", "inputs": {"aa": 1, "ac": 1}, "outputs": null, "error": null, "invokeId": "a", "parentInvokeId": "start", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "a", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a\"}", "loopNodeId": null, "loopIndex": null, "status": "running", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:57.968939", "endTime": null, "inputs": {"result": null}, "outputs": null, "error": null, "invokeId": "end", "parentInvokeId": "b", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "end", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"end\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "2879cd06-b142-43bb-b5d7-789635a303f2", "startTime": "2025-07-21 15:24:57.968939", "endTime": "2025-07-21 15:24:57.970948", "inputs": {"result": null}, "outputs": {"result": null}, "error": null, "invokeId": "end", "parentInvokeId": "b", "executionId": "2879cd06-b142-43bb-b5d7-789635a303f2", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "end", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"end\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": ""}} diff --git a/tests/unit_tests/tracer/test_parallel_exec_stream_workflow_with_tracer.json b/tests/unit_tests/tracer/test_parallel_exec_stream_workflow_with_tracer.json new file mode 100644 index 0000000..c56eda6 --- /dev/null +++ b/tests/unit_tests/tracer/test_parallel_exec_stream_workflow_with_tracer.json @@ -0,0 +1,10 @@ +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:24:59.008499", "endTime": null, "inputs": {"a": null, "b": null, "c": 1, "d": [1, 2, 3]}, "outputs": null, "error": null, "invokeId": "start", "parentInvokeId": null, "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "start", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"start\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:24:59.008499", "endTime": "2025-07-21 15:24:59.008499", "inputs": {"a": null, "b": null, "c": 1, "d": [1, 2, 3]}, "outputs": {"a": null, "b": null, "c": 1, "d": [1, 2, 3]}, "error": null, "invokeId": "start", "parentInvokeId": null, "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "start", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"start\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:25:00.030429", "endTime": null, "inputs": {"aa": null, "ac": 1}, "outputs": null, "error": null, "invokeId": "a", "parentInvokeId": "start", "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "a", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:25:00.031563", "endTime": null, "inputs": {"ba": null, "bc": [1, 2, 3]}, "outputs": null, "error": null, "invokeId": "b", "parentInvokeId": "a", "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "b", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"b\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:25:00.030429", "endTime": null, "inputs": {"aa": null, "ac": 1}, "outputs": null, "error": null, "invokeId": "a", "parentInvokeId": "start", "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": [{"on_invoke_data": "mock with{'aa': None, 'ac': 1}"}], "agentId": "", "componentId": "", "componentName": "", "componentType": "a", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a\"}", "loopNodeId": null, "loopIndex": null, "status": "running", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:25:00.031563", "endTime": null, "inputs": {"ba": null, "bc": [1, 2, 3]}, "outputs": null, "error": null, "invokeId": "b", "parentInvokeId": "a", "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": [{"on_invoke_data": "mock with{'ba': None, 'bc': [1, 2, 3]}"}], "agentId": "", "componentId": "", "componentName": "", "componentType": "b", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"b\"}", "loopNodeId": null, "loopIndex": null, "status": "running", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:25:00.030429", "endTime": "2025-07-21 15:25:07.068615", "inputs": {"aa": null, "ac": 1}, "outputs": {"aa": null, "ac": 1}, "error": null, "invokeId": "a", "parentInvokeId": "start", "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": [{"on_invoke_data": "mock with{'aa': None, 'ac': 1}"}], "agentId": "", "componentId": "", "componentName": "", "componentType": "a", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"a\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:25:00.031563", "endTime": "2025-07-21 15:25:07.070618", "inputs": {"ba": null, "bc": [1, 2, 3]}, "outputs": {"ba": null, "bc": [1, 2, 3]}, "error": null, "invokeId": "b", "parentInvokeId": "a", "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": [{"on_invoke_data": "mock with{'ba': None, 'bc': [1, 2, 3]}"}], "agentId": "", "componentId": "", "componentName": "", "componentType": "b", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"b\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:25:08.087818", "endTime": null, "inputs": {"result": null}, "outputs": null, "error": null, "invokeId": "end", "parentInvokeId": "b", "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "end", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"end\"}", "loopNodeId": null, "loopIndex": null, "status": "start", "parentNodeId": ""}} +{"type": "tracer_workflow", "payload": {"trace_id": "3e60c854-187a-40e3-8148-dd1d5d248faf", "startTime": "2025-07-21 15:25:08.087818", "endTime": "2025-07-21 15:25:08.088794", "inputs": {"result": null}, "outputs": {"result": null}, "error": null, "invokeId": "end", "parentInvokeId": "b", "executionId": "3e60c854-187a-40e3-8148-dd1d5d248faf", "conversationId": "", "onInvokeData": null, "agentId": "", "componentId": "", "componentName": "", "componentType": "end", "agentParentInvokeId": "", "metaData": "{\"component_id\": \"\", \"component_name\": \"\", \"component_type\": \"end\"}", "loopNodeId": null, "loopIndex": null, "status": "finish", "parentNodeId": ""}} diff --git a/tests/unit_tests/tracer/test_workflow.py b/tests/unit_tests/tracer/test_workflow.py index 2a5421f..5023200 100644 --- a/tests/unit_tests/tracer/test_workflow.py +++ b/tests/unit_tests/tracer/test_workflow.py @@ -22,7 +22,7 @@ import unittest from collections.abc import Callable from jiuwen.core.context.config import Config -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, WorkflowContext from jiuwen.core.context.memory.base import InMemoryState from jiuwen.core.graph.base import Graph from jiuwen.core.workflow.base import WorkflowConfig, Workflow @@ -33,7 +33,7 @@ from jiuwen.core.stream.writer import TraceSchema def create_context_with_tracer() -> Context: - return Context(config=Config(), state=InMemoryState(), store=None) + return WorkflowContext(config=Config(), state=InMemoryState(), store=None) def create_graph() -> Graph: diff --git a/tests/unit_tests/workflow/test_mock_node.py b/tests/unit_tests/workflow/test_mock_node.py index a632393..b3bf932 100644 --- a/tests/unit_tests/workflow/test_mock_node.py +++ b/tests/unit_tests/workflow/test_mock_node.py @@ -39,7 +39,7 @@ class MockStartNode(StartComponent, MockNodeBase): super().__init__(node_id) async def invoke(self, inputs: Input, context: Context) -> Output: - context.state.set_outputs(self.node_id, inputs) + context.state().update_io(inputs) logger.info("start: output{%s} ", inputs) return inputs @@ -93,7 +93,7 @@ class StreamNode(MockNodeBase): for data in self._datas: await asyncio.sleep(0.1) logger.info(f"StreamNode[{self._node_id}], stream frame: {data}") - await context.stream_writer_manager.get_custom_writer().write(data) + await context.stream_writer_manager().get_custom_writer().write(data) logger.info(f"StreamNode[{self._node_id}], batch output: {inputs}") return inputs @@ -106,7 +106,7 @@ class StreamNodeWithSubWorkflow(MockNodeBase): async def invoke(self, inputs: Input, context: Context) -> Output: async for chunk in self._sub_workflow.stream({"a": 1, "b": "haha"}, context): logger.info(f"StreamNodeWithSubWorkflow[{self._node_id}], stream frame: {chunk}") - await context.stream_writer_manager.get_custom_writer().write(chunk) + await context.stream_writer_manager().get_custom_writer().write(chunk) logger.info(f"StreamNodeWithSubWorkflow[{self._node_id}], batch output: {inputs}") return inputs diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index 0ae569f..1800252 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -17,7 +17,7 @@ from jiuwen.core.component.loop_callback.output import OutputCallback from jiuwen.core.component.loop_comp import LoopGroup, LoopComponent from jiuwen.core.component.set_variable_comp import SetVariableComponent from jiuwen.core.context.config import Config -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, WorkflowContext from jiuwen.core.context.memory.base import InMemoryState from jiuwen.core.graph.base import Graph from jiuwen.core.graph.graph_state import GraphState @@ -29,7 +29,7 @@ from tests.unit_tests.workflow.test_mock_node import MockStartNode, MockEndNode, def create_context() -> Context: - return Context(config=Config(), state=InMemoryState(), store=None) + return WorkflowContext(config=Config(), state=InMemoryState(), store=None) def create_graph() -> Graph: @@ -536,6 +536,7 @@ class WorkflowTest(unittest.TestCase): flow1.add_workflow_comp("composite", CompositeWorkflowNode("composite", flow2), inputs_schema={"result": "${start.a2}"}) + flow1.set_end_comp("end", MockEndNode("end"), inputs_schema={"b1": "${a1.value}", "b2": "${composite.result}"}) flow1.add_connection("start", "a1") flow1.add_connection("start", "composite") -- Gitee