From 528d2c947139657264382978f2d84e1bb396cee3 Mon Sep 17 00:00:00 2001 From: yangzequ Date: Sat, 12 Jul 2025 10:37:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9ELLM=E7=BB=84=E4=BB=B6?= =?UTF-8?q?=E3=80=81=E6=8F=90=E9=97=AE=E5=99=A8=E7=BB=84=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增LLM组件、提问器组件 #ICLT7L --- jiuwen/core/common/configs/model_config.py | 10 + jiuwen/core/common/enum/enum.py | 16 + jiuwen/core/component/base.py | 22 ++ jiuwen/core/component/llm_comp.py | 233 +++++++++++ jiuwen/core/component/questioner_comp.py | 431 +++++++++++++++++++++ 5 files changed, 712 insertions(+) create mode 100644 jiuwen/core/common/configs/model_config.py create mode 100644 jiuwen/core/common/enum/enum.py create mode 100644 jiuwen/core/component/llm_comp.py create mode 100644 jiuwen/core/component/questioner_comp.py diff --git a/jiuwen/core/common/configs/model_config.py b/jiuwen/core/common/configs/model_config.py new file mode 100644 index 0000000..0f91d7a --- /dev/null +++ b/jiuwen/core/common/configs/model_config.py @@ -0,0 +1,10 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +from dataclasses import field, dataclass + + +@dataclass +class ModelConfig: + model_provider: str + model_info: BaseModelInfo = field(default_factory=BaseModelInfo) diff --git a/jiuwen/core/common/enum/enum.py b/jiuwen/core/common/enum/enum.py new file mode 100644 index 0000000..3533494 --- /dev/null +++ b/jiuwen/core/common/enum/enum.py @@ -0,0 +1,16 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +from enum import Enum + + +class MessageRole(str, Enum): + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + + +class WorkflowLLMResponseType(Enum): + JSON = "json" + MARKDOWN = "markdown" + TEXT = "text" \ No newline at end of file diff --git a/jiuwen/core/component/base.py b/jiuwen/core/component/base.py index 05efc59..c61f149 100644 --- a/jiuwen/core/component/base.py +++ b/jiuwen/core/component/base.py @@ -2,11 +2,33 @@ # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved from abc import ABC +from dataclasses import dataclass, field +from email.policy import default +from enum import Enum +from typing import Optional from jiuwen.core.graph.base import Graph from jiuwen.core.graph.executable import Executable +@dataclass +class WorkflowComponentMetadata: + node_id: str + node_type: str + node_name: str + + +@dataclass +class ComponentConfig: + metadata: Optional[WorkflowComponentMetadata] = field(default=None) + + +@dataclass +class ComponentState: + comp_id: str + status: Enum + + class WorkflowComponent(ABC): def __init__(self): pass diff --git a/jiuwen/core/component/llm_comp.py b/jiuwen/core/component/llm_comp.py new file mode 100644 index 0000000..bef0600 --- /dev/null +++ b/jiuwen/core/component/llm_comp.py @@ -0,0 +1,233 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +from dataclasses import dataclass, field +from typing import List, Any, Dict, Optional, Iterator, AsyncIterator + +from jiuwen.core.common.configs.model_config import ModelConfig +from jiuwen.core.common.enum.enum import WorkflowLLMResponseType, MessageRole +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.common.exception.status_code import StatusCode +from jiuwen.core.component.base import ComponentConfig, WorkflowComponent +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable, Input, Output + +WORKFLOW_CHAT_HISTORY = "workflow_chat_history" +CHAT_HISTORY_MAX_TURN = 3 +_ROLE = "role" +_CONTENT = "content" +ROLE_MAP = {"user": "用户", "assistant": "助手", "system": "系统"} +_SPAN = "span" +_WORKFLOW_DATA = "workflow_data" +_ID = "id" +_TYPE = "type" +_INSTRUCTION_NAME = "instruction_name" +_TEMPLATE_NAME = "template_name" + +RESPONSE_FORMAT_TO_PROMPT_MAP = { + WorkflowLLMResponseType.JSON.value: { + _INSTRUCTION_NAME: "jsonInstruction", + _TEMPLATE_NAME: "llm_json_formatting" + }, + WorkflowLLMResponseType.MARKDOWN.value: { + _INSTRUCTION_NAME: "markdownInstruction", + _TEMPLATE_NAME: "llm_markdown_formatting" + } +} + +@dataclass +class LLMCompConfig(ComponentConfig): + model: 'ModelConfig' = None + deployMode: str = "" + template_content: List[Any] = field(default_factory=list) + response_format: Dict[str, Any] = field(default_factory=dict) + enable_history: bool = True + user_fields: Dict[str, Any] = field(default_factory=dict) + output_config: Dict[str, Any] = field(default_factory=dict) + + +class LLMExecutable(Executable): + def __init__(self, component_config: LLMCompConfig): + super().__init__() + self._config = component_config + self._llm: BaseChatModel = None + self._initialized: bool = False + self._context = None + + def invoke(self, inputs: Input, context: Context) -> Output: + try: + self._set_context(context) + model_inputs = self._prepare_model_inputs(inputs) + response = self._llm.invoke(model_inputs) + return self._create_output(response) + except JiuWenBaseException: + raise + except Exception as e: + raise JiuWenBaseException(error_code=StatusCode.WORKFLOW_LLM_INIT_ERROR.code, + message=StatusCode.WORKFLOW_LLM_INIT_ERROR.errmsg.format(msg=str(e))) from e + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + pass + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + try: + self._set_context(context) + + response_format_type = self._get_response_format().get(_TYPE) + + if response_format_type == WorkflowLLMResponseType.JSON.value: + yield from self._invoke_for_json_format(inputs) + else: + yield from self._stream_with_chunks(inputs) + except JiuWenBaseException: + raise + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.WORKFLOW_LLM_STREAMING_OUTPUT_ERROR.code, + message=StatusCode.WORKFLOW_LLM_STREAMING_OUTPUT_ERROR.errmsg.format(msg=str(e)) + ) from e + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + pass + + def _initialize_if_needed(self): + if not self._initialized: + try: + self._llm = self._create_llm_instance() + self._initialized = True + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.WORKFLOW_LLM_INIT_ERROR.code, + message=StatusCode.WORKFLOW_LLM_INIT_ERROR.errmsg.format(msg=str(e)) + ) from e + + def _create_llm_instance(self): + return ModelFactory().get_model(self._config.model.model_provider, self._config.model.model_info) + + def _validate_inputs(self, inputs: Input) -> None: + pass + + def _process_inputs(self, inputs: dict) -> dict: + processed_inputs = {} + if inputs: + processed_inputs = inputs.copy() + if self._context: + chat_history: list = self._context.store.read(WORKFLOW_CHAT_HISTORY) + chat_history = chat_history[:-1] if chat_history else [] + full_input = "" + for history in chat_history[-CHAT_HISTORY_MAX_TURN:]: + full_input += "{}:{}\n".format(ROLE_MAP.get(history.get("role", "user"), "用户"), + history.get("content")) + inputs.update({"CHAT_HISTORY": full_input}) + return processed_inputs + + def _build_prompt_message(self, inputs: dict) -> Template: + template_content_list = self._config.template_content + user_prompt = [element for element in template_content_list if element.get(_ROLE, "") == MessageRole.USER.value] + if not user_prompt or not isinstance(user_prompt[0], dict): + raise JiuWenBaseException( + error_code=StatusCode.WORKFLOW_LLM_INIT_ERROR.code, + message=StatusCode.WORKFLOW_LLM_INIT_ERROR.errmsg.format(msg="Failed to retrieve llm template content") + ) + default_template = Template(name="default", content=str(user_prompt[0].get("content"))) + return PromptEngine.format(inputs, default_template) + + def _get_model_input(self, inputs: dict): + return self._build_prompt_message(inputs) + + def _get_response_format(self): + try: + response_format = self._config.response_format + if not response_format: + return {} + + format_type = response_format.get(_TYPE) + if not format_type or format_type not in RESPONSE_FORMAT_TO_PROMPT_MAP: + return response_format + + format_config = RESPONSE_FORMAT_TO_PROMPT_MAP[format_type] + instruction_name = format_config.get(_INSTRUCTION_NAME) + + if response_format.get(instruction_name): + return response_format + + instruction_content = self._get_instruction_from_template(format_config) + if instruction_content: + response_format[instruction_name] = instruction_content + + return response_format + + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.WORKFLOW_LLM_TEMPLATE_ASSEMBLE_ERROR.code, + message=StatusCode.WORKFLOW_LLM_TEMPLATE_ASSEMBLE_ERROR.errmsg + ) from e + + def _get_instruction_from_template(self, format_config: dict) -> Optional[str]: + template_name = format_config.get(_TEMPLATE_NAME) + try: + if not template_name: + return None + filters = self._build_template_filters() + + template_manager = TemplateManager() + template = template_manager.get(name=template_name, filters=filters) + + return getattr(template, "content", None) if template else None + + def _build_template_filters(self) -> dict: + filters = {} + + model_name = self._config.model.model_info.model_name + if model_name: + filters["model_name"] = model_name + + return filters + + def _create_output(self, llm_output) -> Output: + formatted_res = OutputFormatter.format_response(llm_output, + self._config.response_format, + self._config.output_config) + return {USER_FIELDS: formatted_res} + + def _set_context(self, context): + self._context = context + + def _prepare_model_inputs(self, inputs): + self._initialize_if_needed() + self._validate_inputs(inputs) + + user_inputs = inputs.get(USER_FIELDS) + processed_inputs = self._process_inputs(user_inputs) + return self._get_model_input(processed_inputs) + + def _invoke_for_json_format(self, inputs: Input) -> Iterator[Output]: + model_inputs = self._prepare_model_inputs(inputs) + llm_output = self._llm.invoke(model_inputs) + yield self._create_output(llm_output) + + def _stream_with_chunks(self, inputs: Input) -> Iterator[Output]: + model_inputs = self._prepare_model_inputs(inputs) + result = self._llm.stream(model_inputs) + for chunk in result: + content = WorkflowLLMUtils.extract_content(chunk) + yield content + + def _format_response_content(self, response_content: str) -> dict: + pass + + +class LLMComponent(WorkflowComponent): + def __init__(self, component_config: Optional[LLMCompConfig] = None): + super().__init__() + self._executable = None + self._config = component_config + + @property + def executable(self) -> LLMExecutable: + if self._executable is None: + self._executable = self.to_executable() + return self._executable + + def to_executable(self) -> LLMExecutable: + return LLMExecutable(self._config) diff --git a/jiuwen/core/component/questioner_comp.py b/jiuwen/core/component/questioner_comp.py new file mode 100644 index 0000000..044d12b --- /dev/null +++ b/jiuwen/core/component/questioner_comp.py @@ -0,0 +1,431 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +import ast +import json +import re +from dataclasses import dataclass, field +from email.policy import default +from enum import Enum +from typing import Any, Optional, List, Dict, Iterator, AsyncIterator + +from pydantic import BaseModel, Field + +from jiuwen.core.common.configs.model_config import ModelConfig +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.common.exception.status_code import StatusCode +from jiuwen.core.component.base import ComponentConfig, WorkflowComponent +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable, Input, Output + +START_STR = "start" +END_STR = "end" +USER_INTERACT_STR = "user_interact" +USER_FIELDS_KEY = "userFields" + +SUB_PLACEHOLDER_PATTERN = r'\{\{([^}]*)\}\}' +CONTINUE_ASK_STATEMENT = "请您提供{non_extracted_key_fields_names}相关的信息" +WORKFLOW_CHAT_HISTORY = "workflow_chat_history" +TEMPLATE_NAME = "questioner" +QUESTIONER_STATE_KEY = "questioner_state" + +class ExecutionStatus(Enum): + START = START_STR + USER_INTERACT = USER_INTERACT_STR + END = END_STR + +class QuestionerEvent(Enum): + START_EVENT = START_STR + END_EVENT = END_STR + USER_INTERACT_EVENT = USER_INTERACT_STR + +class ResponseType(Enum): + ReplyDirectly = "reply_directly" + +class FieldInfo(BaseModel): + field_name: str = Field(default="") + description: str = Field(default="") + cn_field_name: str = Field(default="") + required: bool = Field(default=False) + default_value: Any = Field(default="") + +@dataclass +class QuestionerConfig(ComponentConfig): + model: Optional[ModelConfig] = field(default=None) + response_type: str = field(default=ResponseType.ReplyDirectly.value) + question_content: str = field(default="") + extract_fields_from_response: bool = field(default=True) + field_names: List[FieldInfo] = field(default_factory=list) + max_response: int = field(default=3) + with_chat_history: bool = field(default=True) + chat_history_max_rounds: int = field(default=5) + prompt_template: List[dict] = field(default_factory=list) + extra_prompt_for_fields_extraction: str = field(default="") + example_content: str = field(default="") + + +class QuestionerOutput(BaseModel): + user_response: str = Field(default="") + question: str = Field(default="") + key_fields: dict = Field(default=ExecutionStatus.START) + + +class QuestionerState(BaseModel): + response_num: int = Field(default=0) + user_response: str = Field(default="") + extracted_key_fields: Dict[str, Any] = Field(default_factory=dict) + status: ExecutionStatus = Field(default=ExecutionStatus.START) + + @classmethod + def deserialize(cls, raw_state: str): + state = cls.model_validate_json(raw_state) + return state.handle_event(QuestionerEvent(state.status.value)) + + def handle_event(self, event: QuestionerEvent): + if event == QuestionerEvent.START_EVENT: + return QuestionerStartState.from_state(self) + if event == QuestionerEvent.USER_INTERACT_EVENT: + return QuestionerInteractState.from_state(self) + if event == QuestionerEvent.END_EVENT: + return QuestionerEndState.from_state(self) + return self + + def serialize(self) -> str: + return self.model_dump_json() + + def is_undergoing_interaction(self): + return self.status in [ExecutionStatus.USER_INTERACT] + + +class QuestionerStartState(QuestionerState): + @classmethod + def from_state(cls, questioner_state: QuestionerState): + return cls(response_num=questioner_state.response_num, + user_response=questioner_state.user_response, + extracted_key_fields=questioner_state.extracted_key_fields, + status=ExecutionStatus.START) + +class QuestionerInteractState(QuestionerState): + @classmethod + def from_state(cls, questioner_state: QuestionerState): + return cls(response_num=questioner_state.response_num, + user_response=questioner_state.user_response, + extracted_key_fields=questioner_state.extracted_key_fields, + status=ExecutionStatus.USER_INTERACT) + + def handle_event(self, event: QuestionerEvent): + if event == QuestionerEvent.END_EVENT: + return QuestionerEndState.from_state(self) + return self + + +class QuestionerEndState(QuestionerState): + @classmethod + def from_state(cls, questioner_state: QuestionerState): + return cls(response_num=questioner_state.response_num, + user_response=questioner_state.user_response, + extracted_key_fields=questioner_state.extracted_key_fields, + status=ExecutionStatus.END) + + def handle_event(self, event: QuestionerEvent): + return self + + +class QuestionerUtils: + @staticmethod + def format_template(template: str, user_fields: dict): + def replace(match): + key = match.group(1) + return str(user_fields.get(key)) + + try: + result = re.sub(SUB_PLACEHOLDER_PATTERN, replace, template) + return result + except (KeyError, TypeError, AttributeError): + return "" + + @staticmethod + def get_latest_k_rounds_chat(chat_history, rounds): + return chat_history[-rounds * 2 - 1:] + + @staticmethod + def format_continue_ask_question(non_extracted_key_fields: List[FieldInfo]): + non_extracted_key_fields_names = list() + for param in non_extracted_key_fields: + non_extracted_key_fields_names.append(param.cn_field_name or param.description) + result = ", ".join(non_extracted_key_fields_names) + return CONTINUE_ASK_STATEMENT.format(non_extracted_key_fields_names=result) + + +class QuestionerDirectReplyHandler: + def __init__(self): + self._config = None + self._model = None + self._state = None + self._prompt = None + self._query = "" + + def config(self, config: QuestionerConfig): + self._config = config + return self + + def model(self, model: BaseChatModel): + self._model = model + return self + + def state(self, state: QuestionerState): + self._state = state + return self + + def prompt(self, prompt): + self._prompt = prompt + return self + + def handle(self, inputs: Input, context: Context): + if self._state.status == ExecutionStatus.START: + return self._handle_start_state(inputs, context) + if self._state.status == ExecutionStatus.USER_INTERACT: + return self._handle_user_interact_state(inputs, context) + if self._state.status == ExecutionStatus.END: + return self._handle_end_state(inputs, context) + return dict() + + def _handle_start_state(self, inputs, context): + output = QuestionerOutput() + self._query = inputs.get("query", "") + chat_history = self._get_latest_chat_history(context) + if self._is_set_question_content(): + user_fields = inputs.get(USER_FIELDS_KEY, dict()) + output.question = QuestionerUtils.format_template(self._config.question_content, user_fields) + self._state.handle_event(QuestionerEvent.USER_INTERACT_EVENT) + return dict(usesFields=output.model_dump(exclude_defaults=True)) + + if self._need_extract_fields(): + is_continue_ask = self._initial_extract_from_chat_history(chat_history, output) + event = QuestionerEvent.USER_INTERACT_EVENT if is_continue_ask else QuestionerEvent.END_EVENT + self._state.handle_event(event) + else: + raise JiuWenBaseException( + error_code=StatusCode.WORKFLOW_QUESTIONER_QUESTION_EMPTY_DIRECT_COLLECTION_ERROR.code, + message=StatusCode.WORKFLOW_QUESTIONER_QUESTION_EMPTY_DIRECT_COLLECTION_ERROR.errmsg + ) + return dict(usesFields=output.model_dump(exclude_defaults=True)) + + def _handle_user_interact_state(self, inputs, context): + output = QuestionerOutput() + self._query = inputs.get("query", "") + chat_history = self._get_latest_chat_history(context) + user_response = chat_history[-1].get("content", "") if chat_history else "" + + if self._is_set_question_content() and not self._need_extract_fields(): + output.user_response = user_response + self._state.handle_event(QuestionerEvent.END_EVENT) + return dict(usesFields=output.model_dump(exclude_defaults=True)) + + if self._need_extract_fields(): + is_continue_ask = self._repeat_extract_from_chat_history(chat_history, output) + event = QuestionerEvent.USER_INTERACT_EVENT if is_continue_ask else QuestionerEvent.END_EVENT + self._state.handle_event(event) + else: + raise JiuWenBaseException( + error_code=StatusCode.WORKFLOW_QUESTIONER_QUESTION_EMPTY_DIRECT_COLLECTION_ERROR.code, + message=StatusCode.WORKFLOW_QUESTIONER_QUESTION_EMPTY_DIRECT_COLLECTION_ERROR.errmsg + ) + return dict(usesFields=output.model_dump(exclude_defaults=True)) + + def _handle_end_state(self, inputs, context): + return dict( + userFields=QuestionerOutput(user_response=self._state.user_response, + key_fields=self._state.extract_key_fields).model_dump(exclude_defaults=True) + ) + + def _is_set_question_content(self): + return isinstance(self._config.question_content, str) and len(self._config.question_content) > 0 + + def _need_extract_fields(self): + return (self._config.extract_fields_from_response and + len(self._config.field_names) > len(self._state.extract_key_fields)) + + def _initial_extract_from_chat_history(self, chat_history, output: QuestionerOutput) -> bool: + self._invoke_llm_and_parse_result(chat_history, output) + + self._update_param_default_value(output) + self._update_state_of_key_fields(output.key_fields) + + return self._check_if_continue_ask(output) + + def _repeat_extract_from_chat_history(self, chat_history, output: QuestionerOutput) -> bool: + self._invoke_llm_and_parse_result(chat_history, output) + + return self._check_if_continue_ask(output) + + def _get_latest_chat_history(self, context) -> List: + result = list() + if self._config.with_chat_history: + raw_chat_history = context.store.read(WORKFLOW_CHAT_HISTORY) or list() + if raw_chat_history: + result = QuestionerUtils.get_latest_k_rounds_chat(raw_chat_history, self._config.chat_history_max_rounds) + if not result or "user" == result[-1].get("role", ""): + result.append(dict(role="user", content=self._query)) + return result + + def _build_llm_inputs(self, chat_history: list = None) -> List: + prompt_template_input = self._create_prompt_template_keywords(chat_history) + prompt_template = PromptEngine.format(prompt_template_input, self._prompt) + return prompt_template.content + + def _create_prompt_template_keywords(self, chat_history): + params_list, required_name_list = list(), list() + for param in self._config.field_names: + params_list.append(f"{param.field_name}: {param.description}") + if param.required: + required_name_list.append(param.cn_field_name or param.description) + required_name_str = "、".join(required_name_list) + f"{len(required_name_list)}个必要信息" + all_param_str = "\n".join(params_list) + dialogue_history_str = "\n".join([f"{_.get('role', '')}:{_.get('content', '')}" for _ in chat_history]) + + return dict(required_name=required_name_str, required_params_list=all_param_str, + extra_info=self._config.extra_prompt_for_fields_extraction, example=self._config.example_content, + dig_history=dialogue_history_str) + + def _invoke_llm_for_extraction(self, llm_inputs): + try: + response = self._model.invoke(llm_inputs).content + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.INVOKE_LLM_FAILED.code, + message=StatusCode.INVOKE_LLM_FAILED.errmsg + ) from e + + result = dict() + try: + result = json.loads(response, strict=False) + result = {k: v for k, v in result.items() if v is not None and str(v)} + except json.JSONDecodeError as e: + try: + result = {k: v for k, v in ast.literal_eval(response).items() if v is not None and str(v)} + except (SyntaxError, AttributeError, ValueError): + return result + return result + + def _filter_non_extracted_key_fields(self) -> List[FieldInfo]: + return [_ for _ in self._config.field_names if _.field_name not in self._state.extract_key_fields] + + def _update_state_of_key_fields(self, key_fields): + self._state.extract_key_fields.update(key_fields) + + def _update_param_default_value(self, output: QuestionerOutput): + result = dict() + extracted_key_fields = self._state.extract_key_fields + for param in self._config.field_names: + param_name = param.field_name + default_value = param.default_value + if default_value and param_name not in extracted_key_fields: + result.update({param_name: default_value}) + output.key_fields.update(result) + + def _increment_state_of_response_num(self): + self._state.response_num += 1 + + def _exceed_max_response(self): + return self._state.response_num > self._config.max_response + + def _check_if_continue_ask(self, output: QuestionerOutput): + is_continue_ask = False + non_extracted_key_fields: List[FieldInfo] = self._filter_non_extracted_key_fields() + if non_extracted_key_fields: + if not self._exceed_max_response(): + output.question = QuestionerUtils.format_continue_ask_question(non_extracted_key_fields) + is_continue_ask = True + else: + raise JiuWenBaseException( + error_code=StatusCode.WORKFLOW_QUESTIONER_EXCEED_LOOP.code, + message=StatusCode.WORKFLOW_QUESTIONER_EXCEED_LOOP.errmsg + ) + if is_continue_ask: + output.key_fields.clear() + else: + output.key_fields.update(self._state.extract_key_fields) + return is_continue_ask + + def _invoke_llm_and_parse_result(self, chat_history, output): + llm_inputs = self._build_llm_inputs(chat_history=chat_history) + extracted_key_fields = self._invoke_llm_for_extraction(llm_inputs) + output.key_fields.update(extracted_key_fields) + self._increment_state_of_response_num() + self._update_state_of_key_fields(extracted_key_fields) + +class QuestionerExecutable(Executable): + def __init__(self, config: QuestionerConfig): + super().__init__() + self._config = config + self._llm = self._create_llm_instance() + self._prompt: Template = self._init_prompt() + self._questioner_state = QuestionerState() + + @staticmethod + def _load_state_from_context(context) -> QuestionerState: + state_str = context.state.get(QUESTIONER_STATE_KEY) + if state_str: + return QuestionerState.deserialize(state_str) + return QuestionerState() + + @staticmethod + def _store_state_to_context(state: QuestionerState, context): + state_str = state.serialize() + context.state.ser(QUESTIONER_STATE_KEY, state_str) + + def invoke(self, inputs: Input, context: Context) -> Output: + state_from_context = self._load_state_from_context(context) + if state_from_context.is_undergoing_interaction(): + self._questioner_state = state_from_context + self._questioner_state = self._questioner_state.handle_event(QuestionerEvent.START_EVENT) + + invoke_result = dict() + if self._config.response_type == ResponseType.ReplyDirectly.value: + invoke_result = self._handle_questioner_direct_reply(inputs, context) + + self._store_state_to_context(self._questioner_state, context) + return invoke_result + + async def ainvoke(self, inputs: Input, context: Context) -> Output: + pass + + def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + pass + + async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + pass + + def interrupt(self, message: dict): + return super().interrupt(message) + + def _create_llm_instance(self) -> BaseChatModel: + return ModelFactory().get_model(model_provider=self._config.model.model_provider, + model_info=self._config.model.model_info) + + def _init_prompt(self) -> Template: + if self._config.prompt_template: + return Template(name="user_prompt", content=self._config.prompt_template) + + filters = dict(model_name=self._config.model.model_info.model_name) + return TemplateManager.get(name=TEMPLATE_NAME, filters=filters) + + def _handle_questioner_direct_reply(self, inputs: Input, context: Context): + return (QuestionerDirectReplyHandler() + .config(self._config) + .model(self._llm) + .state(self._questioner_state) + .prompt(self._prompt) + .handle(inputs, context)) + + +class QuestionerComponent(WorkflowComponent): + def __init__(self, questioner_comp_config: QuestionerConfig = None): + super().__init__() + self._questioner_config = questioner_comp_config + self._executable = None + + def to_executable(self) -> Executable: + return QuestionerExecutable(self._questioner_config) + -- Gitee