From 7c6f0983b9bbfa4c7e948853f5a233973a988126 Mon Sep 17 00:00:00 2001 From: Jiayuan Kang Date: Tue, 15 Jul 2025 14:10:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A4=A7=E6=A8=A1=E5=9E=8B=E7=BB=84?= =?UTF-8?q?=E4=BB=B6=E6=94=AF=E6=8C=81=E5=BC=82=E6=AD=A5=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 大模型组件支持异步调用 #ICLT7L:组件功能实现 --- jiuwen/core/common/configs/model_config.py | 2 + jiuwen/core/common/constants/constant.py | 4 + jiuwen/core/common/exception/status_code.py | 5 + jiuwen/core/common/utils/utils.py | 2 +- jiuwen/core/component/llm_comp.py | 55 ++++++---- tests/unit_tests/workflow/test_llm_comp.py | 108 ++++++++++++++++++++ 6 files changed, 157 insertions(+), 19 deletions(-) create mode 100644 jiuwen/core/common/constants/constant.py create mode 100644 tests/unit_tests/workflow/test_llm_comp.py diff --git a/jiuwen/core/common/configs/model_config.py b/jiuwen/core/common/configs/model_config.py index 0f91d7a..9a7965e 100644 --- a/jiuwen/core/common/configs/model_config.py +++ b/jiuwen/core/common/configs/model_config.py @@ -3,6 +3,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved from dataclasses import field, dataclass +from jiuwen.core.utils.llm.base import BaseModelInfo + @dataclass class ModelConfig: diff --git a/jiuwen/core/common/constants/constant.py b/jiuwen/core/common/constants/constant.py new file mode 100644 index 0000000..f3c7141 --- /dev/null +++ b/jiuwen/core/common/constants/constant.py @@ -0,0 +1,4 @@ +# IR userFields key +USER_FIELDS = "userFields" +# IR systemFields key +SYSTEM_FIELDS = "systemFields" diff --git a/jiuwen/core/common/exception/status_code.py b/jiuwen/core/common/exception/status_code.py index 4d00abb..6b42062 100644 --- a/jiuwen/core/common/exception/status_code.py +++ b/jiuwen/core/common/exception/status_code.py @@ -28,6 +28,11 @@ class StatusCode(Enum): PROMPT_TEMPLATE_DUPLICATED_ERROR = (102101, "Template duplicated") PROMPT_TEMPLATE_NOT_FOUND_ERROR = (102102, "Template not found") + # LLM组件 101561-101590 + WORKFLOW_LLM_INIT_ERROR = (101561, "LLM component initialization error, msg = {msg}") + WORKFLOW_LLM_TEMPLATE_ASSEMBLE_ERROR = (101562, "LLM component template assemble error") + WORKFLOW_LLM_STREAMING_OUTPUT_ERROR = (101563, "Get model streaming output error, msg = {msg}") + @property def code(self): return self.value[0] diff --git a/jiuwen/core/common/utils/utils.py b/jiuwen/core/common/utils/utils.py index 2101fcb..f08a04f 100644 --- a/jiuwen/core/common/utils/utils.py +++ b/jiuwen/core/common/utils/utils.py @@ -158,7 +158,7 @@ class OutputFormatter: formatters = { "text": OutputFormatter._format_text_response, "markdown": OutputFormatter._format_text_response, - "json": OutputFormatter._format_json.response + "json": OutputFormatter._format_json_response } formatter = formatters.get(response_type) diff --git a/jiuwen/core/component/llm_comp.py b/jiuwen/core/component/llm_comp.py index bef0600..6ec6544 100644 --- a/jiuwen/core/component/llm_comp.py +++ b/jiuwen/core/component/llm_comp.py @@ -5,12 +5,18 @@ from dataclasses import dataclass, field from typing import List, Any, Dict, Optional, Iterator, AsyncIterator from jiuwen.core.common.configs.model_config import ModelConfig +from jiuwen.core.common.constants.constant import USER_FIELDS from jiuwen.core.common.enum.enum import WorkflowLLMResponseType, MessageRole from jiuwen.core.common.exception.exception import JiuWenBaseException from jiuwen.core.common.exception.status_code import StatusCode +from jiuwen.core.common.utils.utils import WorkflowLLMUtils, OutputFormatter from jiuwen.core.component.base import ComponentConfig, WorkflowComponent from jiuwen.core.context.context import Context from jiuwen.core.graph.executable import Executable, Input, Output +from jiuwen.core.utils.llm.base import BaseChatModel +from jiuwen.core.utils.llm.model_utils.model_factory import ModelFactory +from jiuwen.core.utils.prompt.template.template import Template +from jiuwen.core.utils.prompt.template.template_manager import TemplateManager WORKFLOW_CHAT_HISTORY = "workflow_chat_history" CHAT_HISTORY_MAX_TURN = 3 @@ -35,6 +41,7 @@ RESPONSE_FORMAT_TO_PROMPT_MAP = { } } + @dataclass class LLMCompConfig(ComponentConfig): model: 'ModelConfig' = None @@ -54,7 +61,7 @@ class LLMExecutable(Executable): self._initialized: bool = False self._context = None - def invoke(self, inputs: Input, context: Context) -> Output: + async def invoke(self, inputs: Input, context: Context) -> Output: try: self._set_context(context) model_inputs = self._prepare_model_inputs(inputs) @@ -66,30 +73,37 @@ class LLMExecutable(Executable): raise JiuWenBaseException(error_code=StatusCode.WORKFLOW_LLM_INIT_ERROR.code, message=StatusCode.WORKFLOW_LLM_INIT_ERROR.errmsg.format(msg=str(e))) from e - async def ainvoke(self, inputs: Input, context: Context) -> Output: - pass - - def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + async def stream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: try: self._set_context(context) - response_format_type = self._get_response_format().get(_TYPE) - + if response_format_type == WorkflowLLMResponseType.JSON.value: - yield from self._invoke_for_json_format(inputs) + async for out in self._invoke_for_json_format(inputs): + yield out else: - yield from self._stream_with_chunks(inputs) + async for out in self._stream_with_chunks(inputs): + yield out except JiuWenBaseException: - raise + raise except Exception as e: raise JiuWenBaseException( error_code=StatusCode.WORKFLOW_LLM_STREAMING_OUTPUT_ERROR.code, message=StatusCode.WORKFLOW_LLM_STREAMING_OUTPUT_ERROR.errmsg.format(msg=str(e)) ) from e - async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + async def collect(self, inputs: AsyncIterator[Input], contex: Context) -> Output: + pass + + async def transform(self, inputs: AsyncIterator[Input], context: Context) -> AsyncIterator[Output]: pass + async def interrupt(self, message: dict): + raise InterruptException( + error_code=StatusCode.CONTROLLER_INTERRUPTED_ERROR.code, + message=json.dumps(message, ensure_ascii=False) + ) + def _initialize_if_needed(self): if not self._initialized: try: @@ -117,7 +131,7 @@ class LLMExecutable(Executable): full_input = "" for history in chat_history[-CHAT_HISTORY_MAX_TURN:]: full_input += "{}:{}\n".format(ROLE_MAP.get(history.get("role", "user"), "用户"), - history.get("content")) + history.get("content")) inputs.update({"CHAT_HISTORY": full_input}) return processed_inputs @@ -130,7 +144,7 @@ class LLMExecutable(Executable): message=StatusCode.WORKFLOW_LLM_INIT_ERROR.errmsg.format(msg="Failed to retrieve llm template content") ) default_template = Template(name="default", content=str(user_prompt[0].get("content"))) - return PromptEngine.format(inputs, default_template) + return default_template.format(inputs) def _get_model_input(self, inputs: dict): return self._build_prompt_message(inputs) @@ -174,6 +188,11 @@ class LLMExecutable(Executable): template = template_manager.get(name=template_name, filters=filters) return getattr(template, "content", None) if template else None + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.WORKFLOW_LLM_TEMPLATE_ASSEMBLE_ERROR.code, + message=StatusCode.WORKFLOW_LLM_TEMPLATE_ASSEMBLE_ERROR.errmsg + )(e) def _build_template_filters(self) -> dict: filters = {} @@ -201,15 +220,15 @@ class LLMExecutable(Executable): processed_inputs = self._process_inputs(user_inputs) return self._get_model_input(processed_inputs) - def _invoke_for_json_format(self, inputs: Input) -> Iterator[Output]: + async def _invoke_for_json_format(self, inputs: Input) -> AsyncIterator[Output]: model_inputs = self._prepare_model_inputs(inputs) - llm_output = self._llm.invoke(model_inputs) + llm_output = await self._llm.ainvoke(model_inputs) # 如果 invoke 是异步接口,要加 await yield self._create_output(llm_output) - def _stream_with_chunks(self, inputs: Input) -> Iterator[Output]: + async def _stream_with_chunks(self, inputs: Input) -> AsyncIterator[Output]: model_inputs = self._prepare_model_inputs(inputs) - result = self._llm.stream(model_inputs) - for chunk in result: + # 假设 self._llm.stream 本身就是异步生成器 + async for chunk in self._llm.astream(model_inputs): content = WorkflowLLMUtils.extract_content(chunk) yield content diff --git a/tests/unit_tests/workflow/test_llm_comp.py b/tests/unit_tests/workflow/test_llm_comp.py new file mode 100644 index 0000000..1618700 --- /dev/null +++ b/tests/unit_tests/workflow/test_llm_comp.py @@ -0,0 +1,108 @@ +import sys +import types + +import pytest +from unittest.mock import Mock + +fake_base = types.ModuleType("base") +fake_base.logger = Mock() + +fake_exception_module = types.ModuleType("base") +fake_exception_module.JiuWenBaseException = Mock() + +sys.modules["jiuwen.core.common.logging.base"] = fake_base +sys.modules["jiuwen.core.common.exception.base"] = fake_exception_module +from unittest.mock import patch, AsyncMock + +from jiuwen.core.common.configs.model_config import ModelConfig +from jiuwen.core.common.constants.constant import USER_FIELDS +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.component.llm_comp import LLMCompConfig, LLMExecutable +from jiuwen.core.context.context import Context +from jiuwen.core.utils.llm.base import BaseModelInfo + + +@pytest.fixture +def fake_ctx(): + from unittest.mock import MagicMock + ctx = MagicMock(spec=Context) + ctx.store = MagicMock() + ctx.store.read.return_value = [] + return ctx + + +@pytest.fixture +def fake_input(): + return lambda **kw: {USER_FIELDS: kw} + + +@pytest.fixture +def fake_model_config() -> ModelConfig: + """构造一个最小可用的 ModelConfig""" + return ModelConfig( + model_provider="openai", + model_info=BaseModelInfo( + api_key="sk-fake", + api_base="https://api.openai.com/v1", + model_name="gpt-3.5-turbo", + temperature=0.8, + top_p=0.9, + streaming=False, + timeout=30.0, + ), + ) + + +@patch( + "jiuwen.core.utils.llm.model_utils.model_factory.ModelFactory.get_model", + autospec=True, +) +class TestLLMExecutableInvoke: + + @pytest.mark.asyncio + async def test_invoke_success( + self, + mock_get_model, # 这就是补丁 + fake_ctx, + fake_input, + fake_model_config, + ): + config = LLMCompConfig( + model=fake_model_config, + template_content=[{"role": "user", "content": "Hello {name}"}], + response_format={"type": "text"}, + output_config={"result": { + "type": "string", + "required": True, + }}, + ) + exe = LLMExecutable(config) + + fake_llm = AsyncMock() + fake_llm.invoke = Mock(return_value="mocked response") + mock_get_model.return_value = fake_llm # 直接返回 mock LLM + + output = await exe.invoke(fake_input(name="pytest"), fake_ctx) + + assert output[USER_FIELDS] == {'result': 'mocked response'} + fake_llm.invoke.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_llm_exception( + self, + mock_get_model, + fake_ctx, + fake_input, + fake_model_config, + ): + config = LLMCompConfig(model=fake_model_config, template_content=[{"role": "user", "content": "Hello {name}"}]) + exe = LLMExecutable(config) + + fake_llm = Mock() + fake_llm.invoke = Mock(side_effect=RuntimeError("LLM down")) + mock_get_model.return_value = fake_llm + + with pytest.raises(JiuWenBaseException) as exc_info: + await exe.invoke(fake_input(), fake_ctx) + + assert "LLM down" in str(exc_info.value.message) -- Gitee