From 05fdf683466dc0e55d2dc13016295fb0ad3f3b8c Mon Sep 17 00:00:00 2001 From: z30057876 Date: Thu, 28 Aug 2025 17:44:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9LLM=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/llm/embedding.py | 160 ++++++++++++++++---- apps/llm/function.py | 11 +- apps/models/llm.py | 13 ++ apps/models/user.py | 2 - apps/models/vectors.py | 113 -------------- apps/routers/tag.py | 8 + apps/scheduler/call/core.py | 23 ++- apps/scheduler/executor/agent.py | 231 +++++++++++++++-------------- apps/scheduler/mcp_agent/plan.py | 5 +- apps/scheduler/mcp_agent/prompt.py | 9 +- apps/schemas/response_data.py | 13 ++ apps/schemas/scheduler.py | 11 ++ apps/services/llm.py | 13 +- apps/services/tag.py | 19 +++ 14 files changed, 350 insertions(+), 281 deletions(-) delete mode 100644 apps/models/vectors.py diff --git a/apps/llm/embedding.py b/apps/llm/embedding.py index 298cc9ff..cb0cfd50 100644 --- a/apps/llm/embedding.py +++ b/apps/llm/embedding.py @@ -3,38 +3,144 @@ import logging import httpx - -from apps.common.config import config - -logger = logging.getLogger(__name__) +from pgvector.sqlalchemy import Vector +from sqlalchemy import Column, ForeignKey, Index, String, text +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import declarative_base + +from apps.common.postgres import postgres +from apps.models.llm import EmbeddingBackend, LLMData + +_logger = logging.getLogger(__name__) +_VectorBase = declarative_base() +_flow_pool_vector_table = { + "__tablename__": "framework_flow_vector", + "appId": Column(UUID(as_uuid=True), ForeignKey("framework_app.id"), nullable=False), + "id": Column(String(255), ForeignKey("framework_flow.id"), primary_key=True), + "__table_args__": ( + Index( + "hnsw_index", + "embedding", + postgresql_using="hnsw", + postgresql_with={"m": 16, "ef_construction": 200}, + postgresql_ops={"embedding": "vector_cosine_ops"}, + ), + ), +} +_service_pool_vector_table = { + "__tablename__": "framework_service_vector", + "id": Column(UUID(as_uuid=True), ForeignKey("framework_service.id"), primary_key=True), + "__table_args__": ( + Index( + "hnsw_index", + "embedding", + postgresql_using="hnsw", + postgresql_with={"m": 16, "ef_construction": 200}, + postgresql_ops={"embedding": "vector_cosine_ops"}, + ), + ), +} +_node_pool_vector_table = { + "__tablename__": "framework_node_vector", + "id": Column(String(255), ForeignKey("framework_node.id"), primary_key=True), + "serviceId": Column(UUID(as_uuid=True), ForeignKey("framework_service.id"), nullable=True), + "__table_args__": ( + Index( + "hnsw_index", + "embedding", + postgresql_using="hnsw", + postgresql_with={"m": 16, "ef_construction": 200}, + postgresql_ops={"embedding": "vector_cosine_ops"}, + ), + ), +} +_mcp_vector_table = { + "__tablename__": "framework_mcp_vector", + "id": Column(String(255), ForeignKey("framework_mcp.id"), primary_key=True), + "__table_args__": ( + Index( + "hnsw_index", + "embedding", + postgresql_using="hnsw", + postgresql_with={"m": 16, "ef_construction": 200}, + postgresql_ops={"embedding": "vector_cosine_ops"}, + ), + ), +} +_mcp_tool_vector_table = { + "__tablename__": "framework_mcp_tool_vector", + "id": Column(String(255), ForeignKey("framework_mcp_tool.id"), primary_key=True), + "mcpId": Column(String(255), ForeignKey("framework_mcp.id"), nullable=False), + "__table_args__": ( + Index( + "hnsw_index", + "embedding", + postgresql_using="hnsw", + postgresql_with={"m": 16, "ef_construction": 200}, + postgresql_ops={"embedding": "vector_cosine_ops"}, + ), + ), +} class Embedding: """Embedding模型""" - # TODO: 应当自动检测向量维度 - @classmethod - async def _get_embedding_dimension(cls) -> int: + async def _get_embedding_dimension(self) -> int: """获取Embedding的维度""" - embedding = await cls.get_embedding(["测试文本"]) + embedding = await self.get_embedding(["测试文本"]) return len(embedding[0]) - - @classmethod - async def _get_openai_embedding(cls, text: list[str]) -> list[list[float]]: + async def _delete_vector(self) -> None: + """删除所有Vector表""" + async with postgres.session() as session: + await session.execute(text("DROP TABLE IF EXISTS framework_flow_vector")) + await session.execute(text("DROP TABLE IF EXISTS framework_service_vector")) + await session.execute(text("DROP TABLE IF EXISTS framework_node_vector")) + await session.execute(text("DROP TABLE IF EXISTS framework_mcp_vector")) + await session.execute(text("DROP TABLE IF EXISTS framework_mcp_tool_vector")) + await session.commit() + + async def _create_vector_table(self, dim: int) -> None: + """根据检测出的维度创建Vector表""" + if dim <= 0: + err = "[Embedding] 检测到的Embedding维度为0,无法创建Vector表" + _logger.error(err) + raise RuntimeError(err) + # 给所有的Dict加入embedding字段 + for table in [ + _flow_pool_vector_table, + _service_pool_vector_table, + _node_pool_vector_table, + _mcp_vector_table, + _mcp_tool_vector_table, + ]: + table["embedding"] = Column(Vector(dim), nullable=False) + # 创建表 + _VectorBase.metadata.create_all(_VectorBase.metadata) + + async def __init__(self, llm_config: LLMData | None = None) -> None: + """初始化Embedding模型""" + if not llm_config or not llm_config.embeddingBackend: + err = "[Embedding] 未设置Embedding模型" + _logger.error(err) + raise RuntimeError(err) + self._config: LLMData = llm_config + + async def _get_openai_embedding(self, text: list[str]) -> list[list[float]]: """访问OpenAI兼容的Embedding API,获得向量化数据""" - api = config.embedding.endpoint + "/embeddings" + api = self._config.openaiBaseUrl + "/embeddings" data = { "input": text, - "model": config.embedding.model, + "model": self._config.modelName, "encoding_format": "float", } headers = { "Content-Type": "application/json", } - if config.embedding.api_key: - headers["Authorization"] = f"Bearer {config.embedding.api_key}" + if self._config.openaiAPIKey: + headers["Authorization"] = f"Bearer {self._config.openaiAPIKey}" async with httpx.AsyncClient() as client: response = await client.post( @@ -47,15 +153,14 @@ class Embedding: return [item["embedding"] for item in json["data"]] - @classmethod - async def _get_tei_embedding(cls, text: list[str]) -> list[list[float]]: + async def _get_tei_embedding(self, text: list[str]) -> list[list[float]]: """访问TEI兼容的Embedding API,获得向量化数据""" - api = config.embedding.endpoint + "/embed" + api = self._config.openaiBaseUrl + "/embed" headers = { "Content-Type": "application/json", } - if config.embedding.api_key: - headers["Authorization"] = f"Bearer {config.embedding.api_key}" + if self._config.openaiAPIKey: + headers["Authorization"] = f"Bearer {self._config.openaiAPIKey}" async with httpx.AsyncClient() as client: result = [] @@ -73,8 +178,7 @@ class Embedding: return result - @classmethod - async def get_embedding(cls, text: list[str]) -> list[list[float]]: + async def get_embedding(self, text: list[str]) -> list[list[float]]: """ 访问OpenAI兼容的Embedding API,获得向量化数据 @@ -82,13 +186,13 @@ class Embedding: :return: 文本对应的向量(顺序与text一致,也为List) """ try: - if config.embedding.type == "openai": - return await cls._get_openai_embedding(text) - if config.embedding.type == "mindie": - return await cls._get_tei_embedding(text) + if self._config.embeddingBackend == EmbeddingBackend.OPENAI: + return await self._get_openai_embedding(text) + if self._config.embeddingBackend == EmbeddingBackend.TEI: + return await self._get_tei_embedding(text) - logger.error("不支持的Embedding API类型: %s", config.embedding.type) + _logger.error("[Embedding] 不支持的Embedding API类型: %s", self._config.modelName) return [[0.0] * 1024 for _ in range(len(text))] except Exception: - logger.exception("获取Embedding失败") + _logger.exception("[Embedding] 获取Embedding失败") return [[0.0] * 1024 for _ in range(len(text))] diff --git a/apps/llm/function.py b/apps/llm/function.py index 88ee8d69..b6766e49 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -38,17 +38,12 @@ class FunctionLLM: - structured_output """ # 暂存config;这里可以替代为从其他位置获取 - if not llm_config: - err = "未设置大模型配置" + if not llm_config or not llm_config.functionCallBackend: + err = "[FunctionLLM] 未设置FunctionCall模型" logger.error(err) raise RuntimeError(err) - self._config: LLMData = llm_config - - if not self._config.modelName: - err_msg = "[FunctionCall] 未设置FuntionCall所用模型!" - logger.error(err_msg) - raise ValueError(err_msg) + self._config: LLMData = llm_config self._params = { "model": self._config.modelName, "messages": [], diff --git a/apps/models/llm.py b/apps/models/llm.py index a819ea35..0e8801ca 100644 --- a/apps/models/llm.py +++ b/apps/models/llm.py @@ -39,6 +39,15 @@ class FunctionCallBackend(str, PyEnum): """Structured Output""" +class EmbeddingBackend(str, PyEnum): + """Embedding后端""" + + OPENAI = "openai" + """OpenAI""" + TEI = "tei" + """TEI""" + + class LLMData(Base): """大模型信息""" @@ -63,6 +72,10 @@ class LLMData(Base): Enum(FunctionCallBackend), default=None, nullable=True, ) """Function Call后端""" + embeddingBackend: Mapped[EmbeddingBackend | None] = mapped_column( # noqa: N815 + Enum(EmbeddingBackend), default=None, nullable=True, + ) + """Embedding后端""" createdAt: Mapped[DateTime] = mapped_column( # noqa: N815 DateTime, default_factory=lambda: datetime.now(tz=UTC), diff --git a/apps/models/user.py b/apps/models/user.py index abd844b3..b32a3cd5 100644 --- a/apps/models/user.py +++ b/apps/models/user.py @@ -36,8 +36,6 @@ class User(Base): """用户个人令牌""" selectedKB: Mapped[list[uuid.UUID]] = mapped_column(ARRAY(UUID), default=[], nullable=False) # noqa: N815 """用户选择的知识库的ID""" - reasoningLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815 - """用户选择的问答模型ID""" functionLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815 """用户选择的函数模型ID""" embeddingLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815 diff --git a/apps/models/vectors.py b/apps/models/vectors.py deleted file mode 100644 index 7378d07e..00000000 --- a/apps/models/vectors.py +++ /dev/null @@ -1,113 +0,0 @@ -"""向量表""" - -import uuid - -from pgvector.sqlalchemy import Vector -from sqlalchemy import ForeignKey, Index, String -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Mapped, mapped_column - -from .base import Base - - -class FlowPoolVector(Base): - """Flow向量数据""" - - __tablename__ = "framework_flow_vector" - appId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_app.id"), nullable=False) # noqa: N815 - """所属App的ID""" - embedding: Mapped[list[float]] = mapped_column(Vector(1024), nullable=False) - """向量数据""" - id: Mapped[str] = mapped_column(String(255), ForeignKey("framework_flow.id"), primary_key=True) - """Flow的ID""" - __table_args__ = ( - Index( - "hnsw_index", - "embedding", - postgresql_using="hnsw", - postgresql_with={"m": 16, "ef_construction": 200}, - postgresql_ops={"embedding": "vector_cosine_ops"}, - ), - ) - - -class ServicePoolVector(Base): - """Service向量数据""" - - __tablename__ = "framework_service_vector" - embedding: Mapped[list[float]] = mapped_column(Vector(1024), nullable=False) - """向量数据""" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_service.id"), primary_key=True) - """Service的ID""" - __table_args__ = ( - Index( - "hnsw_index", - "embedding", - postgresql_using="hnsw", - postgresql_with={"m": 16, "ef_construction": 200}, - postgresql_ops={"embedding": "vector_cosine_ops"}, - ), - ) - - -class NodePoolVector(Base): - """Node向量数据""" - - __tablename__ = "framework_node_vector" - embedding: Mapped[list[float]] = mapped_column(Vector(1024), nullable=False) - """向量数据""" - serviceId: Mapped[uuid.UUID | None] = mapped_column( # noqa: N815 - UUID(as_uuid=True), ForeignKey("framework_service.id"), nullable=True, - ) - """Service的ID""" - id: Mapped[str] = mapped_column(String(255), ForeignKey("framework_node.id"), primary_key=True) - """Node的ID""" - __table_args__ = ( - Index( - "hnsw_index", - "embedding", - postgresql_using="hnsw", - postgresql_with={"m": 16, "ef_construction": 200}, - postgresql_ops={"embedding": "vector_cosine_ops"}, - ), - ) - - -class MCPVector(Base): - """MCP向量数据""" - - __tablename__ = "framework_mcp_vector" - embedding: Mapped[list[float]] = mapped_column(Vector(1024), nullable=False) - """向量数据""" - id: Mapped[str] = mapped_column(String(255), ForeignKey("framework_mcp.id"), primary_key=True) - """MCP的ID""" - __table_args__ = ( - Index( - "hnsw_index", - "embedding", - postgresql_using="hnsw", - postgresql_with={"m": 16, "ef_construction": 200}, - postgresql_ops={"embedding": "vector_cosine_ops"}, - ), - ) - - -class MCPToolVector(Base): - """MCP工具向量数据""" - - __tablename__ = "framework_mcp_tool_vector" - embedding: Mapped[list[float]] = mapped_column(Vector(1024), nullable=False) - """向量数据""" - mcpId: Mapped[str] = mapped_column(String(255), ForeignKey("framework_mcp.id"), nullable=False, index=True) # noqa: N815 - """MCP的ID""" - id: Mapped[str] = mapped_column(String(255), ForeignKey("framework_mcp_tool.id"), primary_key=True) - """MCP工具的ID""" - __table_args__ = ( - Index( - "hnsw_index", - "embedding", - postgresql_using="hnsw", - postgresql_with={"m": 16, "ef_construction": 200}, - postgresql_ops={"embedding": "vector_cosine_ops"}, - ), - ) diff --git a/apps/routers/tag.py b/apps/routers/tag.py index 7d555311..2ad2e64f 100644 --- a/apps/routers/tag.py +++ b/apps/routers/tag.py @@ -31,8 +31,16 @@ router = APIRouter( @router.get("") async def get_user_tag( request: Request, + *, + user_only: bool = False, ) -> JSONResponse: """获取所有标签;传入user_sub的时候获取特定用户的标签信息,不传则获取所有标签""" + if user_only: + return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( + code=status.HTTP_200_OK, + message="[Tag] Get user tag success.", + result=await TagManager.get_tag_by_user_sub(request.state.user_sub), + ).model_dump(exclude_none=True, by_alias=True)) return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( code=status.HTTP_200_OK, message="[Tag] Get all tag success.", diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 3bf42c36..26791623 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -95,30 +95,30 @@ class CoreCall(BaseModel): @staticmethod def _assemble_call_vars(executor: "StepExecutor") -> CallVars: """组装CallVars""" - if not executor.state: + if not executor.task.state: err = "[CoreCall] 当前ExecutorState为空" logger.error(err) raise ValueError(err) history = {} history_order = [] - for item in executor.context: + for item in executor.task.context: history[item.stepId] = item history_order.append(item.stepId) return CallVars( - language=executor.runtime.language, + language=executor.task.runtime.language, ids=CallIds( - task_id=executor.task.id, - executor_id=executor.state.executorId, - session_id=executor.task.sessionId, - user_sub=executor.task.userSub, - app_id=executor.state.appId, + task_id=executor.task.metadata.id, + executor_id=executor.task.state.executorId, + session_id=executor.task.runtime.sessionId, + user_sub=executor.task.metadata.userSub, + app_id=executor.task.state.appId, ), question=executor.question, history=history, history_order=history_order, - summary=executor.runtime.reasoning, + summary=executor.task.runtime.reasoning, ) @@ -164,7 +164,6 @@ class CoreCall(BaseModel): await obj._set_input(executor) return obj - async def _set_input(self, executor: "StepExecutor") -> None: """获取Call的输入""" self._sys_vars = self._assemble_call_vars(executor) @@ -197,7 +196,7 @@ class CoreCall(BaseModel): async def _llm(self, messages: list[dict[str, Any]], *, streaming: bool = False) -> AsyncGenerator[str, None]: """Call可直接使用的LLM非流式调用""" user_sub = self._sys_vars.ids.user_sub - llm_id = await LLMManager.get_user_default_llm(user_sub) + llm_id = await LLMManager.get_user_selected_llm(user_sub) if not llm_id: err = f"[CoreCall] 用户{user_sub}未设置默认LLM" logger.error(err) @@ -230,7 +229,7 @@ class CoreCall(BaseModel): async def _json(self, messages: list[dict[str, Any]], schema: dict[str, Any]) -> dict[str, Any]: """Call可直接使用的JSON生成""" user_sub = self._sys_vars.ids.user_sub - llm_id = await LLMManager.get_user_default_llm(user_sub) + llm_id = await LLMManager.get_user_selected_llm(user_sub) if not llm_id: err = f"[CoreCall] 用户{user_sub}未设置默认LLM" logger.error(err) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 3a63914d..960fb534 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -98,9 +98,14 @@ class MCPAgentExecutor(BaseExecutor): async def get_tool_input_param(self, *, is_first: bool) -> None: """获取工具输入参数""" + if not self.task.state: + err = "[MCPAgentExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + if is_first: # 获取第一个输入参数 - mcp_tool = self.tools[self.task.state.toolId] + mcp_tool = self.tools[self.task.state.stepName] self.task.state.currentInput = await self.host.get_first_input_params( mcp_tool, self.task.runtime.userInput, self.task, ) @@ -112,7 +117,7 @@ class MCPAgentExecutor(BaseExecutor): else: params = {} params_description = "" - mcp_tool = self.tools[self.task.state.tool_id] + mcp_tool = self.tools[self.task.state.stepName] self.task.state.currentInput = await self.host.fill_params( mcp_tool, self.task.runtime.userInput, @@ -131,7 +136,7 @@ class MCPAgentExecutor(BaseExecutor): raise RuntimeError(err) # 发送确认消息 - mcp_tool = self.tools[self.task.state.tool_id] + mcp_tool = self.tools[self.task.state.stepName] confirm_message = await self.planner.get_tool_risk( mcp_tool, self.task.state.currentInput, "", self.resoning_llm, self.task.runtime.language, ) @@ -167,7 +172,7 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.executorStatus = ExecutorStatus.RUNNING self.task.state.stepStatus = StepStatus.RUNNING - mcp_tool = self.tools[self.task.state.toolId] + mcp_tool = self.tools[self.task.state.stepName] mcp_client = (await self.mcp_pool.get(mcp_tool.mcp_id, self.task.metadata.userSub)) try: output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.currentInput) @@ -180,7 +185,10 @@ class MCPAgentExecutor(BaseExecutor): except Exception as e: logger.exception("[MCPAgentExecutor] 执行步骤 %s 时发生错误", mcp_tool.name) self.task.state.stepStatus = StepStatus.ERROR - self.task.state.errorMessage = str(e) + self.task.state.errorMessage = { + "err_msg": str(e), + "data": self.task.state.currentInput, + } return logger.error(f"当前工具名称: {mcp_tool.name}, 输出参数: {output_params}") if output_params.isError: @@ -228,7 +236,7 @@ class MCPAgentExecutor(BaseExecutor): logger.error(err) raise RuntimeError(err) - mcp_tool = self.tools[self.task.state.toolId] + mcp_tool = self.tools[self.task.state.stepName] params_with_null = await self.planner.get_missing_param( mcp_tool, self.task.state.currentInput, @@ -289,12 +297,9 @@ class MCPAgentExecutor(BaseExecutor): tool_id=FINAL_TOOL_ID, description=FINAL_TOOL_ID, ) - tool_id = step.tool_id - step_name = FINAL_TOOL_ID if tool_id == FINAL_TOOL_ID else self.tools[tool_id].name step_description = step.description self.task.state.stepId = uuid.uuid4() - self.task.state.toolId = tool_id - self.task.state.stepName = step_name + self.task.state.stepName = step.tool_id self.task.state.stepDescription = step_description self.task.state.stepStatus = StepStatus.INIT self.task.state.currentInput = {} @@ -334,29 +339,34 @@ class MCPAgentExecutor(BaseExecutor): async def work(self) -> None: """执行当前步骤""" - if self.state.stepStatus == StepStatus.INIT: + if not self.task.state: + err = "[MCPAgentExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + + if self.task.state.stepStatus == StepStatus.INIT: await self.push_message( EventType.STEP_INIT, data={}, ) await self.get_tool_input_param(is_first=True) - user_info = await UserManager.get_userinfo_by_user_sub(self.task.userSub) + user_info = await UserManager.get_user(self.task.metadata.userSub) if not user_info.auto_execute: # 等待用户确认 await self.confirm_before_step() return - self.state.stepStatus = StepStatus.RUNNING - elif self.state.stepStatus in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]: - if self.state.stepStatus == StepStatus.PARAM: - if len(self.context) and self.context[-1].stepId == self.state.stepId: - del self.context[-1] - elif self.state.stepStatus == StepStatus.WAITING: + self.task.state.stepStatus = StepStatus.RUNNING + elif self.task.state.stepStatus in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]: + if self.task.state.stepStatus == StepStatus.PARAM: + if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: + del self.task.context[-1] + elif self.task.state.stepStatus == StepStatus.WAITING: if self.params: - if len(self.context) and self.context[-1].stepId == self.state.stepId: - del self.context[-1] + if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: + del self.task.context[-1] else: - self.state.executorStatus = ExecutorStatus.CANCELLED - self.state.stepStatus = StepStatus.CANCELLED + self.task.state.executorStatus = ExecutorStatus.CANCELLED + self.task.state.stepStatus = StepStatus.CANCELLED await self.push_message( EventType.STEP_CANCEL, data={}, @@ -365,17 +375,17 @@ class MCPAgentExecutor(BaseExecutor): EventType.FLOW_CANCEL, data={}, ) - if len(self.context) and self.context[-1].stepId == self.state.stepId: - self.context[-1].stepStatus = StepStatus.CANCELLED + if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: + self.task.context[-1].stepStatus = StepStatus.CANCELLED return max_retry = 5 for i in range(max_retry): if i != 0: await self.get_tool_input_param(is_first=True) await self.run_step() - if self.state.stepStatus == StepStatus.SUCCESS: + if self.task.state.stepStatus == StepStatus.SUCCESS: break - elif self.state.stepStatus == StepStatus.ERROR: + elif self.task.state.stepStatus == StepStatus.ERROR: # 错误处理 if self.task.state.retry_times >= 3: await self.error_handle_after_step() @@ -385,42 +395,42 @@ class MCPAgentExecutor(BaseExecutor): await self.push_message( EventType.STEP_ERROR, data={ - "message": self.state.errorMessage, + "message": self.task.state.errorMessage, } ) - if len(self.context) and self.context[-1].stepId == self.state.stepId: - self.context[-1].stepStatus = StepStatus.ERROR - self.context[-1].outputData = { - "message": self.state.errorMessage, + if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: + self.task.context[-1].stepStatus = StepStatus.ERROR + self.task.context[-1].outputData = { + "message": self.task.state.errorMessage, } else: - self.context.append( + self.task.context.append( ExecutorHistory( - task_id=self.task.id, - step_id=self.task.state.step_id, - step_name=self.task.state.step_name, - step_description=self.task.state.step_description, - step_status=StepStatus.ERROR, - flow_id=self.task.state.flow_id, - flow_name=self.task.state.flow_name, - flow_status=self.task.state.flow_status, - input_data=self.task.state.current_input, - output_data={ - "message": self.task.state.error_message, + taskId=self.task.metadata.id, + stepId=self.task.state.stepId, + stepName=self.task.state.stepName, + stepDescription=self.task.state.stepDescription, + stepStatus=StepStatus.ERROR, + executorId=self.task.state.executorId, + executorName=self.task.state.executorName, + executorStatus=self.task.state.executorStatus, + inputData=self.task.state.currentInput, + outputData={ + "message": self.task.state.errorMessage, }, - ) + ), ) await self.get_next_step() else: - mcp_tool = self.tools[self.state.toolId] - is_param_error = await MCPPlanner.is_param_error( - self.runtime.question, - await MCPHost.assemble_memory(self.task), - self.state.errorMessage, + mcp_tool = self.tools[self.task.state.toolId] + is_param_error = await self.planner.is_param_error( + self.task.runtime.question, + await self.host.assemble_memory(self.task.runtime, self.task.context), + self.task.state.errorMessage, mcp_tool, - self.state.stepDescription, - self.state.currentInput, - language=self.runtime.language, + self.task.state.stepDescription, + self.task.state.currentInput, + language=self.task.runtime.language, ) if is_param_error.is_param_error: # 如果是参数错误,生成参数补充 @@ -429,78 +439,83 @@ class MCPAgentExecutor(BaseExecutor): await self.push_message( EventType.STEP_ERROR, data={ - "message": self.state.errorMessage, - } + "message": self.task.state.errorMessage, + }, ) - if len(self.context) and self.context[-1].stepId == self.state.stepId: - self.context[-1].stepStatus = StepStatus.ERROR - self.context[-1].outputData = { - "message": self.state.errorMessage, + if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: + self.task.context[-1].stepStatus = StepStatus.ERROR + self.task.context[-1].outputData = { + "message": self.task.state.errorMessage, } else: self.task.context.append( - FlowStepHistory( - task_id=self.task.id, - step_id=self.task.state.step_id, - step_name=self.task.state.step_name, - step_description=self.task.state.step_description, - step_status=StepStatus.ERROR, - flow_id=self.task.state.flow_id, - flow_name=self.task.state.flow_name, - flow_status=self.task.state.flow_status, - input_data=self.task.state.current_input, - output_data={ - "message": self.task.state.error_message, + ExecutorHistory( + taskId=self.task.metadata.id, + stepId=self.task.state.stepId, + stepName=self.task.state.stepName, + stepDescription=self.task.state.stepDescription, + stepStatus=StepStatus.ERROR, + executorId=self.task.state.executorId, + executorName=self.task.state.executorName, + executorStatus=self.task.state.executorStatus, + inputData=self.task.state.currentInput, + outputData={ + "message": self.task.state.errorMessage, }, ), ) await self.get_next_step() - elif self.state.stepStatus == StepStatus.SUCCESS: + elif self.task.state.stepStatus == StepStatus.SUCCESS: await self.get_next_step() async def summarize(self) -> None: """总结""" - async for chunk in MCPPlanner.generate_answer( - self.runtime.userInput, - (await MCPHost.assemble_memory(self.task)), + async for chunk in self.planner.generate_answer( + self.task.runtime.userInput, + (await self.host.assemble_memory(self.task.runtime, self.task.context)), self.resoning_llm, - self.runtime.language, + self.task.runtime.language, ): await self.push_message( EventType.TEXT_ADD, data=chunk, ) - self.runtime.fullAnswer += chunk + self.task.runtime.fullAnswer += chunk async def run(self) -> None: """执行MCP Agent的主逻辑""" + if not self.task.state: + err = "[MCPAgentExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + # 初始化MCP服务 await self.load_state() await self.load_mcp() - if self.state.executorStatus == ExecutorStatus.INIT: + if self.task.state.executorStatus == ExecutorStatus.INIT: # 初始化状态 try: - self.state.executorId = str(uuid.uuid4()) - self.state.executorName = (await MCPPlanner.get_flow_name( - self.task.runtime.question, self.resoning_llm, self.task.language + self.task.state.executorId = str(uuid.uuid4()) + self.task.state.executorName = (await self.planner.get_flow_name( + self.task.runtime.question, self.resoning_llm, self.task.runtime.language )).flow_name - await TaskManager.save_task(self.task.id, self.task) + await TaskManager.save_task(self.task.metadata.id, self.task) await self.get_next_step() except Exception as e: logger.exception("[MCPAgentExecutor] 初始化失败") - self.task.state.flow_status = FlowStatus.ERROR - self.task.state.error_message = str(e) + self.task.state.executorStatus = ExecutorStatus.ERROR + self.task.state.errorMessage = str(e) await self.push_message( EventType.FLOW_FAILED, data={}, ) return - self.state.executorStatus = ExecutorStatus.RUNNING + self.task.state.executorStatus = ExecutorStatus.RUNNING await self.push_message( EventType.FLOW_START, data={}, ) - if self.state.toolId == FINAL_TOOL_ID: + if self.task.state.toolId == FINAL_TOOL_ID: # 如果已经是最后一步,直接结束 self.state.executorStatus = ExecutorStatus.SUCCESS await self.push_message( @@ -510,15 +525,15 @@ class MCPAgentExecutor(BaseExecutor): await self.summarize() return try: - while self.state.executorStatus == ExecutorStatus.RUNNING: + while self.task.state.executorStatus == ExecutorStatus.RUNNING: if self.state.toolId == FINAL_TOOL_ID: break await self.work() - await TaskManager.save_task(self.task.id, self.task) + await TaskManager.save_task(self.task.metadata.id, self.task) if self.state.toolId == FINAL_TOOL_ID: # 如果已经是最后一步,直接结束 - self.state.executorStatus = ExecutorStatus.SUCCESS - self.state.stepStatus = StepStatus.SUCCESS + self.task.state.executorStatus = ExecutorStatus.SUCCESS + self.task.state.stepStatus = StepStatus.SUCCESS await self.push_message( EventType.FLOW_SUCCESS, data={}, @@ -526,9 +541,9 @@ class MCPAgentExecutor(BaseExecutor): await self.summarize() except Exception as e: logger.exception("[MCPAgentExecutor] 执行过程中发生错误") - self.state.executorName = ExecutorStatus.ERROR - self.state.errorMessage = str(e) - self.state.stepStatus = StepStatus.ERROR + self.task.state.executorName = ExecutorStatus.ERROR + self.task.state.errorMessage = str(e) + self.task.state.stepStatus = StepStatus.ERROR await self.push_message( EventType.STEP_ERROR, data={}, @@ -537,25 +552,25 @@ class MCPAgentExecutor(BaseExecutor): EventType.FLOW_FAILED, data={}, ) - if len(self.context) and self.context[-1].stepId == self.state.stepId: - del self.context[-1] - self.context.append( + if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: + del self.task.context[-1] + self.task.context.append( ExecutorHistory( - task_id=self.task.id, - step_id=self.task.state.step_id, - step_name=self.task.state.step_name, - step_description=self.task.state.step_description, - step_status=self.task.state.step_status, - flow_id=self.task.state.flow_id, - flow_name=self.task.state.flow_name, - flow_status=self.task.state.flow_status, - input_data={}, - output_data={}, + taskId=self.task.metadata.id, + stepId=self.task.state.stepId, + stepName=self.task.state.stepName, + stepDescription=self.task.state.stepDescription, + stepStatus=self.task.state.stepStatus, + executorId=self.task.state.executorId, + executorName=self.task.state.executorName, + executorStatus=self.task.state.executorStatus, + inputData={}, + outputData={}, ), ) finally: for mcp_service in self.mcp_list: try: - await self.mcp_pool.stop(mcp_service.id, self.task.ids.user_sub) - except Exception as e: - logger.error("[MCPAgentExecutor] 停止MCP客户端时发生错误: %s", traceback.format_exc()) + await self.mcp_pool.stop(mcp_service.id, self.task.metadata.userSub) + except Exception: + logger.exception("[MCPAgentExecutor] 停止MCP客户端时发生错误") diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index c658aa7f..33c626ac 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -44,11 +44,10 @@ class MCPPlanner(MCPBase): goal: str llm: ReasoningLLM - def __init__(self, goal: str, llm: ReasoningLLM, language: LanguageType) -> None: + def __init__(self, user_sub: str, goal: str, language: LanguageType) -> None: """初始化MCPPlanner""" - super().__init__() + super().__init__(user_sub) self.goal = goal - self.llm = llm self.language = language async def get_flow_name(self) -> FlowName: diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 4577e178..b1a9d16f 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -679,7 +679,7 @@ GET_MISSING_PARAMS: dict[LanguageType, str] = { "required": ["host", "port", "username", "password"] } # 运行报错 - 执行端口扫描命令时,出现了错误:`password is not correct`。 + {"err_msg": "执行端口扫描命令时,出现了错误:`password is not correct`。", "data": {} } # 输出 ```json { @@ -763,8 +763,9 @@ input parameters, input parameter schema, and runtime errors, and output a JSON- }, "required": ["host", "port", "username", "password"] } - # Run error - When executing the port scan command, an error occurred: `password is not correct`. + # Error info + {"err_msg": "When executing the port scan command, an error occurred: `password is not correct`.", \ +"data": {} } # Output ```json { @@ -783,7 +784,7 @@ input parameters, input parameter schema, and runtime errors, and output a JSON- {{input_param}} # Tool input parameter schema (some fields can be null) {{input_schema}} - # Run error + # Error info {{error_message}} # Output """, diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index 9d2ef79d..ac630b3f 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -632,3 +632,16 @@ class GetOperaRsp(ResponseData): """GET /api/operate 返回数据结构""" result: list[OperateAndBindType] = Field(..., title="Result") + + +class UserSelectedLLMData(BaseModel): + """用户选择的LLM数据结构""" + + functionLLM: str | None = Field(default=None, description="函数模型ID") # noqa: N815 + embeddingLLM: str | None = Field(default=None, description="向量模型ID") # noqa: N815 + + +class UserSelectedLLMRsp(ResponseData): + """GET /api/user/llm 返回数据结构""" + + result: UserSelectedLLMData = Field(..., title="Result") diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py index d52bf85a..6c495d8c 100644 --- a/apps/schemas/scheduler.py +++ b/apps/schemas/scheduler.py @@ -6,11 +6,22 @@ from typing import Any from pydantic import BaseModel, Field +from apps.llm.embedding import Embedding +from apps.llm.function import FunctionLLM +from apps.llm.reasoning import ReasoningLLM from apps.models.task import ExecutorHistory from .enum_var import CallOutputType, LanguageType +class LLMConfig(BaseModel): + """LLM配置""" + + reasoning: ReasoningLLM = Field(description="推理LLM") + function: FunctionLLM = Field(description="函数LLM") + embedding: Embedding = Field(description="Embedding") + + class CallInfo(BaseModel): """Call的名称和描述""" diff --git a/apps/services/llm.py b/apps/services/llm.py index ceb76c6d..1c9b8b87 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -12,7 +12,11 @@ from apps.schemas.request_data import ( UpdateLLMReq, UpdateUserSpecialLLMReq, ) -from apps.schemas.response_data import LLMProvider, LLMProviderInfo +from apps.schemas.response_data import ( + LLMProvider, + LLMProviderInfo, + UserSelectedLLMData, +) from apps.templates.generate_llm_operator_config import llm_provider_dict logger = logging.getLogger(__name__) @@ -41,7 +45,7 @@ class LLMManager: @staticmethod - async def get_user_default_llm(user_sub: str) -> str | None: + async def get_user_selected_llm(user_sub: str) -> UserSelectedLLMData | None: """ 通过用户ID获取大模型ID @@ -56,7 +60,10 @@ class LLMManager: logger.error("[LLMManager] 用户 %s 不存在", user_sub) return None - return user.reasoningLLM + return UserSelectedLLMData( + functionLLM=user.functionLLM, + embeddingLLM=user.embeddingLLM, + ) @staticmethod diff --git a/apps/services/tag.py b/apps/services/tag.py index f0c03151..b3748d79 100644 --- a/apps/services/tag.py +++ b/apps/services/tag.py @@ -8,6 +8,7 @@ from sqlalchemy import select from apps.common.postgres import postgres from apps.models.tag import Tag +from apps.models.user import UserTag from apps.schemas.request_data import PostTagData logger = logging.getLogger(__name__) @@ -40,6 +41,24 @@ class TagManager: return (await session.scalars(select(Tag).where(Tag.name == name).limit(1))).one_or_none() + @staticmethod + async def get_tag_by_user_sub(user_sub: str) -> list[Tag]: + """ + 根据用户sub获取用户标签信息 + + :param user_sub: 用户sub + :return: 用户标签信息 + """ + tags = [] + async with postgres.session() as session: + user_tags = (await session.scalars(select(UserTag).where(UserTag.userSub == user_sub))).all() + for user_tag in user_tags: + tag = (await session.scalars(select(Tag).where(Tag.id == user_tag.tag))).one_or_none() + if tag: + tags.append(tag) + return list(tags) + + @staticmethod async def add_tag(data: PostTagData) -> None: """ -- Gitee