diff --git a/jiuwen/core/component/branch_comp.py b/jiuwen/core/component/branch_comp.py index 12937dca3a0bdb718d46fbc4b173f8abd959049f..4f28de8cfb012bd02916beadcad4b7821cdcde15 100644 --- a/jiuwen/core/component/branch_comp.py +++ b/jiuwen/core/component/branch_comp.py @@ -6,12 +6,17 @@ from contextvars import Context from functools import partial from typing import Callable, Union, Hashable, Iterator, AsyncIterator +from jiuwen.core.component.base import WorkflowComponent from jiuwen.core.component.branch_router import BranchRouter from jiuwen.core.component.condition.condition import Condition from jiuwen.core.graph.base import Graph +from jiuwen.core.graph.executable import Executable, Input, Output class BranchComponent(WorkflowComponent, Executable): + def interrupt(self, message: dict): + pass + def __init__(self, context: Context, executable: Executable = None): self._router = BranchRouter(context) self._executable = executable @@ -29,8 +34,8 @@ class BranchComponent(WorkflowComponent, Executable): return self def invoke(self, inputs: Input, context: Context) -> Output: - if self._executable is None: - return self._executor.invoke(inputs, context) + if self._executable: + return self._executable.invoke(inputs, context) return inputs async def ainvoke(self, inputs: Input, context: Context) -> Output: @@ -43,6 +48,6 @@ class BranchComponent(WorkflowComponent, Executable): async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: yield await self.ainvoke(inputs, context) - def add_component(self, graph: Graph, node_id: str, wait_for_all: bool = False, inputs: dict = None): - graph.add_node(node_id, self.to_executable(), wait_for_all=wait_for_all, inputs=inputs) + def add_component(self, graph: Graph, node_id: str, wait_for_all: bool = False): + graph.add_node(node_id, self.to_executable(), wait_for_all=wait_for_all) graph.add_conditional_edges(node_id, self.router()) diff --git a/jiuwen/core/component/branch_router.py b/jiuwen/core/component/branch_router.py index 6b25d7d3efb546b02bfdd94515642f3ac6adaff5..0c50e66f8ceb3e328cad8399a8742a5e201062c4 100644 --- a/jiuwen/core/component/branch_router.py +++ b/jiuwen/core/component/branch_router.py @@ -5,6 +5,7 @@ from typing import Callable, Union from jiuwen.core.component.condition.condition import Condition, FuncCondition from jiuwen.core.component.condition.expression import ExpressionCondition +from jiuwen.core.context.context import Context class Branch: diff --git a/jiuwen/core/component/break_comp.py b/jiuwen/core/component/break_comp.py index e0c159e99ca0dab70ea292a75e5847c6b301a8fd..e191288d7386000431f43ecaa2122c30bf994bb2 100644 --- a/jiuwen/core/component/break_comp.py +++ b/jiuwen/core/component/break_comp.py @@ -2,6 +2,11 @@ # -*- coding: UTF-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from abc import abstractmethod, ABC +from typing import Iterator, AsyncIterator + +from jiuwen.core.component.base import WorkflowComponent +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable, Input, Output class LoopController(ABC): @@ -15,9 +20,14 @@ class LoopController(ABC): class BreakComponent(WorkflowComponent, Executable): + def __init__(self): + super().__init__() self._loop_controller = None + def interrupt(self, message: dict): + pass + def set_controller(self, loop_controller: LoopController): self._loop_controller = loop_controller diff --git a/jiuwen/core/component/condition/array.py b/jiuwen/core/component/condition/array.py index 8f2400eadd0373ae1605d74ecc044fed2b528101..178949fdc6ed9bb4c7e0012c40bd8e24b72c6ed9 100644 --- a/jiuwen/core/component/condition/array.py +++ b/jiuwen/core/component/condition/array.py @@ -3,8 +3,9 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from typing import Union, Any -from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR from jiuwen.core.component.condition.condition import Condition +from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import extract_origin_key, NESTED_PATH_SPLIT DEFAULT_MAX_LOOP_NUMBER = 1000 @@ -13,27 +14,28 @@ class ArrayCondition(Condition): def __init__(self, context: Context, node_id: str, arrays: dict[str, Union[str, list[Any]]], index_path: str = None, array_root: str = None): self._context = context + self._node_id = node_id self._arrays = arrays - self._index_path = index_path if index_path else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "index" - self._arrays_root = array_root if array_root else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "arrLoopVar" + self._index_path = index_path if index_path else node_id + NESTED_PATH_SPLIT + "index" + self._arrays_root = array_root if array_root else node_id + NESTED_PATH_SPLIT + "arrLoopVar" def init(self): - self._context.store.write(self._index_path, 0) - self._context.store.write(self._arrays_root, {}) + self._context.state.io_state.update(self._node_id, {self._index_path: 0}) + self._context.state.io_state.update(self._node_id, {self._arrays_root: {}}) def __call__(self) -> bool: - current_idx = self._context.store.read(self._index_path) + current_idx = self._context.state.get(self._index_path) min_length = DEFAULT_MAX_LOOP_NUMBER updates: dict[str, Any] = {} for key, array_info in self._arrays.items(): - key_path = self._arrays_root + BPMN_VARIABLE_POOL_SEPARATOR + key + key_path = self._arrays_root + NESTED_PATH_SPLIT + key arr: list[Any] = [] if isinstance(array_info, list): arr = array_info elif isinstance(array_info, str): - ref_str = get_ref_str(array_info) + ref_str = extract_origin_key(array_info) if ref_str != "": - arr = self._context.store.read(ref_str) + arr = self._context.state.get(ref_str) else: raise RuntimeError("error value: " + array_info + " is not a array path") else: @@ -43,8 +45,6 @@ class ArrayCondition(Condition): return False updates[key_path] = arr[current_idx] - self._context.store.write(self._index_path, current_idx + 1) - for path, update in updates.items(): - self._context.store.write(path, update) - + self._context.state.io_state.update(self._node_id, {self._index_path: current_idx + 1}) + self._context.state.io_state.update(self._node_id, updates) return True diff --git a/jiuwen/core/component/condition/expression.py b/jiuwen/core/component/condition/expression.py index 2c44265f696e014ddde2135db08bf2c3ea5dbbf4..c770eda76079572efaff8dbc6cc48d35b9137e5b 100644 --- a/jiuwen/core/component/condition/expression.py +++ b/jiuwen/core/component/condition/expression.py @@ -4,6 +4,7 @@ import re from jiuwen.core.component.condition.condition import Condition +from jiuwen.core.context.context import Context class ExpressionCondition(Condition): @@ -19,7 +20,7 @@ class ExpressionCondition(Condition): matches = re.findall(pattern, self._expression) inputs = {} for match in matches: - inputs[match] = self._context.store.read(match[2:-1]) + inputs[match] = self._context.state.get(match[2:-1]) return self._evalueate_expression(self._expression, inputs) def _evalueate_expression(self, expression, inputs) -> bool: diff --git a/jiuwen/core/component/condition/number.py b/jiuwen/core/component/condition/number.py index 412c59f22899e9ad0a6a700dd4b156b4aa47c9ae..b1507844759d451e8715102ed54a514439aca5c4 100644 --- a/jiuwen/core/component/condition/number.py +++ b/jiuwen/core/component/condition/number.py @@ -5,6 +5,7 @@ from typing import Union from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR from jiuwen.core.component.condition.condition import Condition +from jiuwen.core.context.context import Context class NumberCondition(Condition): @@ -12,20 +13,21 @@ class NumberCondition(Condition): self._context = context self._index_path = index_path if index_path else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "index" self._limit = limit + self._node_id = node_id def init(self): - self._context.store.write(self._index_path, 0) + self._context.state.io_state.update(self._node_id, {self._index_path: 0}) def __call__(self) -> bool: - current_idx = self._context.store.read(self._index_path) + current_idx = self._context.state.get(self._index_path) limit_num: int if isinstance(self._limit, int): limit_num = self._limit else: - limit_num = self._context.store.read(self._limit) + limit_num = self._context.state.get(self._limit) result = current_idx < limit_num if result: - self._context.store.write(self._index_path, current_idx + 1) + self._context.state.io_state.update(self._node_id, {self._index_path: current_idx + 1}) return result diff --git a/jiuwen/core/component/loop_callback/intermediate_loop_var.py b/jiuwen/core/component/loop_callback/intermediate_loop_var.py index e0d117ed779ffe91cca4732ca3675a5e415cb968..88c3cc4f8f3fedae9bae93510986961cc2552628 100644 --- a/jiuwen/core/component/loop_callback/intermediate_loop_var.py +++ b/jiuwen/core/component/loop_callback/intermediate_loop_var.py @@ -3,34 +3,36 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from typing import Union, Any -from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR from jiuwen.core.component.loop_callback.loop_callback import LoopCallback +from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import NESTED_PATH_SPLIT, is_ref_path, extract_origin_key class IntermediateLoopVarCallback(LoopCallback): def __init__(self, context: Context, node_id: str, intermediate_loop_var: dict[str, Union[str, Any]], intermediate_loop_var_root: str = None): self._context = context + self._node_id = node_id self._intermediate_loop_var = intermediate_loop_var self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root \ - else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "intermediateLoopVar" + else node_id + NESTED_PATH_SPLIT + "intermediateLoopVar" def first_in_loop(self): for key, value in self._intermediate_loop_var.items(): - path = self._intermediate_loop_var_root + BPMN_VARIABLE_POOL_SEPARATOR + key + path = self._intermediate_loop_var_root + NESTED_PATH_SPLIT + key updates: Any if isinstance(value, str): - ref_str = get_ref_str(value) - if ref_str != "": - update = self._context.store.read(ref_str) + if is_ref_path(value): + ref_str = extract_origin_key(value) + update = self._context.state.get(ref_str) else: update = value else: update = value - self._context.store.write(path, update) + self._context.state.io_state.update(self._node_id, {path: update}) def out_loop(self): - self._context.store.write(self._intermediate_loop_var_root, {}) + self._context.state.io_state.update(self._node_id, {self._intermediate_loop_var_root: {}}) def start_round(self): pass diff --git a/jiuwen/core/component/loop_callback/output.py b/jiuwen/core/component/loop_callback/output.py index 0f8f90f4673f76cdb51ca7e4a4bd736faa5aa6ad..d9b46e6cce853f56de9b4242894b1bcab63e8ff4 100644 --- a/jiuwen/core/component/loop_callback/output.py +++ b/jiuwen/core/component/loop_callback/output.py @@ -3,55 +3,60 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from typing import Any -from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR from jiuwen.core.component.loop_callback.loop_callback import LoopCallback +from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import is_ref_path, extract_origin_key, NESTED_PATH_SPLIT class OutputCallback(LoopCallback): def __init__(self, context: Context, node_id: str, outputs_format: dict[str, Any], round_result_root: str = None, result_root: str = None, intermediate_loop_var_root: str = None): + self._node_id = node_id self._context = context self._outputs_format = outputs_format - self._round_result_root = round_result_root if round_result_root else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "round" + self._round_result_root = round_result_root if round_result_root else node_id + NESTED_PATH_SPLIT + "round" self._result_root = result_root if result_root else node_id - self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root else node_id + BPMN_VARIABLE_POOL_SEPARATOR + "intermediateLoopVar" + self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root else node_id + NESTED_PATH_SPLIT + "intermediateLoopVar" - def _generate_results(self, results: dict[str, list[Any]]): + def _generate_results(self, results: list[(str, Any)]): for key, value in self._outputs_format.items(): - if isinstance(value, str): - ref_str = get_ref_str(value) - if ref_str != "": - results[ref_str] = [] + if isinstance(value, str) and is_ref_path(value): + ref_str = extract_origin_key(value) + results.append ((ref_str, None)) elif isinstance(value, dict): self._generate_results(results) def first_in_loop(self): - _results: dict[str, list[Any]] = {} + _results: list[(str, Any)] = [] self._generate_results(_results) - self._context.store.write(self._round_result_root, _results) + self._context.state.update(self._node_id, {self._round_result_root: _results}) def out_loop(self): - results: dict[str, list[Any]] = self._context.store.read(self._round_result_root) - if not isinstance(results, dict): + results: list[(str, Any)] = self._context.state.get(self._round_result_root) + if not isinstance(results, list): raise RuntimeError("error results in loop process") - for path, array in results.items(): - self._context.store.write(path, array) - result = filter_input(self._outputs_format, self._context.store) - self._context.store.write(self._round_result_root, {}) - set_output(result, self._result_root, self._context.store) + for (path, value) in results: + self._context.state.io_state.update(self._node_id, {path: value}) + self._context.state.io_state.commit() + result = self._context.state.get_inputs(self._outputs_format) + self._context.state.update(self._node_id, {self._round_result_root : {}}) + self._context.state.set_outputs(self._node_id, result) def start_round(self): pass def end_round(self): - results: dict[str, list[Any]] = self._context.store.read(self._round_result_root) - if not isinstance(results, dict): + results: list[(str, Any)] = self._context.state.get(self._round_result_root) + if not isinstance(results, list): raise RuntimeError("error results in round process") - for path, value in results.items(): + for value in results: + path = value[0] if path.startswith(self._intermediate_loop_var_root): - results[path] = self._context.store.read(path) + value[1] = self._context.state.get(path) elif isinstance(value, list): - value.append(self._context.store.read(path)) + if value[1] is None: + value[1] = [] + value[1].append(self._context.state.get(path)) else: raise RuntimeError("error process in loop: " + path + ", " + str(value)) - self._context.store.write(self._round_result_root, results) + self._context.state.update(self._node_id, {self._round_result_root : results}) diff --git a/jiuwen/core/component/loop_comp.py b/jiuwen/core/component/loop_comp.py index 6c041929cbfccc01c0b1555a747d01c46134f5f6..cf609a83040a8adfe5c605c95f520130d290325a 100644 --- a/jiuwen/core/component/loop_comp.py +++ b/jiuwen/core/component/loop_comp.py @@ -7,12 +7,15 @@ from typing import Iterator, AsyncIterator, Self, Union, Callable from langgraph.constants import END, START -from jiuwen.core.common.constants.constant import BPMN_VARIABLE_POOL_SEPARATOR -from jiuwen.core.component.break_comp import BreakComponent +from jiuwen.core.component.base import WorkflowComponent +from jiuwen.core.component.break_comp import BreakComponent, LoopController from jiuwen.core.component.condition.condition import Condition, AlwaysTrue, FuncCondition from jiuwen.core.component.condition.expression import ExpressionCondition from jiuwen.core.component.loop_callback.loop_callback import LoopCallback -from jiuwen.core.graph.base import Graph +from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import NESTED_PATH_SPLIT +from jiuwen.core.graph.base import Graph, Router, ExecutableGraph +from jiuwen.core.graph.executable import Output, Input, Executable from jiuwen.graph.factory import GraphFactory @@ -31,6 +34,9 @@ class EmptyExecutable(Executable): async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: yield await self.ainvoke(inputs, context) + def interrupt(self, message: dict): + return + class LoopGroup: def __init__(self, context: Context, graph: Graph = None): @@ -38,8 +44,9 @@ class LoopGroup: self._graph = graph if graph else GraphFactory().create_graph() def add_component(self, node_id: str, component: WorkflowComponent, *, wait_for_all: bool = False, - inputs: dict = None) -> Self: - component.add_component(self._graph, node_id, inputs, wait_for_all=wait_for_all) + inputs_schema: dict = None, outputs_schema: dict = None) -> Self: + component.add_component(self._graph, node_id, wait_for_all=wait_for_all) + self._context.config.set_io_schema(node_id, (inputs_schema, outputs_schema)) return self def start_nodes(self, nodes: list[str]) -> Self: @@ -52,7 +59,7 @@ class LoopGroup: self._graph.end_node(node) return self - def add_condition(self, start_node_id: Union[str, list[str]], end_node_id: str) -> Self: + def add_connection(self, start_node_id: Union[str, list[str]], end_node_id: str) -> Self: self._graph.add_edge(start_node_id, end_node_id) return self @@ -61,7 +68,7 @@ class LoopGroup: return self def compile(self) -> ExecutableGraph: - return self._graph.compile() + return self._graph.compile(self._context) BROKEN = "_broken" @@ -71,10 +78,10 @@ FIRST_IN_LOOP = "_first_in_loop" class LoopComponent(WorkflowComponent, LoopController): def __init__(self, context: Context, node_id: str, body: Union[Executable, LoopGroup], - condition: Union[str, Callable[[], bool], Condition] = None, - context_root: str = None, - break_nodes: list[BreakComponent] = None, - callbacks: list[LoopCallback] = None, graph: Graph = None): + condition: Union[str, Callable[[], bool], Condition] = None, context_root: str = None, + break_nodes: list[BreakComponent] = None, callbacks: list[LoopCallback] = None, graph: Graph = None): + super().__init__() + self._node_id = node_id if context is None: raise ValueError("context cannot be None") if context_root is None: @@ -96,7 +103,7 @@ class LoopComponent(WorkflowComponent, LoopController): if break_nodes: for break_node in break_nodes: - break_node.set_condition(self) + break_node.set_controller(self) if callbacks: for callback in callbacks: @@ -117,11 +124,11 @@ class LoopComponent(WorkflowComponent, LoopController): self._graph.add_conditional_edges(condition_node_id, self) self.init() - self._compiled = self._graph.compile() + self._compiled = self._graph.compile(self._context) def init(self): - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + BROKEN, False) - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + FIRST_IN_LOOP, True) + self._context.state.update(self._node_id, {self._context_root + NESTED_PATH_SPLIT + BROKEN: False}) + self._context.state.update(self._node_id, {self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: True}) self._condition.init() def to_executable(self) -> Executable: @@ -144,6 +151,8 @@ class LoopComponent(WorkflowComponent, LoopController): else: callback.out_loop() + self._context.state.io_state.commit() + self._context.state.global_state.commit() if continue_loop: return self._in_loop @@ -151,19 +160,20 @@ class LoopComponent(WorkflowComponent, LoopController): return self._out_loop def first_in_loop(self) -> bool: - _first_in_loop = self._context.store.read(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + FIRST_IN_LOOP) + _first_in_loop = self._context.state.get(self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP) if isinstance(_first_in_loop, bool): if _first_in_loop: - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + FIRST_IN_LOOP, False) + self._context.state.update(self._node_id, + {self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: False}) return _first_in_loop - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + FIRST_IN_LOOP, False) + self._context.state.update(self._node_id, {self._context_root + NESTED_PATH_SPLIT + FIRST_IN_LOOP: False}) return True def is_broken(self) -> bool: - _is_broken = self._context.store.read(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + BROKEN) + _is_broken = self._context.state.get(self._context_root + NESTED_PATH_SPLIT + BROKEN) if isinstance(_is_broken, bool): return _is_broken return False def break_loop(self): - self._context.store.write(self._context_root + BPMN_VARIABLE_POOL_SEPARATOR + BROKEN, True) + self._context.state.update(self._node_id, {self._context_root + NESTED_PATH_SPLIT + BROKEN: True}) diff --git a/jiuwen/core/component/set_variable_comp.py b/jiuwen/core/component/set_variable_comp.py new file mode 100644 index 0000000000000000000000000000000000000000..fdcbdbbc6c9f32df15280607ef038cafc13f1ed0 --- /dev/null +++ b/jiuwen/core/component/set_variable_comp.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import asyncio +from functools import partial +from typing import AsyncIterator, Iterator, Any + +from jiuwen.core.component.base import WorkflowComponent +from jiuwen.core.context.context import Context +from jiuwen.core.context.utils import extract_origin_key, is_ref_path +from jiuwen.core.graph.executable import Executable, Input, Output + + +class SetVariableComponent(WorkflowComponent, Executable): + def __init__(self, node_id: str, context: Context, variable_mapping: dict[str, Any]): + self._context = context + self._node_id = node_id + self._variable_mapping = variable_mapping + + def invoke(self, inputs: Input, context: Context) -> Output: + for left, right in self._variable_mapping.items(): + left_ref_str = extract_origin_key(left) + if left_ref_str == "": + left_ref_str = left + if isinstance(right, str) and is_ref_path(right): + ref_str = extract_origin_key(right) + self._context.state.io_state.update(self._node_id, {left_ref_str : self._context.state.get(ref_str)}) + continue + self._context.state.io_state.update(self._node_id, {left_ref_str : right}) + + return None + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.invoke, context), inputs + ) + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + yield self.invoke(inputs, context) + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + yield await self.ainvoke(inputs, context) + + def interrupt(self, message: dict): + pass + + def to_executable(self) -> Executable: + return self \ No newline at end of file diff --git a/jiuwen/core/context/context.py b/jiuwen/core/context/context.py index cf1fc69b77dc480563f978e49eb73cd8a47a6b90..0da9a2f60998919adac58cfd32dbd271e76c1ccd 100644 --- a/jiuwen/core/context/context.py +++ b/jiuwen/core/context/context.py @@ -28,6 +28,7 @@ class Context(ABC): return False self._workflow_config = workflow_config self._stream_modes = stream_modes + return True @property def config(self) -> Config: diff --git a/jiuwen/core/context/memory/base.py b/jiuwen/core/context/memory/base.py index 893a0e9a9734f0f1828ea5ede8dc0bada56c6c17..b1c767a2ddd691518c469795896bb059b818dfd5 100644 --- a/jiuwen/core/context/memory/base.py +++ b/jiuwen/core/context/memory/base.py @@ -4,12 +4,11 @@ from copy import deepcopy from typing import Union, Optional, Any, Callable -from jiuwen.core.common.exception.exception import JiuWenBaseException from jiuwen.core.context.state import CommitState, StateLike, State from jiuwen.core.context.utils import extract_origin_key, get_value_by_nested_path, update_dict -class InMemoryState(StateLike): +class InMemoryStateLike(StateLike): def __init__(self): self._state: dict = dict() @@ -28,18 +27,18 @@ class InMemoryState(StateLike): result.append(self.get(item)) return result else: - raise JiuWenBaseException(1, "key type is not support") + return key def get_by_transformer(self, transformer: Callable) -> Optional[Any]: return transformer(self._state) def update(self, node_id: str, data: dict) -> None: - update_dict(self._state, data) + update_dict(data, self._state) class InMemoryCommitState(CommitState): def __init__(self): - self._state = InMemoryState() + self._state = InMemoryStateLike() self._updates: dict[str, list[dict]] = dict() def update(self, node_id: str, data: dict) -> None: @@ -72,7 +71,7 @@ class InMemoryState(State): def __init__(self): super().__init__(io_state=InMemoryCommitState(), global_state=InMemoryCommitState(), - trace_state=InMemoryState(), - comp_state=InMemoryState()) + trace_state=InMemoryStateLike(), + comp_state=InMemoryStateLike()) diff --git a/jiuwen/core/context/state.py b/jiuwen/core/context/state.py index beb6e0097c395aa37166df96674c6fc5b1ba3c99..a181855f61f0660a15ee706b171d9d8689d32cba 100644 --- a/jiuwen/core/context/state.py +++ b/jiuwen/core/context/state.py @@ -63,7 +63,10 @@ class State(ABC): def get(self, key: Union[str, dict]) -> Optional[Any]: if self._global_state is None: return None - return self._global_state.get(key) + value = self._global_state.get(key) + if value is None: + return self._io_state.get(key) + return value def update(self, node_id: str, data: dict) -> None: if self._global_state is None: @@ -85,10 +88,10 @@ class State(ABC): return self._comp_state.update(node_id, data) - def set_user_inputs(self, inputs: dict) -> None: + def set_user_inputs(self, inputs: Any) -> None: if self._io_state is None: return - self._io_state.update({"user": {"inputs": inputs}}) + self._io_state.update("user", {"user.inputs": inputs}) def get_inputs(self, input_schemas: dict) -> dict: if self._io_state is None: @@ -101,7 +104,7 @@ class State(ABC): return self._io_state.get(node_id) def set_outputs(self, node_id: str, outputs: dict) -> None: - if self._io_state is None: + if self._io_state is None or outputs is None: return - return self._io_state.update(node_id, outputs) + return self._io_state.update(node_id, {node_id: outputs}) diff --git a/jiuwen/core/context/utils.py b/jiuwen/core/context/utils.py index c01b43c4aef761ac3382b7790dabf3986f983dbe..a3173786d4fe2e8404000d3d8a81fd5bb589b68c 100644 --- a/jiuwen/core/context/utils.py +++ b/jiuwen/core/context/utils.py @@ -27,8 +27,10 @@ def update_dict(update: dict, source: dict) -> None: def get_value_by_nested_path(nested_key: str, source: dict) -> Optional[Any]: result = root_to_path(nested_key, source) - if result is None: - return result + if result[1] is None: + return None + if result[0] not in result[1]: + return None return result[1][result[0]] @@ -58,6 +60,8 @@ def split_nested_path(nested_key: str) -> list: final_list.append(match.group(1)) return final_list +def is_ref_path(path: str) -> bool: + return len(path) > 3 and path.startswith("${") and path.endswith("}") def extract_origin_key(key: str) -> str: """ diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index f853d2b59f6d66181e5a94f3fcab20ad88e357fd..c9bc19f50319d612ab05fe3b4906adc882638098 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -19,9 +19,9 @@ class Vertex: return True def __call__(self, state: GraphState) -> Output: - if self._context is None: + if self._context is None or self._executable is None: raise JiuWenBaseException(1, "vertex is not initialized, node is is " + self._node_id) - inputs = self.__pre_invoke__(state) + inputs = self.__pre_invoke__() is_stream = self.__is_stream__(state) try: if is_stream: @@ -34,15 +34,17 @@ class Vertex: raise JiuWenBaseException(e.error_code, "failed to invoke, caused by " + e.message) return {"source_node_id": self._node_id} - def __pre_invoke__(self, state:GraphState) -> Optional[dict]: + def __pre_invoke__(self) -> Optional[dict]: inputs_schema = self._context.config.get_inputs_schema(self._node_id) - inputs = self._context.state.get_inputs(inputs_schema) + inputs = self._context.state.get_inputs(inputs_schema) if inputs_schema else None if self._context.tracer is not None: self.__trace_inputs__(inputs) return inputs def __post_invoke__(self, results: Optional[dict]) -> None: - self._context.state.set_outputs(results) + self._context.state.set_outputs(self._node_id, results) + self._context.state.io_state.commit() + self._context.state.global_state.commit() pass def __post_stream__(self, results_iter: Any) -> None: diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index 5686512e59131e41bd9ed27f289af6f75919c3ff..76973b9a508c8ca1205860d674716020017838d5 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -17,19 +17,18 @@ from jiuwen.core.graph.executable import Executable, Input, Output class WorkflowConfig(BaseModel): - metadata: BaseModel = Field(default=None) + metadata: BaseModel class WorkflowOutput(BaseModel): - result: str = Field(default="") + result: str class WorkflowChunk(BaseModel): - chunk_id: str = Field(default="") - payload: str = Field(default="") - metadata: Dict[str, Any] = Field(default_factory=dict) - is_final: bool = Field(default=False) - + chunk_id: str + payload: str + metadata: Dict[str, Any] + is_final: bool class Workflow: @@ -67,7 +66,7 @@ class Workflow: self._comp_io_schemas[start_comp_id] = (inputs_schema, output_schema) return self - def set_end_component( + def set_end_comp( self, end_comp_id: str, component: EndComponent, @@ -81,11 +80,11 @@ class Workflow: return self def add_connection(self, src_comp_id: str, target_comp_id: str) -> Self: - self._graph.add_edge(source_node_id=src_comp_id, target_node_id=target_comp_id) + self._graph.add_edge(src_comp_id, target_comp_id) return self def add_stream_connection(self, src_comp_id: str, target_comp_id: str) -> Self: - self._graph.add_edge(source_node_id=src_comp_id, target_node_id=target_comp_id) + self._graph.add_edge(src_comp_id, target_comp_id) if target_comp_id not in self._stream_edges: self._stream_edges[src_comp_id] = [target_comp_id] else: @@ -102,12 +101,13 @@ class Workflow: return None compiled_graph = self._graph.compile(context) context.state.set_user_inputs(inputs) + context.state.io_state.commit() compiled_graph.invoke(inputs, context) return context.state.get_outputs(self._end_comp_id) async def ainvoke(self, inputs: Input, context: Context) -> Output: return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, context), inputs + None, partial(self.invoke, context = context), inputs ) def stream( diff --git a/jiuwen/graph/pregel/graph.py b/jiuwen/graph/pregel/graph.py index b1459199c5cf60482af7ba0896d2ec12673a82fe..32c8a1cbedfb7e06300f705cfafd261f139b94e6 100644 --- a/jiuwen/graph/pregel/graph.py +++ b/jiuwen/graph/pregel/graph.py @@ -4,18 +4,22 @@ from typing import Union, Self, Iterator, AsyncIterator from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph -from jiuwen.core.graph.base import Graph -from jiuwen.core.graph.state import State +from jiuwen.core.context.context import Context +from jiuwen.core.graph.base import Graph, Router, ExecutableGraph +from jiuwen.core.graph.executable import Executable, Input, Output +from jiuwen.core.graph.graph_state import GraphState from jiuwen.core.graph.vertex import Vertex class PregelGraph(Graph): def __init__(self): - self.pregel: StateGraph = StateGraph(State) + self.pregel: StateGraph = StateGraph(GraphState) self.compiledStateGraph = None self.edges: list[Union[str, list[str]], str] = [] self.waits: set[str] = set() + self.nodes: list[Vertex] = [] def start_node(self, node_id: str) -> Self: self.pregel.set_entry_point(node_id) @@ -25,21 +29,25 @@ class PregelGraph(Graph): self.pregel.set_finish_point(node_id) return self - def add_node(self, node_id: str, node: Executable, *, wait_for_all: bool = False, inputs: dict = None) -> Self: - self.pregel.add_node(node_id, Vertex(node_id, node, inputs)) + def add_node(self, node_id: str, node: Executable, *, wait_for_all: bool = False) -> Self: + vertex_node = Vertex(node_id, node) + self.nodes.append(vertex_node) + self.pregel.add_node(node_id, vertex_node) if wait_for_all: self.waits.add(node_id) return self - def add_edge(self, start_node_id: Union[str, list[str]], end_node_id: str) -> Self: - self.edges.append((start_node_id, end_node_id)) + def add_edge(self, source_node_id: Union[str, list[str]], target_node_id: str) -> Self: + self.edges.append((source_node_id, target_node_id)) return self - def add_conditional_edges(self, source: str, router: Router) -> Self: - self.pregel.add_conditional_edges(source, router) + def add_conditional_edges(self, source_node_id: str, router: Router) -> Self: + self.pregel.add_conditional_edges(source_node_id, router) return self - def compile(self) -> ExecutableGraph: + def compile(self, context: Context) -> ExecutableGraph: + for node in self.nodes: + node.init(context) if self.compiledStateGraph is None: self._pre_compile() self.compiledStateGraph = self.pregel.compile() @@ -65,17 +73,20 @@ class PregelGraph(Graph): class CompiledGraph(ExecutableGraph): - def __init__(self, compiledStateGraph: CompiledGraph): + def __init__(self, compiledStateGraph: CompiledStateGraph): self._compiledStateGraph = compiledStateGraph def invoke(self, inputs: Input, context: Context) -> Output: - return self._compiledStateGraph.invoke({'context': context}) + return self._compiledStateGraph.invoke({"source_node_id": None}) def stream(self, inputs: Input, context: Context) -> Iterator[Output]: - return self._compiledStateGraph.stream({'context': context}) + return self._compiledStateGraph.stream({"source_node_id": None}) async def ainvoke(self, inputs: Input, context: Context) -> Output: - return self._compiledStateGraph.ainvoke({'context': context}) + return self._compiledStateGraph.ainvoke({"source_node_id": None}) async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: - return self._compiledStateGraph.astream({'context': context}) + return self._compiledStateGraph.astream({"source_node_id": None}) + + def interrupt(self, message: dict): + return diff --git a/pyproject.toml b/pyproject.toml index 9bc920fdf0f19380c424e8a81966e7b472335b2c..ad584ebe390351390b4f863e8624d49af761b013 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "langgraph==0.2.35", "mcp==1.7.1", "pydantic==2.10.6", + "sqlalchemy>=2.0.41", ] [tool.uv] @@ -29,4 +30,4 @@ dev = [ ] [tool.coverage.run] -omit = ["tests/*"] \ No newline at end of file +omit = ["tests/*"] diff --git a/tests/unit_tests/workflow/test_mock_node.py b/tests/unit_tests/workflow/test_mock_node.py new file mode 100644 index 0000000000000000000000000000000000000000..311be9ffe0168af84ed3d7db97f8917463603b0a --- /dev/null +++ b/tests/unit_tests/workflow/test_mock_node.py @@ -0,0 +1,58 @@ +from typing import Iterator, AsyncIterator, Callable + +from jiuwen.core.component.base import WorkflowComponent, StartComponent, EndComponent +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable, Input, Output + + +class MockNodeBase(Executable, WorkflowComponent): + def invoke(self, inputs: Input, context: Context) -> Output: + pass + + def __init__(self, node_id: str): + super().__init__() + self.node_id = node_id + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + yield self.invoke(inputs, context) + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + yield await self.invoke(inputs, context) + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + yield await self.invoke(inputs, context) + + def interrupt(self, message: dict): + return + + def to_executable(self) -> Executable: + return self + +class MockStartNode(StartComponent, MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + + def invoke(self, inputs: Input, context: Context) -> Output: + context.state.set_outputs(self.node_id, inputs) + print("start: output = " + str(inputs)) + return inputs + +class MockEndNode(EndComponent, MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + self.node_id = node_id + + def invoke(self, inputs: Input, context: Context) -> Output: + context.state.set_outputs(self.node_id, inputs) + print("endNode: output = " + str(inputs)) + return inputs + + +class Node1(MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + + def invoke(self, inputs: Input, context: Context) -> Output: + context.state.set_outputs(self.node_id, inputs) + print("node1: output = " + str(inputs)) + return inputs \ No newline at end of file diff --git a/tests/unit_tests/workflow/test_node.py b/tests/unit_tests/workflow/test_node.py new file mode 100644 index 0000000000000000000000000000000000000000..52dcc8bf18c56e732fe3137e1671f0afed4973d0 --- /dev/null +++ b/tests/unit_tests/workflow/test_node.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +from typing import AsyncIterator, Iterator + +from jiuwen.core.component.base import WorkflowComponent +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable, Input, Output + + +class CommonNode(Executable, WorkflowComponent): + + def __init__(self, node_id: str): + super().__init__() + self.node_id = node_id + + def invoke(self, inputs: Input, context: Context) -> Output: + return inputs + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + pass + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + yield self.invoke(inputs, context) + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + yield await self.ainvoke(inputs, context) + + def interrupt(self, message: dict): + pass + + def to_executable(self) -> Executable: + return self + +class AddTenNode(Executable, WorkflowComponent): + + def __init__(self, node_id: str): + super().__init__() + self.node_id = node_id + + def invoke(self, inputs: Input, context: Context) -> Output: + return {"result": inputs["source"] + 10} + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + pass + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + yield self.invoke(inputs, context) + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + yield await self.ainvoke(inputs, context) + + def interrupt(self, message: dict): + pass + + def to_executable(self) -> Executable: + return self \ No newline at end of file diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..9d641db7fb1ea3d9c20827dd4eaed5006e7fbc34 --- /dev/null +++ b/tests/unit_tests/workflow/test_workflow.py @@ -0,0 +1,184 @@ +import asyncio +import unittest +from collections.abc import Callable + +import random + +from jiuwen.core.component.branch_comp import BranchComponent +from jiuwen.core.component.condition.array import ArrayCondition +from jiuwen.core.component.loop_callback.intermediate_loop_var import IntermediateLoopVarCallback +from jiuwen.core.component.loop_callback.output import OutputCallback +from jiuwen.core.component.loop_comp import LoopGroup, LoopComponent +from jiuwen.core.component.set_variable_comp import SetVariableComponent +from jiuwen.core.context.config import Config +from jiuwen.core.context.context import Context +from jiuwen.core.context.memory.base import InMemoryState +from jiuwen.core.graph.base import Graph +from jiuwen.core.graph.graph_state import GraphState +from jiuwen.core.workflow.base import WorkflowConfig, Workflow +from jiuwen.graph.pregel.graph import PregelGraph +from test_node import AddTenNode, CommonNode +from tests.unit_tests.workflow.test_mock_node import MockNodeBase, MockStartNode, MockEndNode, Node1 + + +def create_context() -> Context: + return Context(config=Config(), state=InMemoryState(), store=None, tracer=None) + + +def create_graph() -> Graph: + return PregelGraph() + + +def create_flow() -> Workflow: + return Workflow(workflow_config=DEFAULT_WORKFLOW_CONFIG, graph=create_graph()) + + +DEFAULT_WORKFLOW_CONFIG = WorkflowConfig(metadata={}) + + +class WorkflowTest(unittest.TestCase): + def assert_workflow_invoke(self, inputs: dict, context: Context, flow: Workflow, expect_results: dict = None, + checker: Callable = None): + results = flow.invoke(inputs=inputs, context=context) + if expect_results is not None: + assert results == expect_results + elif checker is not None: + checker(results) + + def assert_workflow_ainvoke(self, inputs: dict, context: Context, flow: Workflow, expect_results: dict = None, + checker: Callable = None): + loop = asyncio.get_event_loop() + feature = asyncio.ensure_future(flow.ainvoke(inputs=inputs, context=context)) + loop.run_until_complete(feature) + if expect_results is not None: + assert feature.result() == expect_results + elif checker is not None: + checker(feature.result()) + + def test_simple_workflow(self): + """ + graph : start->a->end + """ + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start"), + inputs_schema={ + "a": "${user.inputs.a}", + "b": "${user.inputs.b}", + "c": 1, + "d": [1, 2, 3]}) + flow.add_workflow_comp("a", Node1("a"), + inputs_schema={ + "aa": "${start.a}", + "ac": "${start.c}"}) + flow.set_end_comp("end", MockEndNode("end"), + inputs_schema={ + "result": "${a.aa}"}) + flow.add_connection("start", "a") + flow.add_connection("a", "end") + self.assert_workflow_invoke({"a": 1, "b": "haha"}, create_context(), flow, expect_results={"result": 1}) + self.assert_workflow_ainvoke({"a": 1, "b": "haha"}, create_context(), flow, expect_results={"result": 1}) + + def test_simple_workflow_with_condition(self): + """ + start -> condition[a,b] -> end + :return: + """ + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start"), + inputs_schema={"a": "${user.inputs.a}", + "b": "${user.inputs.b}", + "c": 1, + "d": [1, 2, 3]}) + + def router(state: GraphState): + condition_nodes = ["a", "b"] + randomIdx = random.randint(1, 2) + return condition_nodes[randomIdx - 1] + + flow.add_conditional_connection("start", router=router) + flow.add_workflow_comp("a", Node1("a"), inputs_schema={"a": "${start.a}", "b": "${start.c}"}) + flow.add_workflow_comp("b", Node1("b"), inputs_schema={"b": "${start.b}"}) + flow.set_end_comp("end", MockEndNode("end"), {"result1": "${a.a}", "result2": "${b.b}"}) + flow.add_connection("a", "end") + flow.add_connection("b", "end") + + def checker(results): + if "result1" in results: + assert results["result1"] == 1 + elif "result2" in results: + assert results["result2"] == "haha" + + self.assert_workflow_invoke({"a": 1, "b": "haha"}, create_context(), flow, checker=checker) + self.assert_workflow_ainvoke({"a": 1, "b": "haha"}, create_context(), flow, checker=checker) + + def test_workflow_with_branch(self): + context = create_context() + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start")) + flow.set_end_comp("end", MockEndNode("end"), + inputs_schema={"a": "${a.result}", "b": "${b.result}"}) + + sw = BranchComponent(context) + sw.add_branch("${user.inputs.a} <= 10", ["b"], "1") + sw.add_branch("${user.inputs.a} > 10", ["a"], "2") + + flow.add_workflow_comp("sw", sw) + + flow.add_workflow_comp("a", CommonNode("a"), + inputs_schema={"result": "${user.inputs.a}"}) + + flow.add_workflow_comp("b", AddTenNode("b"), + inputs_schema={"source": "${user.inputs.a}"}) + + flow.add_connection("start", "sw") + flow.add_connection("a", "end") + flow.add_connection("b", "end") + + result = flow.invoke({"a": 2}, context=context) + assert result["b"] == 12 + + result = flow.invoke({"a": 15}, context=context) + assert result["a"] == 15 + + def test_workflow_with_loop(self): + flow = create_flow() + flow.set_start_comp("s", MockStartNode("s")) + flow.set_end_comp("e", MockEndNode("e"), + inputs_schema={"array_result": "${b.array_result}", "user_var": "${b.user_var}"}) + flow.add_workflow_comp("a", CommonNode("a"), + inputs_schema={"array": "${user.inputs.input_array}"}) + flow.add_workflow_comp("b", CommonNode("b"), + inputs_schema={"array_result": "${l.results}", "user_var": "${l.user_var}"}) + + # create loop: (1->2->3) + context = create_context() + loop_group = LoopGroup(context) + loop_group.add_component("1", AddTenNode("1"), inputs_schema={"source": "${l.arrLoopVar.item}"}) + loop_group.add_component("2", AddTenNode("2"), inputs_schema={"source": "${l.intermediateLoopVar.user_var}"}) + loop_group.add_component("3", SetVariableComponent("3", context, + {"${l.intermediateLoopVar.user_var}": "${2.result}"})) + loop_group.start_nodes(["1"]) + loop_group.end_nodes(["3"]) + loop_group.add_connection("1", "2") + loop_group.add_connection("2", "3") + output_callback = OutputCallback(context, "l", + {"results": "${1.result}", "user_var": "${l.intermediateLoopVar.user_var}"}) + intermediate_callback = IntermediateLoopVarCallback(context, "l", + {"user_var": "${user.inputs.input_number}"}) + + loop = LoopComponent(context, "l", loop_group, ArrayCondition(context, "l", {"item": "${a.array}"}), + callbacks=[output_callback, intermediate_callback]) + + flow.add_workflow_comp("l", loop) + + # s->a->(1->2->3)->b->e + flow.add_connection("s", "a") + flow.add_connection("a", "l") + flow.add_connection("l", "b") + flow.add_connection("b", "e") + + result = flow.invoke({"input_array": [1, 2, 3], "input_number": 1}, context=context) + assert result == {"array_result": [11, 12, 13], "user_var": 31} + + result = flow.invoke({"input_array": [4, 5], "input_number": 2}, context=context) + assert result == {"array_result": [14, 15], "user_var": 22}