From 5373744972d4bf4d386fea7b15d465189a4cf0c6 Mon Sep 17 00:00:00 2001 From: CandiceGuo Date: Mon, 14 Jul 2025 11:04:11 +0800 Subject: [PATCH 1/7] =?UTF-8?q?fix:=20=E8=81=94=E8=B0=83workflow=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/context/memory/base.py | 8 ++--- jiuwen/core/workflow/base.py | 15 ++++---- jiuwen/graph/pregel/graph.py | 25 ++++++++------ tests/unit_tests/workflow/test_mock_node.py | 38 +++++++++++++++++++++ tests/unit_tests/workflow/test_workflow.py | 35 +++++++++++++++++++ 5 files changed, 98 insertions(+), 23 deletions(-) create mode 100644 tests/unit_tests/workflow/test_mock_node.py create mode 100644 tests/unit_tests/workflow/test_workflow.py diff --git a/jiuwen/core/context/memory/base.py b/jiuwen/core/context/memory/base.py index 893a0e9..04c3752 100644 --- a/jiuwen/core/context/memory/base.py +++ b/jiuwen/core/context/memory/base.py @@ -9,7 +9,7 @@ from jiuwen.core.context.state import CommitState, StateLike, State from jiuwen.core.context.utils import extract_origin_key, get_value_by_nested_path, update_dict -class InMemoryState(StateLike): +class InMemoryStateLike(StateLike): def __init__(self): self._state: dict = dict() @@ -39,7 +39,7 @@ class InMemoryState(StateLike): class InMemoryCommitState(CommitState): def __init__(self): - self._state = InMemoryState() + self._state = InMemoryStateLike() self._updates: dict[str, list[dict]] = dict() def update(self, node_id: str, data: dict) -> None: @@ -72,7 +72,7 @@ class InMemoryState(State): def __init__(self): super().__init__(io_state=InMemoryCommitState(), global_state=InMemoryCommitState(), - trace_state=InMemoryState(), - comp_state=InMemoryState()) + trace_state=InMemoryStateLike(), + comp_state=InMemoryStateLike()) diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index 5686512..f9afa32 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -17,19 +17,18 @@ from jiuwen.core.graph.executable import Executable, Input, Output class WorkflowConfig(BaseModel): - metadata: BaseModel = Field(default=None) + metadata: BaseModel class WorkflowOutput(BaseModel): - result: str = Field(default="") + result: str class WorkflowChunk(BaseModel): - chunk_id: str = Field(default="") - payload: str = Field(default="") - metadata: Dict[str, Any] = Field(default_factory=dict) - is_final: bool = Field(default=False) - + chunk_id: str + payload: str + metadata: Dict[str, Any] + is_final: bool class Workflow: @@ -67,7 +66,7 @@ class Workflow: self._comp_io_schemas[start_comp_id] = (inputs_schema, output_schema) return self - def set_end_component( + def set_end_comp( self, end_comp_id: str, component: EndComponent, diff --git a/jiuwen/graph/pregel/graph.py b/jiuwen/graph/pregel/graph.py index b145919..ef0c529 100644 --- a/jiuwen/graph/pregel/graph.py +++ b/jiuwen/graph/pregel/graph.py @@ -4,15 +4,18 @@ from typing import Union, Self, Iterator, AsyncIterator from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph -from jiuwen.core.graph.base import Graph -from jiuwen.core.graph.state import State +from jiuwen.core.context.context import Context +from jiuwen.core.graph.base import Graph, Router, ExecutableGraph +from jiuwen.core.graph.executable import Executable, Input, Output +from jiuwen.core.graph.graph_state import GraphState from jiuwen.core.graph.vertex import Vertex class PregelGraph(Graph): def __init__(self): - self.pregel: StateGraph = StateGraph(State) + self.pregel: StateGraph = StateGraph(GraphState) self.compiledStateGraph = None self.edges: list[Union[str, list[str]], str] = [] self.waits: set[str] = set() @@ -25,8 +28,8 @@ class PregelGraph(Graph): self.pregel.set_finish_point(node_id) return self - def add_node(self, node_id: str, node: Executable, *, wait_for_all: bool = False, inputs: dict = None) -> Self: - self.pregel.add_node(node_id, Vertex(node_id, node, inputs)) + def add_node(self, node_id: str, node: Executable, *, wait_for_all: bool = False) -> Self: + self.pregel.add_node(node_id, Vertex(node_id, node)) if wait_for_all: self.waits.add(node_id) return self @@ -39,7 +42,7 @@ class PregelGraph(Graph): self.pregel.add_conditional_edges(source, router) return self - def compile(self) -> ExecutableGraph: + def compile(self, context: Context) -> ExecutableGraph: if self.compiledStateGraph is None: self._pre_compile() self.compiledStateGraph = self.pregel.compile() @@ -65,17 +68,17 @@ class PregelGraph(Graph): class CompiledGraph(ExecutableGraph): - def __init__(self, compiledStateGraph: CompiledGraph): + def __init__(self, compiledStateGraph: CompiledStateGraph): self._compiledStateGraph = compiledStateGraph def invoke(self, inputs: Input, context: Context) -> Output: - return self._compiledStateGraph.invoke({'context': context}) + return self._compiledStateGraph.invoke({}) def stream(self, inputs: Input, context: Context) -> Iterator[Output]: - return self._compiledStateGraph.stream({'context': context}) + return self._compiledStateGraph.stream({}) async def ainvoke(self, inputs: Input, context: Context) -> Output: - return self._compiledStateGraph.ainvoke({'context': context}) + return self._compiledStateGraph.ainvoke({}) async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: - return self._compiledStateGraph.astream({'context': context}) + return self._compiledStateGraph.astream({}) diff --git a/tests/unit_tests/workflow/test_mock_node.py b/tests/unit_tests/workflow/test_mock_node.py new file mode 100644 index 0000000..8c6ea95 --- /dev/null +++ b/tests/unit_tests/workflow/test_mock_node.py @@ -0,0 +1,38 @@ +from typing import Iterator, AsyncIterator, Callable + +from jiuwen.core.component.base import WorkflowComponent, StartComponent, EndComponent +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable, Input, Output + + +class MockNodeBase(Executable, WorkflowComponent): + def __init__(self, node_id: str): + super().__init__() + self.node_id = node_id + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + yield self.invoke(inputs, context) + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + yield await self.invoke(inputs, context) + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + yield await self.invoke(inputs, context) + + def interrupt(self, message: dict): + return + + +class MockStartNode(StartComponent, MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + + def invoke(self, inputs: Input, context: Context) -> Output: + return inputs + +class MockEndNode(EndComponent, MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + + def invoke(self, inputs: Input, context: Context) -> Output: + return inputs diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py new file mode 100644 index 0000000..b0ee178 --- /dev/null +++ b/tests/unit_tests/workflow/test_workflow.py @@ -0,0 +1,35 @@ +from sqlalchemy.testing.suite.test_reflection import metadata + +from jiuwen.core.component.base import StartComponent, EndComponent +from jiuwen.core.context.config import Config +from jiuwen.core.context.context import Context +from jiuwen.core.context.memory.base import InMemoryState +from jiuwen.core.graph.executable import Input, Output +from jiuwen.core.workflow.base import WorkflowConfig, Workflow +from jiuwen.graph.pregel.graph import PregelGraph +from tests.unit_tests.workflow.test_mock_node import MockNodeBase, MockStartNode, MockEndNode + + +class Node1(MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + + def invoke(self, inputs: Input, context: Context) -> Output: + return {} + + +def test_workflow_base(): + workflow_config = WorkflowConfig(metadata = {}) + graph = PregelGraph() + + flow = Workflow(workflow_config=workflow_config, graph=graph) + flow.add_workflow_comp("a", Node1("a"), inputs_schema={}) + flow.set_start_comp("start", MockStartNode("start")) + flow.set_end_comp("end", MockEndNode("end")) + + context = Context(config = Config(), state = InMemoryState(), store = None, tracer = None) + flow.invoke({}, context) + + +if __name__ == "__main__": + test_workflow_base() -- Gitee From 16de379e0b55eb9fb9fb69df61dd757f096b5921 Mon Sep 17 00:00:00 2001 From: caoyuzhe Date: Mon, 14 Jul 2025 11:56:22 +0800 Subject: [PATCH 2/7] =?UTF-8?q?fix:=20=E8=81=94=E8=B0=83workflow=E9=80=82?= =?UTF-8?q?=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1、新增更新变量组件 2、增加循环组件ut 3、部分代码适配,未完成 --- jiuwen/core/component/break_comp.py | 10 ++++ jiuwen/core/component/condition/array.py | 1 + jiuwen/core/component/condition/expression.py | 1 + jiuwen/core/component/condition/number.py | 1 + .../loop_callback/intermediate_loop_var.py | 13 +++-- jiuwen/core/component/loop_callback/output.py | 14 ++--- jiuwen/core/component/loop_comp.py | 20 ++++--- jiuwen/core/component/set_variable_comp.py | 44 ++++++++++++++ jiuwen/core/context/utils.py | 2 + tests/unit_tests/workflow/test_loop.py | 58 +++++++++++++++++++ tests/unit_tests/workflow/test_node.py | 52 +++++++++++++++++ 11 files changed, 194 insertions(+), 22 deletions(-) create mode 100644 jiuwen/core/component/set_variable_comp.py create mode 100644 tests/unit_tests/workflow/test_loop.py create mode 100644 tests/unit_tests/workflow/test_node.py diff --git a/jiuwen/core/component/break_comp.py b/jiuwen/core/component/break_comp.py index e0c159e..e191288 100644 --- a/jiuwen/core/component/break_comp.py +++ b/jiuwen/core/component/break_comp.py @@ -2,6 +2,11 @@ # -*- coding: UTF-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from abc import abstractmethod, ABC +from typing import Iterator, AsyncIterator + +from jiuwen.core.component.base import WorkflowComponent +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable, Input, Output class LoopController(ABC): @@ -15,9 +20,14 @@ class LoopController(ABC): class BreakComponent(WorkflowComponent, Executable): + def __init__(self): + super().__init__() self._loop_controller = None + def interrupt(self, message: dict): + pass + def set_controller(self, loop_controller: LoopController): self._loop_controller = loop_controller diff --git a/jiuwen/core/component/condition/array.py b/jiuwen/core/component/condition/array.py index 8f2400e..4bf057f 100644 --- a/jiuwen/core/component/condition/array.py +++ b/jiuwen/core/component/condition/array.py @@ -5,6 +5,7 @@ from typing import Union, Any from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR from jiuwen.core.component.condition.condition import Condition +from jiuwen.core.context.context import Context DEFAULT_MAX_LOOP_NUMBER = 1000 diff --git a/jiuwen/core/component/condition/expression.py b/jiuwen/core/component/condition/expression.py index 2c44265..45515d2 100644 --- a/jiuwen/core/component/condition/expression.py +++ b/jiuwen/core/component/condition/expression.py @@ -4,6 +4,7 @@ import re from jiuwen.core.component.condition.condition import Condition +from jiuwen.core.context.context import Context class ExpressionCondition(Condition): diff --git a/jiuwen/core/component/condition/number.py b/jiuwen/core/component/condition/number.py index 412c59f..8bf3b59 100644 --- a/jiuwen/core/component/condition/number.py +++ b/jiuwen/core/component/condition/number.py @@ -5,6 +5,7 @@ from typing import Union from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR from jiuwen.core.component.condition.condition import Condition +from jiuwen.core.context.context import Context class NumberCondition(Condition): diff --git a/jiuwen/core/component/loop_callback/intermediate_loop_var.py b/jiuwen/core/component/loop_callback/intermediate_loop_var.py index e0d117e..b789e23 100644 --- a/jiuwen/core/component/loop_callback/intermediate_loop_var.py +++ b/jiuwen/core/component/loop_callback/intermediate_loop_var.py @@ -3,8 +3,9 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from typing import Union, Any -from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR from jiuwen.core.component.loop_callback.loop_callback import LoopCallback +from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import NESTED_PATH_SPLIT, is_ref_path, extract_origin_key class IntermediateLoopVarCallback(LoopCallback): @@ -13,21 +14,21 @@ class IntermediateLoopVarCallback(LoopCallback): self._context = context self._intermediate_loop_var = intermediate_loop_var self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root \ - else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "intermediateLoopVar" + else node_id + NESTED_PATH_SPLIT + "intermediateLoopVar" def first_in_loop(self): for key, value in self._intermediate_loop_var.items(): - path = self._intermediate_loop_var_root + BPMN_VARIABLE_POOL_SEPARATOR + key + path = self._intermediate_loop_var_root + NESTED_PATH_SPLIT + key updates: Any if isinstance(value, str): - ref_str = get_ref_str(value) - if ref_str != "": + if is_ref_path(value): + ref_str = extract_origin_key(value) update = self._context.store.read(ref_str) else: update = value else: update = value - self._context.store.write(path, update) + self._context.store.write({path: update}) def out_loop(self): self._context.store.write(self._intermediate_loop_var_root, {}) diff --git a/jiuwen/core/component/loop_callback/output.py b/jiuwen/core/component/loop_callback/output.py index 0f8f90f..e94252a 100644 --- a/jiuwen/core/component/loop_callback/output.py +++ b/jiuwen/core/component/loop_callback/output.py @@ -3,8 +3,9 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from typing import Any -from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR from jiuwen.core.component.loop_callback.loop_callback import LoopCallback +from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import is_ref_path, extract_origin_key, NESTED_PATH_SPLIT class OutputCallback(LoopCallback): @@ -12,16 +13,15 @@ class OutputCallback(LoopCallback): round_result_root: str = None, result_root: str = None, intermediate_loop_var_root: str = None): self._context = context self._outputs_format = outputs_format - self._round_result_root = round_result_root if round_result_root else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "round" + self._round_result_root = round_result_root if round_result_root else node_id + NESTED_PATH_SPLIT + "round" self._result_root = result_root if result_root else node_id - self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "intermediateLoopVar" + self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root else node_id + NESTED_PATH_SPLIT + "intermediateLoopVar" def _generate_results(self, results: dict[str, list[Any]]): for key, value in self._outputs_format.items(): - if isinstance(value, str): - ref_str = get_ref_str(value) - if ref_str != "": - results[ref_str] = [] + if isinstance(value, str) and is_ref_path(value): + ref_str = extract_origin_key(value) + results[ref_str] = [] elif isinstance(value, dict): self._generate_results(results) diff --git a/jiuwen/core/component/loop_comp.py b/jiuwen/core/component/loop_comp.py index 6c04192..66d45d7 100644 --- a/jiuwen/core/component/loop_comp.py +++ b/jiuwen/core/component/loop_comp.py @@ -8,11 +8,15 @@ from typing import Iterator, AsyncIterator, Self, Union, Callable from langgraph.constants import END, START from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR -from jiuwen.core.component.break_comp import BreakComponent + +from jiuwen.core.component.base import WorkflowComponent +from jiuwen.core.component.break_comp import BreakComponent, LoopController from jiuwen.core.component.condition.condition import Condition, AlwaysTrue, FuncCondition from jiuwen.core.component.condition.expression import ExpressionCondition from jiuwen.core.component.loop_callback.loop_callback import LoopCallback -from jiuwen.core.graph.base import Graph +from jiuwen.core.context.context import Context +from jiuwen.core.graph.base import Graph, Router, ExecutableGraph +from jiuwen.core.graph.executable import Output, Input, Executable from jiuwen.graph.factory import GraphFactory @@ -37,9 +41,8 @@ class LoopGroup: self._context = context self._graph = graph if graph else GraphFactory().create_graph() - def add_component(self, node_id: str, component: WorkflowComponent, *, wait_for_all: bool = False, - inputs: dict = None) -> Self: - component.add_component(self._graph, node_id, inputs, wait_for_all=wait_for_all) + def add_component(self, node_id: str, component: WorkflowComponent, *, wait_for_all: bool = False) -> Self: + component.add_component(self._graph, node_id, wait_for_all=wait_for_all) return self def start_nodes(self, nodes: list[str]) -> Self: @@ -71,10 +74,9 @@ FIRST_IN_LOOP = "_first_in_loop" class LoopComponent(WorkflowComponent, LoopController): def __init__(self, context: Context, node_id: str, body: Union[Executable, LoopGroup], - condition: Union[str, Callable[[], bool], Condition] = None, - context_root: str = None, - break_nodes: list[BreakComponent] = None, - callbacks: list[LoopCallback] = None, graph: Graph = None): + condition: Union[str, Callable[[], bool], Condition] = None, context_root: str = None, + break_nodes: list[BreakComponent] = None, callbacks: list[LoopCallback] = None, graph: Graph = None): + super().__init__() if context is None: raise ValueError("context cannot be None") if context_root is None: diff --git a/jiuwen/core/component/set_variable_comp.py b/jiuwen/core/component/set_variable_comp.py new file mode 100644 index 0000000..7ea4d2c --- /dev/null +++ b/jiuwen/core/component/set_variable_comp.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import asyncio +from functools import partial +from typing import AsyncIterator, Iterator + +from jiuwen.core.component.base import WorkflowComponent +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable, Input, Output + + +class SetVariableComponent(WorkflowComponent, Executable): + def __init__(self, context: Context, varivable_mapping: dict[str, Any]): + self._context = context + self._varivable_mapping = varivable_mapping + + def invoke(self, inputs: Input, context: Context) -> Output: + for left, right in self._varivable_mapping.items(): + left_ref_str = get_ref_str(left) + if left_ref_str == "": + left_ref_str = left + if isinstance(right, str): + ref_str = get_ref_str(right) + if ref_str == "": + self._context.store.write(left_ref_str, self._context.store.read(ref_str)) + continue + self._context.store.write(left_ref_str, right) + + return None + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.invoke, context), inputs + ) + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + yield self.invoke(inputs, context) + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + yield await self.ainvoke(inputs, context) + + def interrupt(self, message: dict): + pass diff --git a/jiuwen/core/context/utils.py b/jiuwen/core/context/utils.py index c01b43c..7e1fbc5 100644 --- a/jiuwen/core/context/utils.py +++ b/jiuwen/core/context/utils.py @@ -58,6 +58,8 @@ def split_nested_path(nested_key: str) -> list: final_list.append(match.group(1)) return final_list +def is_ref_path(path: str) -> bool: + return len(path) > 3 and path.startswith("${") and path.endswith("}") def extract_origin_key(key: str) -> str: """ diff --git a/tests/unit_tests/workflow/test_loop.py b/tests/unit_tests/workflow/test_loop.py new file mode 100644 index 0000000..c862b1b --- /dev/null +++ b/tests/unit_tests/workflow/test_loop.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +from jiuwen.core.component.base import StartComponent, EndComponent +from jiuwen.core.component.condition.array import ArrayCondition +from jiuwen.core.component.loop_callback.intermediate_loop_var import IntermediateLoopVarCallback +from jiuwen.core.component.loop_callback.output import OutputCallback +from jiuwen.core.component.loop_comp import LoopGroup, LoopComponent +from jiuwen.core.component.set_variable_comp import SetVariableComponent +from jiuwen.core.context.context import Context +from jiuwen.core.workflow.base import Workflow +from tests.unit_tests.workflow.test_node import CommonNode, AddTenNode + + +def test_loop_array_condition(): + context = Context() + + flow = Workflow() + s = StartComponent() + e = EndComponent() + flow.set_start_comp("s", s) + flow.set_end_component("e", e, inputs_schema={"array_result": "${b.array_result}", "user_var": "${b.user_var}"}) + flow.add_workflow_comp("a", CommonNode("a"), inputs_schema={"array": "${user.inputs.input_array}"}) + flow.add_workflow_comp("b", CommonNode("b"), + inputs_schema={"array_result": "${l.results}", "user_var": "${l.user_var}"}) + + loop_group = LoopGroup(context) + node1 = AddTenNode("1") + loop_group.add_component("1", node1, inputs={"source": "${l.arrLoopVar.item}"}) + node2 = AddTenNode("2") + loop_group.add_component("2", node2, inputs={"source": "${l.intermediateLoopVar.user_var}"}) + node3 = SetVariableComponent(context, {"${l.intermediateLoopVar.user_var}": "${2.result}"}) + loop_group.add_component("3", node3) + loop_group.start_nodes(["1"]) + loop_group.end_nodes(["3"]) + loop_group.add_connection("1", "2") + loop_group.add_connection("2", "3") + + loop_condition = ArrayCondition(context, "l", {"item": "${a.loop_array}"}) + output_callback = OutputCallback(context, "l", + {"results": "${1.result}", "user_var": "${l.intermediateLoopVar.user_var}"}) + intermediate_callback = IntermediateLoopVarCallback(context, "l", + {"user_var": "${user.inputs.input_number}"}) + + loop = LoopComponent(context, "l", loop_group, loop_condition, + callbacks=[output_callback, intermediate_callback]) + flow.add_workflow_comp("l", loop) + + flow.add_connection("s", "a") + flow.add_connection("a", "l") + flow.add_connection("l", "b") + flow.add_connection("b", "e") + + result = flow.invoke({"input_array": [1, 2, 3], "input_number": 1}, context=context) + assert result == {"array_result": [11, 12, 13], "user_var": 31} + + result = flow.invoke({"input_array": [4, 5], "input_number": 2}, context=context) + assert result == {"array_result": [14, 15], "user_var": 22} diff --git a/tests/unit_tests/workflow/test_node.py b/tests/unit_tests/workflow/test_node.py new file mode 100644 index 0000000..cacc0b1 --- /dev/null +++ b/tests/unit_tests/workflow/test_node.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +from typing import AsyncIterator, Iterator + +from jiuwen.core.component.base import WorkflowComponent +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable, Input, Output + + +class CommonNode(Executable, WorkflowComponent): + + def __init__(self, node_id: str): + super().__init__() + self.node_id = node_id + + def invoke(self, inputs: Input, context: Context) -> Output: + return inputs + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + pass + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + yield self.invoke(inputs, context) + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + yield await self.ainvoke(inputs, context) + + def interrupt(self, message: dict): + pass + + +class AddTenNode(Executable, WorkflowComponent): + + def __init__(self, node_id: str): + super().__init__() + self.node_id = node_id + + def invoke(self, inputs: Input, context: Context) -> Output: + return {"result": inputs["source"] + 10} + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + pass + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + yield self.invoke(inputs, context) + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + yield await self.ainvoke(inputs, context) + + def interrupt(self, message: dict): + pass -- Gitee From 6defed3c537aa9816f20fc832e8bfa8b44a3e3de Mon Sep 17 00:00:00 2001 From: CandiceGuo Date: Mon, 14 Jul 2025 12:00:55 +0800 Subject: [PATCH 3/7] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dworkflow=E8=81=94?= =?UTF-8?q?=E8=B0=83=E7=BC=96=E8=AF=91=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/component/condition/array.py | 23 ++--- jiuwen/core/component/condition/expression.py | 2 +- jiuwen/core/component/condition/number.py | 9 +- .../loop_callback/intermediate_loop_var.py | 7 +- jiuwen/core/component/loop_callback/output.py | 21 ++--- jiuwen/core/component/loop_comp.py | 34 +++++--- jiuwen/core/component/set_variable_comp.py | 18 ++-- jiuwen/core/workflow/base.py | 4 +- jiuwen/graph/pregel/graph.py | 3 + tests/unit_tests/workflow/test_loop.py | 58 ------------- tests/unit_tests/workflow/test_mock_node.py | 8 ++ tests/unit_tests/workflow/test_workflow.py | 87 +++++++++++++++---- 12 files changed, 144 insertions(+), 130 deletions(-) delete mode 100644 tests/unit_tests/workflow/test_loop.py diff --git a/jiuwen/core/component/condition/array.py b/jiuwen/core/component/condition/array.py index 4bf057f..0f97247 100644 --- a/jiuwen/core/component/condition/array.py +++ b/jiuwen/core/component/condition/array.py @@ -3,9 +3,9 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from typing import Union, Any -from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR from jiuwen.core.component.condition.condition import Condition from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import extract_origin_key, NESTED_PATH_SPLIT DEFAULT_MAX_LOOP_NUMBER = 1000 @@ -14,27 +14,28 @@ class ArrayCondition(Condition): def __init__(self, context: Context, node_id: str, arrays: dict[str, Union[str, list[Any]]], index_path: str = None, array_root: str = None): self._context = context + self._node_id = node_id self._arrays = arrays - self._index_path = index_path if index_path else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "index" - self._arrays_root = array_root if array_root else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "arrLoopVar" + self._index_path = index_path if index_path else node_id + NESTED_PATH_SPLIT + "index" + self._arrays_root = array_root if array_root else node_id + NESTED_PATH_SPLIT + "arrLoopVar" def init(self): - self._context.store.write(self._index_path, 0) - self._context.store.write(self._arrays_root, {}) + self._context.state.update(self._node_id, {self._index_path: 0}) + self._context.state.update(self._node_id, {self._arrays_root: {}}) def __call__(self) -> bool: - current_idx = self._context.store.read(self._index_path) + current_idx = self._context.state.get(self._index_path) min_length = DEFAULT_MAX_LOOP_NUMBER updates: dict[str, Any] = {} for key, array_info in self._arrays.items(): - key_path = self._arrays_root + BPMN_VARIABLE_POOL_SEPARATOR + key + key_path = self._arrays_root + NESTED_PATH_SPLIT + key arr: list[Any] = [] if isinstance(array_info, list): arr = array_info elif isinstance(array_info, str): - ref_str = get_ref_str(array_info) + ref_str = extract_origin_key(array_info) if ref_str != "": - arr = self._context.store.read(ref_str) + arr = self._context.state.get(ref_str) else: raise RuntimeError("error value: " + array_info + " is not a array path") else: @@ -44,8 +45,8 @@ class ArrayCondition(Condition): return False updates[key_path] = arr[current_idx] - self._context.store.write(self._index_path, current_idx + 1) + self._context.state.update(self._node_id, {self._index_path: current_idx + 1}) for path, update in updates.items(): - self._context.store.write(path, update) + self._context.state.update(self._node_id, update) return True diff --git a/jiuwen/core/component/condition/expression.py b/jiuwen/core/component/condition/expression.py index 45515d2..c770eda 100644 --- a/jiuwen/core/component/condition/expression.py +++ b/jiuwen/core/component/condition/expression.py @@ -20,7 +20,7 @@ class ExpressionCondition(Condition): matches = re.findall(pattern, self._expression) inputs = {} for match in matches: - inputs[match] = self._context.store.read(match[2:-1]) + inputs[match] = self._context.state.get(match[2:-1]) return self._evalueate_expression(self._expression, inputs) def _evalueate_expression(self, expression, inputs) -> bool: diff --git a/jiuwen/core/component/condition/number.py b/jiuwen/core/component/condition/number.py index 8bf3b59..68ab057 100644 --- a/jiuwen/core/component/condition/number.py +++ b/jiuwen/core/component/condition/number.py @@ -13,20 +13,21 @@ class NumberCondition(Condition): self._context = context self._index_path = index_path if index_path else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "index" self._limit = limit + self._node_id = node_id def init(self): - self._context.store.write(self._index_path, 0) + self._context.state.update(self._node_id, {self._index_path: 0}) def __call__(self) -> bool: - current_idx = self._context.store.read(self._index_path) + current_idx = self._context.state.get(self._index_path) limit_num: int if isinstance(self._limit, int): limit_num = self._limit else: - limit_num = self._context.store.read(self._limit) + limit_num = self._context.state.get(self._limit) result = current_idx < limit_num if result: - self._context.store.write(self._index_path, current_idx + 1) + self._context.state.update(self._node_id, {self._index_path: current_idx + 1}) return result diff --git a/jiuwen/core/component/loop_callback/intermediate_loop_var.py b/jiuwen/core/component/loop_callback/intermediate_loop_var.py index b789e23..f090c26 100644 --- a/jiuwen/core/component/loop_callback/intermediate_loop_var.py +++ b/jiuwen/core/component/loop_callback/intermediate_loop_var.py @@ -12,6 +12,7 @@ class IntermediateLoopVarCallback(LoopCallback): def __init__(self, context: Context, node_id: str, intermediate_loop_var: dict[str, Union[str, Any]], intermediate_loop_var_root: str = None): self._context = context + self._node_id = node_id self._intermediate_loop_var = intermediate_loop_var self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root \ else node_id + NESTED_PATH_SPLIT + "intermediateLoopVar" @@ -23,15 +24,15 @@ class IntermediateLoopVarCallback(LoopCallback): if isinstance(value, str): if is_ref_path(value): ref_str = extract_origin_key(value) - update = self._context.store.read(ref_str) + update = self._context.state.get(ref_str) else: update = value else: update = value - self._context.store.write({path: update}) + self._context.state.update(self._node_id, {path: update}) def out_loop(self): - self._context.store.write(self._intermediate_loop_var_root, {}) + self._context.state.update(self._node_id, {self._intermediate_loop_var_root: {}}) def start_round(self): pass diff --git a/jiuwen/core/component/loop_callback/output.py b/jiuwen/core/component/loop_callback/output.py index e94252a..ae0b8d1 100644 --- a/jiuwen/core/component/loop_callback/output.py +++ b/jiuwen/core/component/loop_callback/output.py @@ -11,6 +11,7 @@ from jiuwen.core.context.utils import is_ref_path, extract_origin_key, NESTED_PA class OutputCallback(LoopCallback): def __init__(self, context: Context, node_id: str, outputs_format: dict[str, Any], round_result_root: str = None, result_root: str = None, intermediate_loop_var_root: str = None): + self._node_id = node_id self._context = context self._outputs_format = outputs_format self._round_result_root = round_result_root if round_result_root else node_id + NESTED_PATH_SPLIT + "round" @@ -28,30 +29,30 @@ class OutputCallback(LoopCallback): def first_in_loop(self): _results: dict[str, list[Any]] = {} self._generate_results(_results) - self._context.store.write(self._round_result_root, _results) + self._context.state.update(self._round_result_root, _results) def out_loop(self): - results: dict[str, list[Any]] = self._context.store.read(self._round_result_root) + results: dict[str, list[Any]] = self._context.state.get(self._round_result_root) if not isinstance(results, dict): raise RuntimeError("error results in loop process") for path, array in results.items(): - self._context.store.write(path, array) - result = filter_input(self._outputs_format, self._context.store) - self._context.store.write(self._round_result_root, {}) - set_output(result, self._result_root, self._context.store) + self._context.state.update(self._node_id, {path: array}) + result = self._context.state.get_inputs(self._outputs_format) + self._context.state.update(self._node_id, {self._round_result_root : {}}) + self._context.state.set_outputs(self._node_id, {self._result_root : result}) def start_round(self): pass def end_round(self): - results: dict[str, list[Any]] = self._context.store.read(self._round_result_root) + results: dict[str, list[Any]] = self._context.state.get(self._round_result_root) if not isinstance(results, dict): raise RuntimeError("error results in round process") for path, value in results.items(): if path.startswith(self._intermediate_loop_var_root): - results[path] = self._context.store.read(path) + results[path] = self._context.state.get(path) elif isinstance(value, list): - value.append(self._context.store.read(path)) + value.append(self._context.state.get(path)) else: raise RuntimeError("error process in loop: " + path + ", " + str(value)) - self._context.store.write(self._round_result_root, results) + self._context.state.update(self._node_id, {self._round_result_root : results}) diff --git a/jiuwen/core/component/loop_comp.py b/jiuwen/core/component/loop_comp.py index 66d45d7..081d852 100644 --- a/jiuwen/core/component/loop_comp.py +++ b/jiuwen/core/component/loop_comp.py @@ -7,14 +7,13 @@ from typing import Iterator, AsyncIterator, Self, Union, Callable from langgraph.constants import END, START -from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR - from jiuwen.core.component.base import WorkflowComponent from jiuwen.core.component.break_comp import BreakComponent, LoopController from jiuwen.core.component.condition.condition import Condition, AlwaysTrue, FuncCondition from jiuwen.core.component.condition.expression import ExpressionCondition from jiuwen.core.component.loop_callback.loop_callback import LoopCallback from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import NESTED_PATH_SPLIT from jiuwen.core.graph.base import Graph, Router, ExecutableGraph from jiuwen.core.graph.executable import Output, Input, Executable from jiuwen.graph.factory import GraphFactory @@ -35,14 +34,19 @@ class EmptyExecutable(Executable): async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: yield await self.ainvoke(inputs, context) + def interrupt(self, message: dict): + return + class LoopGroup: def __init__(self, context: Context, graph: Graph = None): self._context = context self._graph = graph if graph else GraphFactory().create_graph() - def add_component(self, node_id: str, component: WorkflowComponent, *, wait_for_all: bool = False) -> Self: + def add_component(self, node_id: str, component: WorkflowComponent, *, wait_for_all: bool = False, + inputs_schema: dict = None, outputs_schema: dict = None) -> Self: component.add_component(self._graph, node_id, wait_for_all=wait_for_all) + self._context.config.set_io_schema(node_id, (inputs_schema, outputs_schema)) return self def start_nodes(self, nodes: list[str]) -> Self: @@ -55,7 +59,7 @@ class LoopGroup: self._graph.end_node(node) return self - def add_condition(self, start_node_id: Union[str, list[str]], end_node_id: str) -> Self: + def add_connection(self, start_node_id: Union[str, list[str]], end_node_id: str) -> Self: self._graph.add_edge(start_node_id, end_node_id) return self @@ -64,7 +68,7 @@ class LoopGroup: return self def compile(self) -> ExecutableGraph: - return self._graph.compile() + return self._graph.compile(self._context) BROKEN = "_broken" @@ -77,6 +81,7 @@ class LoopComponent(WorkflowComponent, LoopController): condition: Union[str, Callable[[], bool], Condition] = None, context_root: str = None, break_nodes: list[BreakComponent] = None, callbacks: list[LoopCallback] = None, graph: Graph = None): super().__init__() + self._node_id = node_id if context is None: raise ValueError("context cannot be None") if context_root is None: @@ -98,7 +103,7 @@ class LoopComponent(WorkflowComponent, LoopController): if break_nodes: for break_node in break_nodes: - break_node.set_condition(self) + break_node.set_controller(self) if callbacks: for callback in callbacks: @@ -119,11 +124,11 @@ class LoopComponent(WorkflowComponent, LoopController): self._graph.add_conditional_edges(condition_node_id, self) self.init() - self._compiled = self._graph.compile() + self._compiled = self._graph.compile(self._context) def init(self): - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + BROKEN, False) - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + FIRST_IN_LOOP, True) + self._context.state.update(self._node_id, {self._context_root + NESTED_PATH_SPLIT + BROKEN: False}) + self._context.state.update(self._node_id, {self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: True}) self._condition.init() def to_executable(self) -> Executable: @@ -153,19 +158,20 @@ class LoopComponent(WorkflowComponent, LoopController): return self._out_loop def first_in_loop(self) -> bool: - _first_in_loop = self._context.store.read(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + FIRST_IN_LOOP) + _first_in_loop = self._context.state.get(self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP) if isinstance(_first_in_loop, bool): if _first_in_loop: - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + FIRST_IN_LOOP, False) + self._context.state.update(self._node_id, + {self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: False}) return _first_in_loop - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + FIRST_IN_LOOP, False) + self._context.state.update(self._node_id, {self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: False}) return True def is_broken(self) -> bool: - _is_broken = self._context.store.read(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + BROKEN) + _is_broken = self._context.state.get(self._context_root + NESTED_PATH_SPLIT + BROKEN) if isinstance(_is_broken, bool): return _is_broken return False def break_loop(self): - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + BROKEN, True) + self._context.state.update(self._node_id, {self._context_root + NESTED_PATH_SPLIT + BROKEN: True}) diff --git a/jiuwen/core/component/set_variable_comp.py b/jiuwen/core/component/set_variable_comp.py index 7ea4d2c..c2a22bd 100644 --- a/jiuwen/core/component/set_variable_comp.py +++ b/jiuwen/core/component/set_variable_comp.py @@ -3,29 +3,31 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. import asyncio from functools import partial -from typing import AsyncIterator, Iterator +from typing import AsyncIterator, Iterator, Any from jiuwen.core.component.base import WorkflowComponent from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import extract_origin_key from jiuwen.core.graph.executable import Executable, Input, Output class SetVariableComponent(WorkflowComponent, Executable): - def __init__(self, context: Context, varivable_mapping: dict[str, Any]): + def __init__(self, node_id: str, context: Context, variable_mapping: dict[str, Any]): self._context = context - self._varivable_mapping = varivable_mapping + self._node_id = node_id + self._variable_mapping = variable_mapping def invoke(self, inputs: Input, context: Context) -> Output: - for left, right in self._varivable_mapping.items(): - left_ref_str = get_ref_str(left) + for left, right in self._variable_mapping.items(): + left_ref_str = extract_origin_key(left) if left_ref_str == "": left_ref_str = left if isinstance(right, str): - ref_str = get_ref_str(right) + ref_str = extract_origin_key(right) if ref_str == "": - self._context.store.write(left_ref_str, self._context.store.read(ref_str)) + self._context.state.update(self._node_id, self._context.state.get(ref_str)) continue - self._context.store.write(left_ref_str, right) + self._context.state.update(self._node_id, {left_ref_str : right}) return None diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index f9afa32..33e7042 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -80,11 +80,11 @@ class Workflow: return self def add_connection(self, src_comp_id: str, target_comp_id: str) -> Self: - self._graph.add_edge(source_node_id=src_comp_id, target_node_id=target_comp_id) + self._graph.add_edge(src_comp_id, target_comp_id) return self def add_stream_connection(self, src_comp_id: str, target_comp_id: str) -> Self: - self._graph.add_edge(source_node_id=src_comp_id, target_node_id=target_comp_id) + self._graph.add_edge(src_comp_id, target_comp_id) if target_comp_id not in self._stream_edges: self._stream_edges[src_comp_id] = [target_comp_id] else: diff --git a/jiuwen/graph/pregel/graph.py b/jiuwen/graph/pregel/graph.py index ef0c529..b8da6a9 100644 --- a/jiuwen/graph/pregel/graph.py +++ b/jiuwen/graph/pregel/graph.py @@ -82,3 +82,6 @@ class CompiledGraph(ExecutableGraph): async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: return self._compiledStateGraph.astream({}) + + def interrupt(self, message: dict): + return diff --git a/tests/unit_tests/workflow/test_loop.py b/tests/unit_tests/workflow/test_loop.py deleted file mode 100644 index c862b1b..0000000 --- a/tests/unit_tests/workflow/test_loop.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. -from jiuwen.core.component.base import StartComponent, EndComponent -from jiuwen.core.component.condition.array import ArrayCondition -from jiuwen.core.component.loop_callback.intermediate_loop_var import IntermediateLoopVarCallback -from jiuwen.core.component.loop_callback.output import OutputCallback -from jiuwen.core.component.loop_comp import LoopGroup, LoopComponent -from jiuwen.core.component.set_variable_comp import SetVariableComponent -from jiuwen.core.context.context import Context -from jiuwen.core.workflow.base import Workflow -from tests.unit_tests.workflow.test_node import CommonNode, AddTenNode - - -def test_loop_array_condition(): - context = Context() - - flow = Workflow() - s = StartComponent() - e = EndComponent() - flow.set_start_comp("s", s) - flow.set_end_component("e", e, inputs_schema={"array_result": "${b.array_result}", "user_var": "${b.user_var}"}) - flow.add_workflow_comp("a", CommonNode("a"), inputs_schema={"array": "${user.inputs.input_array}"}) - flow.add_workflow_comp("b", CommonNode("b"), - inputs_schema={"array_result": "${l.results}", "user_var": "${l.user_var}"}) - - loop_group = LoopGroup(context) - node1 = AddTenNode("1") - loop_group.add_component("1", node1, inputs={"source": "${l.arrLoopVar.item}"}) - node2 = AddTenNode("2") - loop_group.add_component("2", node2, inputs={"source": "${l.intermediateLoopVar.user_var}"}) - node3 = SetVariableComponent(context, {"${l.intermediateLoopVar.user_var}": "${2.result}"}) - loop_group.add_component("3", node3) - loop_group.start_nodes(["1"]) - loop_group.end_nodes(["3"]) - loop_group.add_connection("1", "2") - loop_group.add_connection("2", "3") - - loop_condition = ArrayCondition(context, "l", {"item": "${a.loop_array}"}) - output_callback = OutputCallback(context, "l", - {"results": "${1.result}", "user_var": "${l.intermediateLoopVar.user_var}"}) - intermediate_callback = IntermediateLoopVarCallback(context, "l", - {"user_var": "${user.inputs.input_number}"}) - - loop = LoopComponent(context, "l", loop_group, loop_condition, - callbacks=[output_callback, intermediate_callback]) - flow.add_workflow_comp("l", loop) - - flow.add_connection("s", "a") - flow.add_connection("a", "l") - flow.add_connection("l", "b") - flow.add_connection("b", "e") - - result = flow.invoke({"input_array": [1, 2, 3], "input_number": 1}, context=context) - assert result == {"array_result": [11, 12, 13], "user_var": 31} - - result = flow.invoke({"input_array": [4, 5], "input_number": 2}, context=context) - assert result == {"array_result": [14, 15], "user_var": 22} diff --git a/tests/unit_tests/workflow/test_mock_node.py b/tests/unit_tests/workflow/test_mock_node.py index 8c6ea95..be83f06 100644 --- a/tests/unit_tests/workflow/test_mock_node.py +++ b/tests/unit_tests/workflow/test_mock_node.py @@ -36,3 +36,11 @@ class MockEndNode(EndComponent, MockNodeBase): def invoke(self, inputs: Input, context: Context) -> Output: return inputs + + +class Node1(MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + + def invoke(self, inputs: Input, context: Context) -> Output: + return {} \ No newline at end of file diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index b0ee178..0a556f2 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -1,35 +1,84 @@ +import unittest +from venv import create + from sqlalchemy.testing.suite.test_reflection import metadata -from jiuwen.core.component.base import StartComponent, EndComponent +from jiuwen.core.component.condition.array import ArrayCondition +from jiuwen.core.component.loop_callback.intermediate_loop_var import IntermediateLoopVarCallback +from jiuwen.core.component.loop_callback.output import OutputCallback +from jiuwen.core.component.loop_comp import LoopGroup, LoopComponent +from jiuwen.core.component.set_variable_comp import SetVariableComponent from jiuwen.core.context.config import Config from jiuwen.core.context.context import Context from jiuwen.core.context.memory.base import InMemoryState -from jiuwen.core.graph.executable import Input, Output +from jiuwen.core.graph.base import Graph from jiuwen.core.workflow.base import WorkflowConfig, Workflow from jiuwen.graph.pregel.graph import PregelGraph -from tests.unit_tests.workflow.test_mock_node import MockNodeBase, MockStartNode, MockEndNode +from test_node import AddTenNode, CommonNode +from tests.unit_tests.workflow.test_mock_node import MockNodeBase, MockStartNode, MockEndNode, Node1 + + +def create_context() -> Context: + return Context(config=Config(), state=InMemoryState(), store=None, tracer=None) + + +def create_graph() -> Graph: + return PregelGraph() + + +def create_flow() -> Workflow: + return Workflow(workflow_config=DEFAULT_WORKFLOW_CONFIG, graph=create_graph()) + + +DEFAULT_WORKFLOW_CONFIG = WorkflowConfig(metadata={}) -class Node1(MockNodeBase): - def __init__(self, node_id: str): - super().__init__(node_id) +class WorkflowTest(unittest.TestCase): + def test_simple_workflow(self): + flow = create_flow() + flow.add_workflow_comp("a", Node1("a"), inputs_schema={}) + flow.set_start_comp("start", MockStartNode("start")) + flow.set_end_comp("end", MockEndNode("end")) + flow.invoke({}, create_context()) - def invoke(self, inputs: Input, context: Context) -> Output: - return {} + def test_workflow_with_loop(self): + flow = create_flow() + flow.set_start_comp("s", MockStartNode("s")) + flow.set_end_comp("e", MockEndNode("e"), + inputs_schema={"array_result": "${b.array_result}", "user_var": "${b.user_var}"}) + flow.add_workflow_comp("a", CommonNode("a"), + inputs_schema={"array": "${user.inputs.input_array}"}) + flow.add_workflow_comp("b", CommonNode("b"), + inputs_schema={"array_result": "${l.results}", "user_var": "${l.user_var}"}) + # create loop: (1->2->3) + context = create_context() + loop_group = LoopGroup(context) + loop_group.add_component("1", AddTenNode("1"), inputs_schema={"source": "${l.arrLoopVar.item}"}) + loop_group.add_component("2", AddTenNode("2"), inputs_schema={"source": "${l.intermediateLoopVar.user_var}"}) + loop_group.add_component("3", SetVariableComponent("3", context, {"${l.intermediateLoopVar.user_var}": "${2.result}"})) + loop_group.start_nodes(["1"]) + loop_group.end_nodes(["3"]) + loop_group.add_connection("1", "2") + loop_group.add_connection("2", "3") + output_callback = OutputCallback(context, "l", + {"results": "${1.result}", "user_var": "${l.intermediateLoopVar.user_var}"}) + intermediate_callback = IntermediateLoopVarCallback(context, "l", + {"user_var": "${user.inputs.input_number}"}) -def test_workflow_base(): - workflow_config = WorkflowConfig(metadata = {}) - graph = PregelGraph() + loop = LoopComponent(context, "l", loop_group, ArrayCondition(context, "l", {"item": "${a.loop_array}"}), + callbacks=[output_callback, intermediate_callback]) - flow = Workflow(workflow_config=workflow_config, graph=graph) - flow.add_workflow_comp("a", Node1("a"), inputs_schema={}) - flow.set_start_comp("start", MockStartNode("start")) - flow.set_end_comp("end", MockEndNode("end")) + flow.add_workflow_comp("l", loop) - context = Context(config = Config(), state = InMemoryState(), store = None, tracer = None) - flow.invoke({}, context) + # s->a->(1->2->3)->b->e + flow.add_connection("s", "a") + flow.add_connection("a", "l") + flow.add_connection("l", "b") + flow.add_connection("b", "e") + result = flow.invoke({"input_array": [1, 2, 3], "input_number": 1}, context=context) + assert result == {"array_result": [11, 12, 13], "user_var": 31} -if __name__ == "__main__": - test_workflow_base() + result = flow.invoke({"input_array": [4, 5], "input_number": 2}, context=context) + assert result == {"array_result": [14, 15], "user_var": 22} -- Gitee From b95fa59b75023ceb660c027ac4667a75d817be2f Mon Sep 17 00:00:00 2001 From: caoyuzhe Date: Mon, 14 Jul 2025 17:31:55 +0800 Subject: [PATCH 4/7] =?UTF-8?q?fix=EF=BC=9Aworkflow=E8=81=94=E8=B0=83?= =?UTF-8?q?=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/component/condition/array.py | 10 ++++------ jiuwen/core/component/condition/number.py | 4 ++-- .../loop_callback/intermediate_loop_var.py | 4 ++-- jiuwen/core/component/loop_callback/output.py | 14 +++++++------- jiuwen/core/component/loop_comp.py | 2 ++ jiuwen/core/component/set_variable_comp.py | 3 +++ jiuwen/core/context/context.py | 1 + jiuwen/core/context/memory/base.py | 4 +++- jiuwen/core/context/state.py | 13 ++++++++----- jiuwen/core/context/utils.py | 7 +++++-- jiuwen/core/graph/vertex.py | 8 +++++--- jiuwen/core/workflow/base.py | 1 + jiuwen/graph/pregel/graph.py | 15 ++++++++++----- pyproject.toml | 3 ++- tests/unit_tests/workflow/test_mock_node.py | 5 +++++ tests/unit_tests/workflow/test_node.py | 5 +++++ tests/unit_tests/workflow/test_workflow.py | 2 +- 17 files changed, 66 insertions(+), 35 deletions(-) diff --git a/jiuwen/core/component/condition/array.py b/jiuwen/core/component/condition/array.py index 0f97247..178949f 100644 --- a/jiuwen/core/component/condition/array.py +++ b/jiuwen/core/component/condition/array.py @@ -20,8 +20,8 @@ class ArrayCondition(Condition): self._arrays_root = array_root if array_root else node_id + NESTED_PATH_SPLIT + "arrLoopVar" def init(self): - self._context.state.update(self._node_id, {self._index_path: 0}) - self._context.state.update(self._node_id, {self._arrays_root: {}}) + self._context.state.io_state.update(self._node_id, {self._index_path: 0}) + self._context.state.io_state.update(self._node_id, {self._arrays_root: {}}) def __call__(self) -> bool: current_idx = self._context.state.get(self._index_path) @@ -45,8 +45,6 @@ class ArrayCondition(Condition): return False updates[key_path] = arr[current_idx] - self._context.state.update(self._node_id, {self._index_path: current_idx + 1}) - for path, update in updates.items(): - self._context.state.update(self._node_id, update) - + self._context.state.io_state.update(self._node_id, {self._index_path: current_idx + 1}) + self._context.state.io_state.update(self._node_id, updates) return True diff --git a/jiuwen/core/component/condition/number.py b/jiuwen/core/component/condition/number.py index 68ab057..b150784 100644 --- a/jiuwen/core/component/condition/number.py +++ b/jiuwen/core/component/condition/number.py @@ -16,7 +16,7 @@ class NumberCondition(Condition): self._node_id = node_id def init(self): - self._context.state.update(self._node_id, {self._index_path: 0}) + self._context.state.io_state.update(self._node_id, {self._index_path: 0}) def __call__(self) -> bool: current_idx = self._context.state.get(self._index_path) @@ -28,6 +28,6 @@ class NumberCondition(Condition): result = current_idx < limit_num if result: - self._context.state.update(self._node_id, {self._index_path: current_idx + 1}) + self._context.state.io_state.update(self._node_id, {self._index_path: current_idx + 1}) return result diff --git a/jiuwen/core/component/loop_callback/intermediate_loop_var.py b/jiuwen/core/component/loop_callback/intermediate_loop_var.py index f090c26..88c3cc4 100644 --- a/jiuwen/core/component/loop_callback/intermediate_loop_var.py +++ b/jiuwen/core/component/loop_callback/intermediate_loop_var.py @@ -29,10 +29,10 @@ class IntermediateLoopVarCallback(LoopCallback): update = value else: update = value - self._context.state.update(self._node_id, {path: update}) + self._context.state.io_state.update(self._node_id, {path: update}) def out_loop(self): - self._context.state.update(self._node_id, {self._intermediate_loop_var_root: {}}) + self._context.state.io_state.update(self._node_id, {self._intermediate_loop_var_root: {}}) def start_round(self): pass diff --git a/jiuwen/core/component/loop_callback/output.py b/jiuwen/core/component/loop_callback/output.py index ae0b8d1..d193271 100644 --- a/jiuwen/core/component/loop_callback/output.py +++ b/jiuwen/core/component/loop_callback/output.py @@ -18,24 +18,24 @@ class OutputCallback(LoopCallback): self._result_root = result_root if result_root else node_id self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root else node_id + NESTED_PATH_SPLIT + "intermediateLoopVar" - def _generate_results(self, results: dict[str, list[Any]]): + def _generate_results(self, results: list[(str, list[Any])]): for key, value in self._outputs_format.items(): if isinstance(value, str) and is_ref_path(value): ref_str = extract_origin_key(value) - results[ref_str] = [] + results.append ((ref_str, [])) elif isinstance(value, dict): self._generate_results(results) def first_in_loop(self): - _results: dict[str, list[Any]] = {} + _results: list[(str, list[Any])] = [] self._generate_results(_results) - self._context.state.update(self._round_result_root, _results) + self._context.state.update(self._node_id, {self._round_result_root: _results}) def out_loop(self): - results: dict[str, list[Any]] = self._context.state.get(self._round_result_root) - if not isinstance(results, dict): + results: list[(str, list[Any])] = self._context.state.get(self._round_result_root) + if not isinstance(results, list): raise RuntimeError("error results in loop process") - for path, array in results.items(): + for (path, array) in results: self._context.state.update(self._node_id, {path: array}) result = self._context.state.get_inputs(self._outputs_format) self._context.state.update(self._node_id, {self._round_result_root : {}}) diff --git a/jiuwen/core/component/loop_comp.py b/jiuwen/core/component/loop_comp.py index 081d852..cf609a8 100644 --- a/jiuwen/core/component/loop_comp.py +++ b/jiuwen/core/component/loop_comp.py @@ -151,6 +151,8 @@ class LoopComponent(WorkflowComponent, LoopController): else: callback.out_loop() + self._context.state.io_state.commit() + self._context.state.global_state.commit() if continue_loop: return self._in_loop diff --git a/jiuwen/core/component/set_variable_comp.py b/jiuwen/core/component/set_variable_comp.py index c2a22bd..6c223c4 100644 --- a/jiuwen/core/component/set_variable_comp.py +++ b/jiuwen/core/component/set_variable_comp.py @@ -44,3 +44,6 @@ class SetVariableComponent(WorkflowComponent, Executable): def interrupt(self, message: dict): pass + + def to_executable(self) -> Executable: + return self \ No newline at end of file diff --git a/jiuwen/core/context/context.py b/jiuwen/core/context/context.py index cf1fc69..0da9a2f 100644 --- a/jiuwen/core/context/context.py +++ b/jiuwen/core/context/context.py @@ -28,6 +28,7 @@ class Context(ABC): return False self._workflow_config = workflow_config self._stream_modes = stream_modes + return True @property def config(self) -> Config: diff --git a/jiuwen/core/context/memory/base.py b/jiuwen/core/context/memory/base.py index 04c3752..8106d70 100644 --- a/jiuwen/core/context/memory/base.py +++ b/jiuwen/core/context/memory/base.py @@ -34,7 +34,7 @@ class InMemoryStateLike(StateLike): return transformer(self._state) def update(self, node_id: str, data: dict) -> None: - update_dict(self._state, data) + update_dict(data, self._state) class InMemoryCommitState(CommitState): @@ -43,6 +43,8 @@ class InMemoryCommitState(CommitState): self._updates: dict[str, list[dict]] = dict() def update(self, node_id: str, data: dict) -> None: + if data == 1: + print("tmp") if node_id not in self._updates: self._updates[node_id] = [] self._updates[node_id].append(data) diff --git a/jiuwen/core/context/state.py b/jiuwen/core/context/state.py index beb6e00..a181855 100644 --- a/jiuwen/core/context/state.py +++ b/jiuwen/core/context/state.py @@ -63,7 +63,10 @@ class State(ABC): def get(self, key: Union[str, dict]) -> Optional[Any]: if self._global_state is None: return None - return self._global_state.get(key) + value = self._global_state.get(key) + if value is None: + return self._io_state.get(key) + return value def update(self, node_id: str, data: dict) -> None: if self._global_state is None: @@ -85,10 +88,10 @@ class State(ABC): return self._comp_state.update(node_id, data) - def set_user_inputs(self, inputs: dict) -> None: + def set_user_inputs(self, inputs: Any) -> None: if self._io_state is None: return - self._io_state.update({"user": {"inputs": inputs}}) + self._io_state.update("user", {"user.inputs": inputs}) def get_inputs(self, input_schemas: dict) -> dict: if self._io_state is None: @@ -101,7 +104,7 @@ class State(ABC): return self._io_state.get(node_id) def set_outputs(self, node_id: str, outputs: dict) -> None: - if self._io_state is None: + if self._io_state is None or outputs is None: return - return self._io_state.update(node_id, outputs) + return self._io_state.update(node_id, {node_id: outputs}) diff --git a/jiuwen/core/context/utils.py b/jiuwen/core/context/utils.py index 7e1fbc5..75a68b6 100644 --- a/jiuwen/core/context/utils.py +++ b/jiuwen/core/context/utils.py @@ -20,6 +20,9 @@ def update_dict(update: dict, source: dict) -> None: :param update: update dict, which key is nested :param source: source dict, which key must not be nested """ + if update == 1: + print("error") + for key, value in update.items(): current_key, current = root_to_path(key, source, create_if_absent=True) update_by_key(current_key, value, current) @@ -27,8 +30,8 @@ def update_dict(update: dict, source: dict) -> None: def get_value_by_nested_path(nested_key: str, source: dict) -> Optional[Any]: result = root_to_path(nested_key, source) - if result is None: - return result + if result[1] is None: + return None return result[1][result[0]] diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index f853d2b..c740368 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -19,7 +19,7 @@ class Vertex: return True def __call__(self, state: GraphState) -> Output: - if self._context is None: + if self._context is None or self._executable is None: raise JiuWenBaseException(1, "vertex is not initialized, node is is " + self._node_id) inputs = self.__pre_invoke__(state) is_stream = self.__is_stream__(state) @@ -36,13 +36,15 @@ class Vertex: def __pre_invoke__(self, state:GraphState) -> Optional[dict]: inputs_schema = self._context.config.get_inputs_schema(self._node_id) - inputs = self._context.state.get_inputs(inputs_schema) + inputs = self._context.state.get_inputs(inputs_schema) if inputs_schema else None if self._context.tracer is not None: self.__trace_inputs__(inputs) return inputs def __post_invoke__(self, results: Optional[dict]) -> None: - self._context.state.set_outputs(results) + self._context.state.set_outputs(self._node_id, results) + self._context.state.io_state.commit() + self._context.state.global_state.commit() pass def __post_stream__(self, results_iter: Any) -> None: diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index 33e7042..755d233 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -101,6 +101,7 @@ class Workflow: return None compiled_graph = self._graph.compile(context) context.state.set_user_inputs(inputs) + context.state.io_state.commit() compiled_graph.invoke(inputs, context) return context.state.get_outputs(self._end_comp_id) diff --git a/jiuwen/graph/pregel/graph.py b/jiuwen/graph/pregel/graph.py index b8da6a9..2f97ea0 100644 --- a/jiuwen/graph/pregel/graph.py +++ b/jiuwen/graph/pregel/graph.py @@ -19,6 +19,7 @@ class PregelGraph(Graph): self.compiledStateGraph = None self.edges: list[Union[str, list[str]], str] = [] self.waits: set[str] = set() + self.nodes: list[Vertex] = [] def start_node(self, node_id: str) -> Self: self.pregel.set_entry_point(node_id) @@ -29,7 +30,9 @@ class PregelGraph(Graph): return self def add_node(self, node_id: str, node: Executable, *, wait_for_all: bool = False) -> Self: - self.pregel.add_node(node_id, Vertex(node_id, node)) + vertex_node = Vertex(node_id, node) + self.nodes.append(vertex_node) + self.pregel.add_node(node_id, vertex_node) if wait_for_all: self.waits.add(node_id) return self @@ -43,6 +46,8 @@ class PregelGraph(Graph): return self def compile(self, context: Context) -> ExecutableGraph: + for node in self.nodes: + node.init(context) if self.compiledStateGraph is None: self._pre_compile() self.compiledStateGraph = self.pregel.compile() @@ -72,16 +77,16 @@ class CompiledGraph(ExecutableGraph): self._compiledStateGraph = compiledStateGraph def invoke(self, inputs: Input, context: Context) -> Output: - return self._compiledStateGraph.invoke({}) + return self._compiledStateGraph.invoke({"source_node_id": ""}) def stream(self, inputs: Input, context: Context) -> Iterator[Output]: - return self._compiledStateGraph.stream({}) + return self._compiledStateGraph.stream({"source_node_id": ""}) async def ainvoke(self, inputs: Input, context: Context) -> Output: - return self._compiledStateGraph.ainvoke({}) + return self._compiledStateGraph.ainvoke({"source_node_id": ""}) async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: - return self._compiledStateGraph.astream({}) + return self._compiledStateGraph.astream({"source_node_id": ""}) def interrupt(self, message: dict): return diff --git a/pyproject.toml b/pyproject.toml index 9bc920f..ad584eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "langgraph==0.2.35", "mcp==1.7.1", "pydantic==2.10.6", + "sqlalchemy>=2.0.41", ] [tool.uv] @@ -29,4 +30,4 @@ dev = [ ] [tool.coverage.run] -omit = ["tests/*"] \ No newline at end of file +omit = ["tests/*"] diff --git a/tests/unit_tests/workflow/test_mock_node.py b/tests/unit_tests/workflow/test_mock_node.py index be83f06..90b9ce1 100644 --- a/tests/unit_tests/workflow/test_mock_node.py +++ b/tests/unit_tests/workflow/test_mock_node.py @@ -6,6 +6,9 @@ from jiuwen.core.graph.executable import Executable, Input, Output class MockNodeBase(Executable, WorkflowComponent): + def invoke(self, inputs: Input, context: Context) -> Output: + pass + def __init__(self, node_id: str): super().__init__() self.node_id = node_id @@ -22,6 +25,8 @@ class MockNodeBase(Executable, WorkflowComponent): def interrupt(self, message: dict): return + def to_executable(self) -> Executable: + return self class MockStartNode(StartComponent, MockNodeBase): def __init__(self, node_id: str): diff --git a/tests/unit_tests/workflow/test_node.py b/tests/unit_tests/workflow/test_node.py index cacc0b1..52dcc8b 100644 --- a/tests/unit_tests/workflow/test_node.py +++ b/tests/unit_tests/workflow/test_node.py @@ -29,6 +29,8 @@ class CommonNode(Executable, WorkflowComponent): def interrupt(self, message: dict): pass + def to_executable(self) -> Executable: + return self class AddTenNode(Executable, WorkflowComponent): @@ -50,3 +52,6 @@ class AddTenNode(Executable, WorkflowComponent): def interrupt(self, message: dict): pass + + def to_executable(self) -> Executable: + return self \ No newline at end of file diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index 0a556f2..741f8c4 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -66,7 +66,7 @@ class WorkflowTest(unittest.TestCase): intermediate_callback = IntermediateLoopVarCallback(context, "l", {"user_var": "${user.inputs.input_number}"}) - loop = LoopComponent(context, "l", loop_group, ArrayCondition(context, "l", {"item": "${a.loop_array}"}), + loop = LoopComponent(context, "l", loop_group, ArrayCondition(context, "l", {"item": "${a.array}"}), callbacks=[output_callback, intermediate_callback]) flow.add_workflow_comp("l", loop) -- Gitee From 75050a90d2419aa468606b089f69e320ceeee00b Mon Sep 17 00:00:00 2001 From: CandiceGuo Date: Mon, 14 Jul 2025 18:42:03 +0800 Subject: [PATCH 5/7] =?UTF-8?q?fix:=20=E5=A2=9E=E5=8A=A0=E5=9F=BA=E7=A1=80?= =?UTF-8?q?workflow=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/context/memory/base.py | 5 +---- jiuwen/core/context/utils.py | 5 ++--- jiuwen/core/graph/vertex.py | 4 ++-- jiuwen/graph/pregel/graph.py | 8 ++++---- tests/unit_tests/workflow/test_mock_node.py | 9 ++++++++- tests/unit_tests/workflow/test_workflow.py | 14 ++++++++++---- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/jiuwen/core/context/memory/base.py b/jiuwen/core/context/memory/base.py index 8106d70..b1c767a 100644 --- a/jiuwen/core/context/memory/base.py +++ b/jiuwen/core/context/memory/base.py @@ -4,7 +4,6 @@ from copy import deepcopy from typing import Union, Optional, Any, Callable -from jiuwen.core.common.exception.exception import JiuWenBaseException from jiuwen.core.context.state import CommitState, StateLike, State from jiuwen.core.context.utils import extract_origin_key, get_value_by_nested_path, update_dict @@ -28,7 +27,7 @@ class InMemoryStateLike(StateLike): result.append(self.get(item)) return result else: - raise JiuWenBaseException(1, "key type is not support") + return key def get_by_transformer(self, transformer: Callable) -> Optional[Any]: return transformer(self._state) @@ -43,8 +42,6 @@ class InMemoryCommitState(CommitState): self._updates: dict[str, list[dict]] = dict() def update(self, node_id: str, data: dict) -> None: - if data == 1: - print("tmp") if node_id not in self._updates: self._updates[node_id] = [] self._updates[node_id].append(data) diff --git a/jiuwen/core/context/utils.py b/jiuwen/core/context/utils.py index 75a68b6..a317378 100644 --- a/jiuwen/core/context/utils.py +++ b/jiuwen/core/context/utils.py @@ -20,9 +20,6 @@ def update_dict(update: dict, source: dict) -> None: :param update: update dict, which key is nested :param source: source dict, which key must not be nested """ - if update == 1: - print("error") - for key, value in update.items(): current_key, current = root_to_path(key, source, create_if_absent=True) update_by_key(current_key, value, current) @@ -32,6 +29,8 @@ def get_value_by_nested_path(nested_key: str, source: dict) -> Optional[Any]: result = root_to_path(nested_key, source) if result[1] is None: return None + if result[0] not in result[1]: + return None return result[1][result[0]] diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index c740368..c9bc19f 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -21,7 +21,7 @@ class Vertex: def __call__(self, state: GraphState) -> Output: if self._context is None or self._executable is None: raise JiuWenBaseException(1, "vertex is not initialized, node is is " + self._node_id) - inputs = self.__pre_invoke__(state) + inputs = self.__pre_invoke__() is_stream = self.__is_stream__(state) try: if is_stream: @@ -34,7 +34,7 @@ class Vertex: raise JiuWenBaseException(e.error_code, "failed to invoke, caused by " + e.message) return {"source_node_id": self._node_id} - def __pre_invoke__(self, state:GraphState) -> Optional[dict]: + def __pre_invoke__(self) -> Optional[dict]: inputs_schema = self._context.config.get_inputs_schema(self._node_id) inputs = self._context.state.get_inputs(inputs_schema) if inputs_schema else None if self._context.tracer is not None: diff --git a/jiuwen/graph/pregel/graph.py b/jiuwen/graph/pregel/graph.py index 2f97ea0..ffee283 100644 --- a/jiuwen/graph/pregel/graph.py +++ b/jiuwen/graph/pregel/graph.py @@ -77,16 +77,16 @@ class CompiledGraph(ExecutableGraph): self._compiledStateGraph = compiledStateGraph def invoke(self, inputs: Input, context: Context) -> Output: - return self._compiledStateGraph.invoke({"source_node_id": ""}) + return self._compiledStateGraph.invoke({"source_node_id": None}) def stream(self, inputs: Input, context: Context) -> Iterator[Output]: - return self._compiledStateGraph.stream({"source_node_id": ""}) + return self._compiledStateGraph.stream({"source_node_id": None}) async def ainvoke(self, inputs: Input, context: Context) -> Output: - return self._compiledStateGraph.ainvoke({"source_node_id": ""}) + return self._compiledStateGraph.ainvoke({"source_node_id": None}) async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: - return self._compiledStateGraph.astream({"source_node_id": ""}) + return self._compiledStateGraph.astream({"source_node_id": None}) def interrupt(self, message: dict): return diff --git a/tests/unit_tests/workflow/test_mock_node.py b/tests/unit_tests/workflow/test_mock_node.py index 90b9ce1..311be9f 100644 --- a/tests/unit_tests/workflow/test_mock_node.py +++ b/tests/unit_tests/workflow/test_mock_node.py @@ -33,13 +33,18 @@ class MockStartNode(StartComponent, MockNodeBase): super().__init__(node_id) def invoke(self, inputs: Input, context: Context) -> Output: + context.state.set_outputs(self.node_id, inputs) + print("start: output = " + str(inputs)) return inputs class MockEndNode(EndComponent, MockNodeBase): def __init__(self, node_id: str): super().__init__(node_id) + self.node_id = node_id def invoke(self, inputs: Input, context: Context) -> Output: + context.state.set_outputs(self.node_id, inputs) + print("endNode: output = " + str(inputs)) return inputs @@ -48,4 +53,6 @@ class Node1(MockNodeBase): super().__init__(node_id) def invoke(self, inputs: Input, context: Context) -> Output: - return {} \ No newline at end of file + context.state.set_outputs(self.node_id, inputs) + print("node1: output = " + str(inputs)) + return inputs \ No newline at end of file diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index 741f8c4..3744a6d 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -36,10 +36,16 @@ DEFAULT_WORKFLOW_CONFIG = WorkflowConfig(metadata={}) class WorkflowTest(unittest.TestCase): def test_simple_workflow(self): flow = create_flow() - flow.add_workflow_comp("a", Node1("a"), inputs_schema={}) - flow.set_start_comp("start", MockStartNode("start")) - flow.set_end_comp("end", MockEndNode("end")) - flow.invoke({}, create_context()) + flow.add_workflow_comp("a", Node1("a"), + inputs_schema={"aa": "${start.a}", "c" : "${start.c}"}) + flow.set_start_comp("start", MockStartNode("start"), + inputs_schema={"a" : "${user.inputs.a}", "b" : "${user.inputs.b}", "c": 1, "d" : [1,2,3]}) + flow.set_end_comp("end", MockEndNode("end"), + inputs_schema={"result": "${a.aa}"}) + flow.add_connection("start", "a") + flow.add_connection("a", "end") + result = flow.invoke({"a": 1, "b": "bvalue"}, create_context()) + assert result["result"] == 1 def test_workflow_with_loop(self): flow = create_flow() -- Gitee From d9f905a37625510d4dfe8e593768a30ebe5d45af Mon Sep 17 00:00:00 2001 From: CandiceGuo Date: Mon, 14 Jul 2025 19:33:24 +0800 Subject: [PATCH 6/7] =?UTF-8?q?fix:=20=E5=A2=9E=E5=8A=A0condition=E5=9F=BA?= =?UTF-8?q?=E6=9C=AC=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/workflow/base.py | 2 +- jiuwen/graph/pregel/graph.py | 8 +-- tests/unit_tests/workflow/test_workflow.py | 82 +++++++++++++++++++--- 3 files changed, 78 insertions(+), 14 deletions(-) diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index 755d233..76973b9 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -107,7 +107,7 @@ class Workflow: async def ainvoke(self, inputs: Input, context: Context) -> Output: return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, context), inputs + None, partial(self.invoke, context = context), inputs ) def stream( diff --git a/jiuwen/graph/pregel/graph.py b/jiuwen/graph/pregel/graph.py index ffee283..32c8a1c 100644 --- a/jiuwen/graph/pregel/graph.py +++ b/jiuwen/graph/pregel/graph.py @@ -37,12 +37,12 @@ class PregelGraph(Graph): self.waits.add(node_id) return self - def add_edge(self, start_node_id: Union[str, list[str]], end_node_id: str) -> Self: - self.edges.append((start_node_id, end_node_id)) + def add_edge(self, source_node_id: Union[str, list[str]], target_node_id: str) -> Self: + self.edges.append((source_node_id, target_node_id)) return self - def add_conditional_edges(self, source: str, router: Router) -> Self: - self.pregel.add_conditional_edges(source, router) + def add_conditional_edges(self, source_node_id: str, router: Router) -> Self: + self.pregel.add_conditional_edges(source_node_id, router) return self def compile(self, context: Context) -> ExecutableGraph: diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index 3744a6d..0767d0f 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -1,7 +1,8 @@ +import asyncio import unittest -from venv import create +from collections.abc import Callable -from sqlalchemy.testing.suite.test_reflection import metadata +import random from jiuwen.core.component.condition.array import ArrayCondition from jiuwen.core.component.loop_callback.intermediate_loop_var import IntermediateLoopVarCallback @@ -12,6 +13,7 @@ from jiuwen.core.context.config import Config from jiuwen.core.context.context import Context from jiuwen.core.context.memory.base import InMemoryState from jiuwen.core.graph.base import Graph +from jiuwen.core.graph.graph_state import GraphState from jiuwen.core.workflow.base import WorkflowConfig, Workflow from jiuwen.graph.pregel.graph import PregelGraph from test_node import AddTenNode, CommonNode @@ -34,18 +36,79 @@ DEFAULT_WORKFLOW_CONFIG = WorkflowConfig(metadata={}) class WorkflowTest(unittest.TestCase): + def assert_workflow_invoke(self, inputs: dict, context: Context, flow: Workflow, expect_results: dict = None, + checker: Callable = None): + results = flow.invoke(inputs=inputs, context=context) + if expect_results is not None: + assert results == expect_results + elif checker is not None: + checker(results) + + def assert_workflow_ainvoke(self, inputs: dict, context: Context, flow: Workflow, expect_results: dict = None, + checker: Callable = None): + loop = asyncio.get_event_loop() + feature = asyncio.ensure_future(flow.ainvoke(inputs=inputs, context=context)) + loop.run_until_complete(feature) + if expect_results is not None: + assert feature.result() == expect_results + elif checker is not None: + checker(feature.result()) + def test_simple_workflow(self): + """ + graph : start->a->end + """ flow = create_flow() - flow.add_workflow_comp("a", Node1("a"), - inputs_schema={"aa": "${start.a}", "c" : "${start.c}"}) flow.set_start_comp("start", MockStartNode("start"), - inputs_schema={"a" : "${user.inputs.a}", "b" : "${user.inputs.b}", "c": 1, "d" : [1,2,3]}) + inputs_schema={ + "a": "${user.inputs.a}", + "b": "${user.inputs.b}", + "c": 1, + "d": [1, 2, 3]}) + flow.add_workflow_comp("a", Node1("a"), + inputs_schema={ + "aa": "${start.a}", + "ac": "${start.c}"}) flow.set_end_comp("end", MockEndNode("end"), - inputs_schema={"result": "${a.aa}"}) + inputs_schema={ + "result": "${a.aa}"}) flow.add_connection("start", "a") flow.add_connection("a", "end") - result = flow.invoke({"a": 1, "b": "bvalue"}, create_context()) - assert result["result"] == 1 + self.assert_workflow_invoke({"a": 1, "b": "haha"}, create_context(), flow, expect_results={"result": 1}) + self.assert_workflow_ainvoke({"a": 1, "b": "haha"}, create_context(), flow, expect_results={"result": 1}) + + def test_simple_workflow_with_condition(self): + """ + start -> condition[a,b] -> end + :return: + """ + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start"), + inputs_schema={"a": "${user.inputs.a}", + "b": "${user.inputs.b}", + "c": 1, + "d": [1, 2, 3]}) + + def router(state: GraphState): + condition_nodes = ["a", "b"] + randomIdx = random.randint(1, 2) + return condition_nodes[randomIdx - 1] + + flow.add_conditional_connection("start", router=router) + flow.add_workflow_comp("a", Node1("a"), inputs_schema={"a": "${start.a}", "b": "${start.c}"}) + flow.add_workflow_comp("b", Node1("b"), inputs_schema={"b": "${start.b}"}) + flow.set_end_comp("end", MockEndNode("end"), {"result1": "${a.a}", "result2": "${b.b}"}) + flow.add_connection("a", "end") + flow.add_connection("b", "end") + + def checker(results): + if "result1" in results: + assert results["result1"] == 1 + elif "result2" in results: + assert results["result2"] == "haha" + + self.assert_workflow_invoke({"a": 1, "b": "haha"}, create_context(), flow, checker=checker) + self.assert_workflow_ainvoke({"a": 1, "b": "haha"}, create_context(), flow, checker=checker) def test_workflow_with_loop(self): flow = create_flow() @@ -62,7 +125,8 @@ class WorkflowTest(unittest.TestCase): loop_group = LoopGroup(context) loop_group.add_component("1", AddTenNode("1"), inputs_schema={"source": "${l.arrLoopVar.item}"}) loop_group.add_component("2", AddTenNode("2"), inputs_schema={"source": "${l.intermediateLoopVar.user_var}"}) - loop_group.add_component("3", SetVariableComponent("3", context, {"${l.intermediateLoopVar.user_var}": "${2.result}"})) + loop_group.add_component("3", SetVariableComponent("3", context, + {"${l.intermediateLoopVar.user_var}": "${2.result}"})) loop_group.start_nodes(["1"]) loop_group.end_nodes(["3"]) loop_group.add_connection("1", "2") -- Gitee From 65c3c4950604f5ee63427232b98942c26953c8a9 Mon Sep 17 00:00:00 2001 From: caoyuzhe Date: Mon, 14 Jul 2025 19:40:03 +0800 Subject: [PATCH 7/7] =?UTF-8?q?fix=EF=BC=9A=E5=88=86=E6=94=AF=E7=BB=84?= =?UTF-8?q?=E4=BB=B6=E5=92=8C=E5=BE=AA=E7=8E=AF=E7=BB=84=E4=BB=B6=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=E8=81=94=E8=B0=83=EF=BC=88=E9=83=A8=E5=88=86=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/component/branch_comp.py | 13 +++++--- jiuwen/core/component/branch_router.py | 1 + jiuwen/core/component/loop_callback/output.py | 28 +++++++++-------- jiuwen/core/component/set_variable_comp.py | 11 ++++--- tests/unit_tests/workflow/test_workflow.py | 30 +++++++++++++++++++ 5 files changed, 61 insertions(+), 22 deletions(-) diff --git a/jiuwen/core/component/branch_comp.py b/jiuwen/core/component/branch_comp.py index 12937dc..4f28de8 100644 --- a/jiuwen/core/component/branch_comp.py +++ b/jiuwen/core/component/branch_comp.py @@ -6,12 +6,17 @@ from contextvars import Context from functools import partial from typing import Callable, Union, Hashable, Iterator, AsyncIterator +from jiuwen.core.component.base import WorkflowComponent from jiuwen.core.component.branch_router import BranchRouter from jiuwen.core.component.condition.condition import Condition from jiuwen.core.graph.base import Graph +from jiuwen.core.graph.executable import Executable, Input, Output class BranchComponent(WorkflowComponent, Executable): + def interrupt(self, message: dict): + pass + def __init__(self, context: Context, executable: Executable = None): self._router = BranchRouter(context) self._executable = executable @@ -29,8 +34,8 @@ class BranchComponent(WorkflowComponent, Executable): return self def invoke(self, inputs: Input, context: Context) -> Output: - if self._executable is None: - return self._executor.invoke(inputs, context) + if self._executable: + return self._executable.invoke(inputs, context) return inputs async def ainvoke(self, inputs: Input, context: Context) -> Output: @@ -43,6 +48,6 @@ class BranchComponent(WorkflowComponent, Executable): async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: yield await self.ainvoke(inputs, context) - def add_component(self, graph: Graph, node_id: str, wait_for_all: bool = False, inputs: dict = None): - graph.add_node(node_id, self.to_executable(), wait_for_all=wait_for_all, inputs=inputs) + def add_component(self, graph: Graph, node_id: str, wait_for_all: bool = False): + graph.add_node(node_id, self.to_executable(), wait_for_all=wait_for_all) graph.add_conditional_edges(node_id, self.router()) diff --git a/jiuwen/core/component/branch_router.py b/jiuwen/core/component/branch_router.py index 6b25d7d..0c50e66 100644 --- a/jiuwen/core/component/branch_router.py +++ b/jiuwen/core/component/branch_router.py @@ -5,6 +5,7 @@ from typing import Callable, Union from jiuwen.core.component.condition.condition import Condition, FuncCondition from jiuwen.core.component.condition.expression import ExpressionCondition +from jiuwen.core.context.context import Context class Branch: diff --git a/jiuwen/core/component/loop_callback/output.py b/jiuwen/core/component/loop_callback/output.py index d193271..d9b46e6 100644 --- a/jiuwen/core/component/loop_callback/output.py +++ b/jiuwen/core/component/loop_callback/output.py @@ -18,41 +18,45 @@ class OutputCallback(LoopCallback): self._result_root = result_root if result_root else node_id self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root else node_id + NESTED_PATH_SPLIT + "intermediateLoopVar" - def _generate_results(self, results: list[(str, list[Any])]): + def _generate_results(self, results: list[(str, Any)]): for key, value in self._outputs_format.items(): if isinstance(value, str) and is_ref_path(value): ref_str = extract_origin_key(value) - results.append ((ref_str, [])) + results.append ((ref_str, None)) elif isinstance(value, dict): self._generate_results(results) def first_in_loop(self): - _results: list[(str, list[Any])] = [] + _results: list[(str, Any)] = [] self._generate_results(_results) self._context.state.update(self._node_id, {self._round_result_root: _results}) def out_loop(self): - results: list[(str, list[Any])] = self._context.state.get(self._round_result_root) + results: list[(str, Any)] = self._context.state.get(self._round_result_root) if not isinstance(results, list): raise RuntimeError("error results in loop process") - for (path, array) in results: - self._context.state.update(self._node_id, {path: array}) + for (path, value) in results: + self._context.state.io_state.update(self._node_id, {path: value}) + self._context.state.io_state.commit() result = self._context.state.get_inputs(self._outputs_format) self._context.state.update(self._node_id, {self._round_result_root : {}}) - self._context.state.set_outputs(self._node_id, {self._result_root : result}) + self._context.state.set_outputs(self._node_id, result) def start_round(self): pass def end_round(self): - results: dict[str, list[Any]] = self._context.state.get(self._round_result_root) - if not isinstance(results, dict): + results: list[(str, Any)] = self._context.state.get(self._round_result_root) + if not isinstance(results, list): raise RuntimeError("error results in round process") - for path, value in results.items(): + for value in results: + path = value[0] if path.startswith(self._intermediate_loop_var_root): - results[path] = self._context.state.get(path) + value[1] = self._context.state.get(path) elif isinstance(value, list): - value.append(self._context.state.get(path)) + if value[1] is None: + value[1] = [] + value[1].append(self._context.state.get(path)) else: raise RuntimeError("error process in loop: " + path + ", " + str(value)) self._context.state.update(self._node_id, {self._round_result_root : results}) diff --git a/jiuwen/core/component/set_variable_comp.py b/jiuwen/core/component/set_variable_comp.py index 6c223c4..fdcbdbb 100644 --- a/jiuwen/core/component/set_variable_comp.py +++ b/jiuwen/core/component/set_variable_comp.py @@ -7,7 +7,7 @@ from typing import AsyncIterator, Iterator, Any from jiuwen.core.component.base import WorkflowComponent from jiuwen.core.context.context import Context -from jiuwen.core.context.utils import extract_origin_key +from jiuwen.core.context.utils import extract_origin_key, is_ref_path from jiuwen.core.graph.executable import Executable, Input, Output @@ -22,12 +22,11 @@ class SetVariableComponent(WorkflowComponent, Executable): left_ref_str = extract_origin_key(left) if left_ref_str == "": left_ref_str = left - if isinstance(right, str): + if isinstance(right, str) and is_ref_path(right): ref_str = extract_origin_key(right) - if ref_str == "": - self._context.state.update(self._node_id, self._context.state.get(ref_str)) - continue - self._context.state.update(self._node_id, {left_ref_str : right}) + self._context.state.io_state.update(self._node_id, {left_ref_str : self._context.state.get(ref_str)}) + continue + self._context.state.io_state.update(self._node_id, {left_ref_str : right}) return None diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index 0767d0f..9d641db 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -4,6 +4,7 @@ from collections.abc import Callable import random +from jiuwen.core.component.branch_comp import BranchComponent from jiuwen.core.component.condition.array import ArrayCondition from jiuwen.core.component.loop_callback.intermediate_loop_var import IntermediateLoopVarCallback from jiuwen.core.component.loop_callback.output import OutputCallback @@ -110,6 +111,35 @@ class WorkflowTest(unittest.TestCase): self.assert_workflow_invoke({"a": 1, "b": "haha"}, create_context(), flow, checker=checker) self.assert_workflow_ainvoke({"a": 1, "b": "haha"}, create_context(), flow, checker=checker) + def test_workflow_with_branch(self): + context = create_context() + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start")) + flow.set_end_comp("end", MockEndNode("end"), + inputs_schema={"a": "${a.result}", "b": "${b.result}"}) + + sw = BranchComponent(context) + sw.add_branch("${user.inputs.a} <= 10", ["b"], "1") + sw.add_branch("${user.inputs.a} > 10", ["a"], "2") + + flow.add_workflow_comp("sw", sw) + + flow.add_workflow_comp("a", CommonNode("a"), + inputs_schema={"result": "${user.inputs.a}"}) + + flow.add_workflow_comp("b", AddTenNode("b"), + inputs_schema={"source": "${user.inputs.a}"}) + + flow.add_connection("start", "sw") + flow.add_connection("a", "end") + flow.add_connection("b", "end") + + result = flow.invoke({"a": 2}, context=context) + assert result["b"] == 12 + + result = flow.invoke({"a": 15}, context=context) + assert result["a"] == 15 + def test_workflow_with_loop(self): flow = create_flow() flow.set_start_comp("s", MockStartNode("s")) -- Gitee