diff --git a/apps/common/queue.py b/apps/common/queue.py index 87089a92c107afee457104f76a705c6e68418b04..5cd66e617956d16884f07c92ece5fd4f701e6141 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator from datetime import UTC, datetime from typing import Any -from apps.llm import LLMConfig +from apps.llm import LLM from apps.schemas.enum_var import EventType from apps.schemas.message import ( HeartbeatData, @@ -36,7 +36,7 @@ class MessageQueue: self._close = False self._heartbeat_task = asyncio.get_event_loop().create_task(self._heartbeat()) - async def push_output(self, task: TaskData, llm: LLMConfig, event_type: str, data: dict[str, Any]) -> None: + async def push_output(self, task: TaskData, llm: LLM, event_type: str, data: dict[str, Any]) -> None: """组装用于向用户(前端/Shell端)输出的消息""" if event_type == EventType.DONE.value: await self._queue.put("[DONE]") @@ -48,8 +48,8 @@ class MessageQueue: metadata = MessageMetadata( timeCost=step_time, - inputTokens=llm.reasoning.input_tokens, - outputTokens=llm.reasoning.output_tokens, + inputTokens=llm.input_tokens, + outputTokens=llm.output_tokens, ) if task.state: diff --git a/apps/constants.py b/apps/constants.py index 4e8a2ed27f8cd10e51e443ce6bacffac81af8139..da50474d420842e890de4559ede978bfba1be773 100644 --- a/apps/constants.py +++ b/apps/constants.py @@ -49,6 +49,6 @@ AGENT_MAX_RETRY_TIMES = 3 # MCP Agent 最大步骤数 AGENT_MAX_STEPS = 25 # MCP Agent 最终步骤名称 -AGENT_FINAL_STEP_NAME = "FIANL" +AGENT_FINAL_STEP_NAME = "FINAL" # 大模型超时时间 LLM_TIMEOUT = 300.0 diff --git a/apps/llm/__init__.py b/apps/llm/__init__.py index 30e487a5c4904a8011638b26bc78747edb301c65..1188d95d09fd1554dedcaa159bf954d6bba9da33 100644 --- a/apps/llm/__init__.py +++ b/apps/llm/__init__.py @@ -1,16 +1,13 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """模型调用模块""" -from .embedding import Embedding, embedding +from .embedding import embedding from .generator import json_generator from .llm import LLM -from .schema import LLMConfig from .token import token_calculator __all__ = [ "LLM", - "Embedding", - "LLMConfig", "embedding", "json_generator", "token_calculator", diff --git a/apps/llm/generator.py b/apps/llm/generator.py index 1a32c1bacdac0ac44d01d593abb716b7c32ade21..f2f681faa7cb41cb8e4c07e70f7903d7503b92f8 100644 --- a/apps/llm/generator.py +++ b/apps/llm/generator.py @@ -15,7 +15,6 @@ from apps.schemas.llm import LLMFunctions from .llm import LLM from .prompt import JSON_GEN_BASIC, JSON_NO_FUNCTION_CALL -from .schema import LLMConfig _logger = logging.getLogger(__name__) @@ -24,7 +23,7 @@ JSON_GEN_MAX_TRIAL = 3 class JsonGenerator: """综合Json生成器(全局单例)""" - def __init__(self) -> None: + def __init__(self, llm: LLM | None = None) -> None: """创建JsonGenerator实例""" # Jinja2环境,可以复用 self._env = SandboxedEnvironment( @@ -35,22 +34,24 @@ class JsonGenerator: extensions=["jinja2.ext.loopcontrols"], ) # 初始化时设为None,调用init后设置 - self._llm: LLM | None = None + self._llm: LLM | None = llm self._support_function_call: bool = False + if llm is not None: + self._check_function_call_support() - def init(self, llm_config: LLMConfig) -> None: + def init(self, llm: LLM) -> None: """初始化JsonGenerator,设置LLM配置""" - # 选择LLM:优先使用Function模型(如果存在且支持FunctionCall),否则回退到Reasoning模型 - self._llm, self._support_function_call = self._select_llm(llm_config) - - def _select_llm(self, llm_config: LLMConfig) -> tuple[LLM, bool]: - """选择LLM:优先使用Function模型(如果存在且支持FunctionCall),否则回退到Reasoning模型""" - if llm_config.function is not None and LLMType.FUNCTION in llm_config.function.config.llmType: - _logger.info("[JSONGenerator] 使用Function模型,支持FunctionCall") - return llm_config.function, True - - _logger.info("[JSONGenerator] Function模型不可用或不支持FunctionCall,回退到Reasoning模型") - return llm_config.reasoning, False + self._llm = llm + self._check_function_call_support() + + def _check_function_call_support(self) -> None: + """检查LLM是否支持FunctionCall""" + if self._llm is not None and LLMType.FUNCTION in self._llm.config.llmType: + _logger.info("[JSONGenerator] LLM支持FunctionCall") + self._support_function_call = True + else: + _logger.info("[JSONGenerator] LLM不支持FunctionCall,将使用prompt方式") + self._support_function_call = False async def _single_trial( self, diff --git a/apps/llm/schema.py b/apps/llm/schema.py deleted file mode 100644 index 53296a1bd51d379510fdc14adf9104e2f2976f13..0000000000000000000000000000000000000000 --- a/apps/llm/schema.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""LLM配置""" - -from pydantic import BaseModel, Field - -from .embedding import Embedding -from .llm import LLM - - -class LLMConfig(BaseModel): - """LLM配置""" - - reasoning: LLM = Field(description="推理LLM") - function: LLM | None = Field(description="函数LLM") - embedding: Embedding | None = Field(description="Embedding") - - class Config: - """配置""" - - arbitrary_types_allowed = True diff --git a/apps/models/task.py b/apps/models/task.py index c799b243f00e3600f8fc9383381152b7ca6a3ad3..6d2f33a3e2b2b85235f7f6eed0edc530325a075e 100644 --- a/apps/models/task.py +++ b/apps/models/task.py @@ -16,7 +16,6 @@ from .base import Base class ExecutorStatus(str, PyEnum): """执行器状态""" - UNKNOWN = "unknown" INIT = "init" WAITING = "waiting" RUNNING = "running" @@ -35,7 +34,6 @@ class LanguageType(str, PyEnum): class StepStatus(str, PyEnum): """步骤状态""" - UNKNOWN = "unknown" INIT = "init" WAITING = "waiting" RUNNING = "running" diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index b82f0f965d617174cc6b3bd025eef3f23e5e91d6..e4aba23f654e1576b00f56cbaed0d0539de4751d 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -91,7 +91,7 @@ class MCPAgentExecutor(BaseExecutor): self.tools[tool.id] = tool self.tools[AGENT_FINAL_STEP_NAME] = MCPTools( - id=AGENT_FINAL_STEP_NAME, mcpId="", toolName="Final Tool", description="结束流程的工具", + mcpId="", toolName=AGENT_FINAL_STEP_NAME, description="结束流程的工具", inputSchema={}, outputSchema={}, ) @@ -140,6 +140,38 @@ class MCPAgentExecutor(BaseExecutor): _logger.error(err) raise RuntimeError(err) + def _get_error_message_str(self, error_message: dict | str | None) -> str: + """将错误消息转换为字符串""" + if isinstance(error_message, dict): + return str(error_message.get("err_msg", "")) + return str(error_message) if error_message else "" + + def _update_last_context_status(self, step_status: StepStatus, output_data: dict | None = None) -> None: + """更新最后一个context的状态(如果是当前步骤)""" + self._validate_task_state() + state = cast("ExecutorCheckpoint", self.task.state) + if len(self.task.context) and self.task.context[-1].stepId == state.stepId: + self.task.context[-1].stepStatus = step_status + if output_data is not None: + self.task.context[-1].outputData = output_data + + async def _add_error_to_context(self, step_status: StepStatus) -> None: + """添加错误到context,先移除重复的步骤再添加""" + self._remove_last_context_if_same_step() + self.task.context.append( + self._create_executor_history(step_status=step_status), + ) + + async def _handle_final_step(self) -> None: + """处理最终步骤,设置状态并总结""" + self._validate_task_state() + state = cast("ExecutorCheckpoint", self.task.state) + state.executorStatus = ExecutorStatus.SUCCESS + state.stepStatus = StepStatus.SUCCESS + self.task.state = state + await self._push_message(EventType.EXECUTOR_STOP, data={}) + await self.summarize() + def _create_executor_history( self, step_status: StepStatus, @@ -176,25 +208,26 @@ class MCPAgentExecutor(BaseExecutor): """处理步骤错误并继续下一步""" self._validate_task_state() state = cast("ExecutorCheckpoint", self.task.state) + error_output = {"message": state.errorMessage} + + # 先更新stepStatus + state.stepStatus = StepStatus.ERROR + self.task.state = state + await self._push_message( - EventType.STEP_ERROR, - data={ - "message": state.errorMessage, - }, + EventType.STEP_END, + data=error_output, ) + + # 更新或添加错误到context if len(self.task.context) and self.task.context[-1].stepId == state.stepId: - self.task.context[-1].stepStatus = StepStatus.ERROR - self.task.context[-1].outputData = { - "message": state.errorMessage, - } + self._update_last_context_status(StepStatus.ERROR, error_output) else: self.task.context.append( self._create_executor_history( step_status=StepStatus.ERROR, input_data=self._current_input, - output_data={ - "message": state.errorMessage, - }, + output_data=error_output, ), ) await self.get_next_step() @@ -207,19 +240,21 @@ class MCPAgentExecutor(BaseExecutor): current_tool = cast("MCPTools", self._current_tool) confirm_message = await self._planner.get_tool_risk(current_tool, self._current_input, "") + + # 先更新状态 + state.executorStatus = ExecutorStatus.WAITING + state.stepStatus = StepStatus.WAITING + self.task.state = state + await self._push_message( EventType.STEP_WAITING_FOR_START, confirm_message.model_dump(exclude_none=True, by_alias=True), ) - await self._push_message(EventType.FLOW_STOP, {}) - state.executorStatus = ExecutorStatus.WAITING - state.stepStatus = StepStatus.WAITING self.task.context.append( self._create_executor_history( step_status=state.stepStatus, extra_data=confirm_message.model_dump(exclude_none=True, by_alias=True), ), ) - self.task.state = state async def run_step(self) -> None: """执行步骤""" @@ -287,6 +322,10 @@ class MCPAgentExecutor(BaseExecutor): "message": message, } + # 先更新状态为成功 + state.stepStatus = StepStatus.SUCCESS + self.task.state = state + await self._push_message(EventType.STEP_INPUT, self._current_input) await self._push_message(EventType.STEP_OUTPUT, output_data) self.task.context.append( @@ -296,8 +335,6 @@ class MCPAgentExecutor(BaseExecutor): output_data=output_data, ), ) - state.stepStatus = StepStatus.SUCCESS - self.task.state = state async def generate_params_with_null(self) -> None: """生成参数补充""" @@ -311,22 +348,21 @@ class MCPAgentExecutor(BaseExecutor): self._current_input, state.errorMessage, ) - error_msg_str = ( - str(state.errorMessage.get("err_msg", "")) - if isinstance(state.errorMessage, dict) - else str(state.errorMessage) - ) + error_msg_str = self._get_error_message_str(state.errorMessage) error_message = await self._planner.change_err_message_to_description( error_message=error_msg_str, tool=current_tool, input_params=self._current_input, ) + + # 先更新状态 + state.executorStatus = ExecutorStatus.WAITING + state.stepStatus = StepStatus.PARAM + self.task.state = state + await self._push_message( EventType.STEP_WAITING_FOR_PARAM, data={"message": error_message, "params": params_with_null}, ) - await self._push_message(EventType.FLOW_STOP, data={}) - state.executorStatus = ExecutorStatus.WAITING - state.stepStatus = StepStatus.PARAM self.task.context.append( self._create_executor_history( step_status=state.stepStatus, @@ -336,7 +372,6 @@ class MCPAgentExecutor(BaseExecutor): }, ), ) - self.task.state = state async def get_next_step(self) -> None: """获取下一步""" @@ -374,26 +409,18 @@ class MCPAgentExecutor(BaseExecutor): state = cast("ExecutorCheckpoint", self.task.state) state.stepStatus = StepStatus.ERROR state.executorStatus = ExecutorStatus.ERROR - await self._push_message( - EventType.FLOW_FAILED, - data={}, - ) - self._remove_last_context_if_same_step() - self.task.context.append( - self._create_executor_history(step_status=state.stepStatus), - ) self.task.state = state - async def work(self) -> None: # noqa: C901, PLR0912, PLR0915 + await self._push_message(EventType.STEP_END, data={}) + await self._push_message(EventType.EXECUTOR_STOP, data={}) + await self._add_error_to_context(state.stepStatus) + + async def work(self) -> None: # noqa: C901, PLR0912 """执行当前步骤""" self._validate_task_state() state = cast("ExecutorCheckpoint", self.task.state) if state.stepStatus == StepStatus.INIT: - await self._push_message( - EventType.STEP_INIT, - data={}, - ) await self.get_tool_input_param(is_first=True) if not self._user.autoExecute: # 等待用户确认 @@ -411,17 +438,11 @@ class MCPAgentExecutor(BaseExecutor): else: state.executorStatus = ExecutorStatus.CANCELLED state.stepStatus = StepStatus.CANCELLED - await self._push_message( - EventType.STEP_CANCEL, - data={}, - ) - await self._push_message( - EventType.FLOW_CANCEL, - data={}, - ) - if len(self.task.context) and self.task.context[-1].stepId == state.stepId: - self.task.context[-1].stepStatus = StepStatus.CANCELLED self.task.state = state + + await self._push_message(EventType.STEP_END, data={}) + await self._push_message(EventType.EXECUTOR_STOP, data={}) + self._update_last_context_status(StepStatus.CANCELLED) return max_retry = 5 for i in range(max_retry): @@ -434,34 +455,23 @@ class MCPAgentExecutor(BaseExecutor): # 错误处理 if self._retry_times >= AGENT_MAX_RETRY_TIMES: await self.error_handle_after_step() + elif self._user.autoExecute: + await self._handle_step_error_and_continue() else: - user_info = await UserManager.get_user(self.task.metadata.userId) - if not user_info: - err = "[MCPAgentExecutor] 用户不存在" - _logger.error(err) - raise RuntimeError(err) - - if user_info.autoExecute: - await self._handle_step_error_and_continue() + mcp_tool = self.tools[state.stepName] + error_msg = self._get_error_message_str(state.errorMessage) + is_param_error = await self._planner.is_param_error( + await self._host.assemble_memory(self.task.runtime, self.task.context), + error_msg, + mcp_tool, + "", + self._current_input, + ) + if is_param_error.is_param_error: + # 如果是参数错误,生成参数补充 + await self.generate_params_with_null() else: - mcp_tool = self.tools[state.stepName] - error_msg = ( - str(state.errorMessage.get("err_msg", "")) - if isinstance(state.errorMessage, dict) - else str(state.errorMessage) - ) - is_param_error = await self._planner.is_param_error( - await self._host.assemble_memory(self.task.runtime, self.task.context), - error_msg, - mcp_tool, - "", - self._current_input, - ) - if is_param_error.is_param_error: - # 如果是参数错误,生成参数补充 - await self.generate_params_with_null() - else: - await self._handle_step_error_and_continue() + await self._handle_step_error_and_continue() elif state.stepStatus == StepStatus.SUCCESS: await self.get_next_step() @@ -478,12 +488,9 @@ class MCPAgentExecutor(BaseExecutor): async def run(self) -> None: """执行MCP Agent的主逻辑""" - if not self.task.state: - err = "[MCPAgentExecutor] 任务状态不存在" - _logger.error(err) - raise RuntimeError(err) - + self._validate_task_state() state = cast("ExecutorCheckpoint", self.task.state) + # 初始化MCP服务 await self.load_mcp() data = {} @@ -499,14 +506,10 @@ class MCPAgentExecutor(BaseExecutor): state.executorStatus = ExecutorStatus.RUNNING self.task.state = state - await self._push_message(EventType.FLOW_START, data=data) if state.stepName == AGENT_FINAL_STEP_NAME: # 如果已经是最后一步,直接结束 - state.executorStatus = ExecutorStatus.SUCCESS - self.task.state = state - await self._push_message(EventType.FLOW_SUCCESS, data={}) - await self.summarize() + await self._handle_final_step() return try: @@ -515,11 +518,7 @@ class MCPAgentExecutor(BaseExecutor): if state.stepName == AGENT_FINAL_STEP_NAME: # 如果已经是最后一步,直接结束 - state.executorStatus = ExecutorStatus.SUCCESS - state.stepStatus = StepStatus.SUCCESS - self.task.state = state - await self._push_message(EventType.FLOW_SUCCESS, data={}) - await self.summarize() + await self._handle_final_step() except Exception as e: _logger.exception("[MCPAgentExecutor] 执行过程中发生错误") state.executorStatus = ExecutorStatus.ERROR @@ -528,14 +527,11 @@ class MCPAgentExecutor(BaseExecutor): "data": {}, } state.stepStatus = StepStatus.ERROR - await self._push_message(EventType.STEP_ERROR, data={}) - await self._push_message(EventType.FLOW_FAILED, data={}) - - self._remove_last_context_if_same_step() - self.task.context.append( - self._create_executor_history(step_status=state.stepStatus), - ) self.task.state = state + + await self._push_message(EventType.STEP_END, data={}) + await self._push_message(EventType.EXECUTOR_STOP, data={}) + await self._add_error_to_context(state.stepStatus) finally: for mcp_service in self._mcp_list: try: diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 43876db364f01f8f832a8367d283485a3d4a9308..dae1249775eff1ce03d50cba70e1c3932173ec0a 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -15,7 +15,7 @@ from apps.services.record import RecordManager if TYPE_CHECKING: from apps.common.queue import MessageQueue - from apps.llm import LLMConfig + from apps.llm import LLM from apps.schemas.task import TaskData @@ -24,7 +24,7 @@ class BaseExecutor(BaseModel, ABC): task: "TaskData" msg_queue: "MessageQueue" - llm: "LLMConfig" + llm: "LLM" question: str @@ -74,7 +74,7 @@ class BaseExecutor(BaseModel, ABC): 统一的消息推送接口 :param event_type: 事件类型 - :param data: 消息数据,如果是FLOW_START事件且data为None,则自动构建FlowStartContent + :param data: 消息数据,如果是EXECUTOR_START事件且data为None,则自动构建ExecutorStartContent """ if event_type == EventType.TEXT_ADD.value and isinstance(data, str): data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 4d820503a63d949088335389874185b2fb9d92a7..9de9e7e42c77f5afbf9a9f7c552be0b6b1a66150 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -78,7 +78,7 @@ class FlowExecutor(BaseExecutor): # 尝试恢复State if ( self.state - and self.state.executorStatus not in [ExecutorStatus.INIT, ExecutorStatus.UNKNOWN] + and self.state.executorStatus != ExecutorStatus.INIT ): # 创建ExecutorState self.state = ExecutorCheckpoint( @@ -182,8 +182,6 @@ class FlowExecutor(BaseExecutor): 数据通过向Queue发送消息的方式传输 """ logger.info("[FlowExecutor] 运行工作流") - # 推送Flow开始消息 - await self._push_message(EventType.FLOW_START.value) # 获取首个步骤 first_step = StepQueueItem( @@ -271,7 +269,4 @@ class FlowExecutor(BaseExecutor): # FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间) self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2) - self.task.runtime.fullTime # 推送Flow停止消息 - if is_error: - await self._push_message(EventType.FLOW_FAILED.value) - else: - await self._push_message(EventType.FLOW_SUCCESS.value) + await self._push_message(EventType.EXECUTOR_STOP.value) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 5d608e054cbc209670cd81a41b2dfe3d32705085..1e7e2f17485ee0463e217d18c7f82bd4fb60b270 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -10,7 +10,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from mcp.types import TextContent -from apps.llm import LLMConfig, json_generator +from apps.llm import LLM, json_generator from apps.models import LanguageType, MCPTools from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.pool.mcp.client import MCPClient @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) class MCPHost: """MCP宿主服务""" - def __init__(self, user_id: str, task_id: uuid.UUID, llm: LLMConfig, language: LanguageType) -> None: + def __init__(self, user_id: str, task_id: uuid.UUID, llm: LLM, language: LanguageType) -> None: """初始化MCP宿主""" self._task_id = task_id self._user_id = user_id diff --git a/apps/scheduler/mcp/plan.py b/apps/scheduler/mcp/plan.py index 70e6b996312a329b56de5318bf0279ae4b0c77f4..ba74910552a4c1513867e0e8b6a6fe46f1660e09 100644 --- a/apps/scheduler/mcp/plan.py +++ b/apps/scheduler/mcp/plan.py @@ -6,7 +6,7 @@ import logging from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -from apps.llm import LLMConfig, json_generator +from apps.llm import LLM, json_generator from apps.models import LanguageType, MCPTools from apps.schemas.mcp import MCPPlan @@ -17,7 +17,7 @@ _logger = logging.getLogger(__name__) class MCPPlanner: """MCP 用户目标拆解与规划""" - def __init__(self, user_goal: str, language: LanguageType, llm: LLMConfig) -> None: + def __init__(self, user_goal: str, language: LanguageType, llm: LLM) -> None: """初始化MCP规划器""" self._user_goal = user_goal self._env = SandboxedEnvironment( @@ -55,7 +55,7 @@ class MCPPlanner: {"role": "user", "content": prompt}, ] result = "" - async for chunk in self._llm.reasoning.call( + async for chunk in self._llm.call( message, streaming=False, ): @@ -99,7 +99,7 @@ class MCPPlanner: ) result = "" - async for chunk in self._llm.reasoning.call( + async for chunk in self._llm.call( [{"role": "user", "content": prompt}], streaming=False, ): diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index ee3bffe314c97c5297ec21637b648a97631f88ff..1028f9b1d1db6a2e209ea0f846051331a4a1078e 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -6,7 +6,7 @@ import logging from sqlalchemy import select from apps.common.postgres import postgres -from apps.llm import LLMConfig, embedding, json_generator +from apps.llm import LLM, embedding, json_generator from apps.models import LanguageType, MCPTools from apps.schemas.mcp import MCPSelectResult from apps.services.mcp_service import MCPServiceManager @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) class MCPSelector: """MCP选择器""" - def __init__(self, llm: LLMConfig, language: LanguageType = LanguageType.CHINESE) -> None: + def __init__(self, llm: LLM, language: LanguageType = LanguageType.CHINESE) -> None: """初始化MCP选择器""" self._llm = llm self._language = language @@ -32,7 +32,7 @@ class MCPSelector: {"role": "user", "content": prompt}, ] result = "" - async for chunk in self._llm.reasoning.call(message): + async for chunk in self._llm.call(message): result += chunk.content or "" return result diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index ea3b01a8383a691533d6f6af1d0c1733774b950a..853c27c383902d717853f3f0832bcca7a7973a81 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -4,7 +4,7 @@ import logging from typing import Any -from apps.llm import LLMConfig, json_generator +from apps.llm import LLM, json_generator from apps.models import LanguageType from apps.schemas.task import TaskData @@ -15,11 +15,11 @@ class MCPBase: """MCP基类""" _user_id: str - _llm: LLMConfig + _llm: LLM _goal: str _language: LanguageType - def __init__(self, task: TaskData, llm: LLMConfig) -> None: + def __init__(self, task: TaskData, llm: LLM) -> None: """初始化MCP基类""" self._user_id = task.metadata.userId self._llm = llm @@ -34,7 +34,7 @@ class MCPBase: {"role": "user", "content": "Please provide a JSON response based on the above information and schema."}, ] result = "" - async for chunk in self._llm.reasoning.call( + async for chunk in self._llm.call( message, streaming=False, ): diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index e11b7286ff49efe0ca670382379f38232e9e29ce..8140ec04b55f8ca5ce0271e38577f17520a2be8b 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -212,7 +212,7 @@ class MCPPlanner(MCPBase): memory=memory, goal=self._goal, ) - async for chunk in self._llm.reasoning.call( + async for chunk in self._llm.call( [{"role": "user", "content": prompt}], streaming=True, ): diff --git a/apps/scheduler/scheduler/executor.py b/apps/scheduler/scheduler/executor.py index bfae212bbb20d4312495334a27f97c566b8ca8ab..4eebcd37dfb168f06a92108459d8892e8b560bbe 100644 --- a/apps/scheduler/scheduler/executor.py +++ b/apps/scheduler/scheduler/executor.py @@ -6,7 +6,7 @@ import logging import uuid from apps.common.queue import MessageQueue -from apps.llm import LLMConfig +from apps.llm import LLM from apps.models import AppType, ExecutorStatus from apps.scheduler.executor.agent import MCPAgentExecutor from apps.scheduler.executor.flow import FlowExecutor @@ -27,11 +27,11 @@ class ExecutorMixin: task: TaskData post_body: RequestData queue: MessageQueue - llm: LLMConfig + llm: LLM async def get_top_flow(self) -> str: """获取Top1 Flow (由FlowMixin实现)""" - ... + ... # noqa: PIE790 async def _determine_app_id(self) -> uuid.UUID | None: """确定最终使用的 app_id""" @@ -64,16 +64,7 @@ class ExecutorMixin: return final_app_id async def _create_executor_task(self, final_app_id: uuid.UUID | None) -> asyncio.Task | None: - """ - 根据 app_id 创建对应的执行器任务 - - Args: - final_app_id: 要使用的 app_id,None 表示使用 QA 模式 - - Returns: - 创建的异步任务,如果发生错误则返回 None - - """ + """根据 app_id 创建对应的执行器任务""" if final_app_id is None: _logger.info("[Scheduler] 运行智能问答模式") return asyncio.create_task(self._run_qa()) diff --git a/apps/scheduler/scheduler/init.py b/apps/scheduler/scheduler/init.py index bdd53a3e9d220867c0a464bd129d2144cd0abfc4..96a12d3f8796ed61c8e79c36c4f4104d82d96670 100644 --- a/apps/scheduler/scheduler/init.py +++ b/apps/scheduler/scheduler/init.py @@ -8,7 +8,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from apps.common.queue import MessageQueue -from apps.llm import LLM, LLMConfig, embedding +from apps.llm import LLM from apps.models import Task, TaskRuntime, User from apps.schemas.request_data import RequestData from apps.schemas.task import TaskData @@ -27,54 +27,21 @@ class InitializationMixin: user: User post_body: RequestData queue: MessageQueue - llm: LLMConfig + llm: LLM _env: SandboxedEnvironment - async def _get_scheduler_llm(self, reasoning_llm_id: str) -> LLMConfig: + async def _get_scheduler_llm(self, reasoning_llm_id: str) -> LLM: """获取RAG大模型""" reasoning_llm = await LLMManager.get_llm(reasoning_llm_id) if not reasoning_llm: err = "[Scheduler] 获取问答用大模型ID失败" _logger.error(err) raise ValueError(err) - reasoning_llm = LLM(reasoning_llm) - - function_llm = None - if not self.user.functionLLM: - _logger.error("[Scheduler] 用户 %s 没有设置函数调用大模型,相关功能将被禁用", self.user.id) - else: - function_llm = await LLMManager.get_llm(self.user.functionLLM) - if not function_llm: - _logger.error( - "[Scheduler] 用户 %s 设置的函数调用大模型ID %s 不存在,相关功能将被禁用", - self.user.id, self.user.functionLLM, - ) - else: - function_llm = LLM(function_llm) - - embedding_obj = None - if not self.user.embeddingLLM: - _logger.error("[Scheduler] 用户 %s 没有设置向量模型,相关功能将被禁用", self.user.id) - else: - embedding_llm_config = await LLMManager.get_llm(self.user.embeddingLLM) - if not embedding_llm_config: - _logger.error( - "[Scheduler] 用户 %s 设置的向量模型ID %s 不存在,相关功能将被禁用", - self.user.id, self.user.embeddingLLM, - ) - else: - await embedding.init(embedding_llm_config) - embedding_obj = embedding - - return LLMConfig( - reasoning=reasoning_llm, - function=function_llm, - embedding=embedding_obj, - ) + return LLM(reasoning_llm) - def _create_new_task(self, task_id: uuid.UUID, user_id: str, conversation_id: uuid.UUID | None) -> TaskData: + def _create_new_task(self, task_id: uuid.UUID, user_id: str, conversation_id: uuid.UUID | None) -> None: """创建新的TaskData""" - return TaskData( + self.task = TaskData( metadata=Task( id=task_id, userId=user_id, @@ -93,7 +60,7 @@ class InitializationMixin: if not conversation_id: _logger.info("[Scheduler] 无Conversation ID,直接创建新任务") - self.task = self._create_new_task(task_id, user_id, None) + self._create_new_task(task_id, user_id, None) return _logger.info("[Scheduler] 尝试从Conversation ID %s 恢复任务", conversation_id) @@ -128,7 +95,7 @@ class InitializationMixin: if not restored: _logger.info("[Scheduler] 无法恢复任务,创建新任务") - self.task = self._create_new_task(task_id, user_id, conversation_id) + self._create_new_task(task_id, user_id, conversation_id) async def _init_user(self, user_id: str) -> None: """初始化用户""" diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 8a0f0bd96d613a6fae56b85be42a5a0949564fca..a4c9230a534ac152fd3360fb9c0c14b88081c34b 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -2,9 +2,10 @@ """消息推送相关的Mixin类""" import logging +from typing import Any from apps.common.queue import MessageQueue -from apps.llm import LLMConfig +from apps.llm import LLM from apps.schemas.enum_var import EventType from apps.schemas.task import TaskData @@ -16,9 +17,23 @@ class MessagingMixin: queue: MessageQueue task: TaskData - llm: LLMConfig + llm: LLM async def _push_done_message(self) -> None: """推送任务完成消息""" _logger.info("[Scheduler] 发送结束消息") await self.queue.push_output(self.task, self.llm, event_type=EventType.DONE.value, data={}) + + async def _push_executor_start_message(self, data: dict[str, Any] | None = None) -> None: + """推送executor.start消息""" + _logger.info("[Scheduler] 发送executor.start消息") + if data is None: + data = {} + await self.queue.push_output(self.task, self.llm, event_type=EventType.EXECUTOR_START.value, data=data) + + async def _push_executor_stop_message(self, data: dict[str, Any] | None = None) -> None: + """推送executor.stop消息""" + _logger.info("[Scheduler] 发送executor.stop消息") + if data is None: + data = {} + await self.queue.push_output(self.task, self.llm, event_type=EventType.EXECUTOR_STOP.value, data=data) diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 5606ed92ba3cae8827a076af1b1f2772dc10faf7..9b35b01327491a531894fbb932dee255136d2f2a 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -25,21 +25,15 @@ class EventType(str, Enum): """事件类型""" HEARTBEAT = "heartbeat" - INIT = "init" TEXT_ADD = "text.add" GRAPH = "graph" STEP_WAITING_FOR_START = "step.waiting_for_start" STEP_WAITING_FOR_PARAM = "step.waiting_for_param" - FLOW_START = "flow.start" - STEP_INIT = "step.init" + EXECUTOR_START = "executor.start" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" - STEP_CANCEL = "step.cancel" - STEP_ERROR = "step.error" - FLOW_STOP = "flow.stop" - FLOW_FAILED = "flow.failed" - FLOW_SUCCESS = "flow.success" - FLOW_CANCEL = "flow.cancel" + STEP_END = "step.end" + EXECUTOR_STOP = "executor.stop" DONE = "done" diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 9f6dac514fffe28641da6135c5804bdd6b4d1aef..e149adb1895a44efd27133daf70ef8e0ee0bb51e 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -34,7 +34,7 @@ class MessageExecutor(BaseModel): executor_id: str = Field(description="Flow ID", alias="executorId") executor_name: str = Field(description="Flow名称", alias="executorName") executor_status: ExecutorStatus = Field( - description="Flow状态", alias="executorStatus", default=ExecutorStatus.UNKNOWN, + description="Flow状态", alias="executorStatus", default=ExecutorStatus.INIT, ) step_id: uuid.UUID = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") diff --git a/apps/services/llm.py b/apps/services/llm.py index 6eb3b6532d6ba58f59d66be02e7a0cf2dfe6da5f..8f99df31fc445297358cd6277562d1d07fcff671 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -6,7 +6,7 @@ import logging from sqlalchemy import select from apps.common.postgres import postgres -from apps.models import LLMData, User +from apps.models import GlobalSettings, LLMData from apps.schemas.request_data import UpdateLLMReq from apps.schemas.response_data import ( LLMAdminInfo, @@ -162,16 +162,13 @@ class LLMManager: await session.commit() async with postgres.session() as session: - # 清除所有FunctionLLM的引用 - user = list((await session.scalars( - select(User).where(User.functionLLM == llm_id), - )).all()) - for item in user: - item.functionLLM = None - # 清除所有EmbeddingLLM的引用 - user = list((await session.scalars( - select(User).where(User.embeddingLLM == llm_id), - )).all()) - for item in user: - item.embeddingLLM = None + # 清除GlobalSettings中的引用 + global_settings = (await session.scalars( + select(GlobalSettings), + )).all() + for settings in global_settings: + if settings.functionLlmId == llm_id: + settings.functionLlmId = None + if settings.embeddingLlmId == llm_id: + settings.embeddingLlmId = None await session.commit() diff --git a/design/executor/agent.md b/design/executor/agent.md index 623993062023b10d5e99c625031eb8f2b8221182..927754d7fa5a708484db8d2187a0d7d0eee66b4f 100644 --- a/design/executor/agent.md +++ b/design/executor/agent.md @@ -777,11 +777,11 @@ erDiagram } ExecutorStatus { - string value "UNKNOWN|INIT|WAITING|RUNNING|SUCCESS|ERROR|CANCELLED" + string value "INIT|WAITING|RUNNING|SUCCESS|ERROR|CANCELLED" } StepStatus { - string value "UNKNOWN|INIT|WAITING|RUNNING|SUCCESS|ERROR|PARAM|CANCELLED" + string value "INIT|WAITING|RUNNING|SUCCESS|ERROR|PARAM|CANCELLED" } ``` @@ -791,7 +791,6 @@ erDiagram classDiagram class ExecutorStatus { <> - UNKNOWN 未知状态 INIT 初始化 WAITING 等待用户输入 RUNNING 运行中 @@ -802,7 +801,6 @@ classDiagram class StepStatus { <> - UNKNOWN 未知状态 INIT 初始化 WAITING 等待用户确认 RUNNING 运行中 diff --git a/design/executor/step.md b/design/executor/step.md index 697b4fa0efcb3206f29a259e26e26f33768a273b..1498def85e70b2781fa09106513376a02965ac08 100644 --- a/design/executor/step.md +++ b/design/executor/step.md @@ -960,7 +960,6 @@ erDiagram classDiagram class StepStatus { <> - UNKNOWN 未知状态 INIT 初始化 WAITING 等待确认 RUNNING 运行中 @@ -972,7 +971,6 @@ classDiagram class StepType { <> - UNKNOWN 未知类型 START 开始节点 END 结束节点 AGENT 智能体节点