From 8920aad01590312022ecf6c0865ad71a356abe91 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 9 Sep 2025 20:48:32 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=81=E7=A7=BBRAG=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=9B=E8=AE=A9MCP=20Node=E5=8E=86=E5=8F=B2=E4=B8=8D?= =?UTF-8?q?=E5=86=8D=E5=88=86=E7=A6=BB=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/llm/function.py | 1 - apps/scheduler/call/llm/prompt.py | 57 ---------- apps/scheduler/call/mcp/mcp.py | 6 +- apps/scheduler/call/rag/prompt.py | 105 ----------------- apps/scheduler/call/rag/rag.py | 30 +++-- apps/scheduler/call/rag/schema.py | 35 +++--- apps/scheduler/call/sql/schema.py | 2 +- apps/scheduler/call/sql/sql.py | 138 ++++++++--------------- apps/scheduler/executor/agent.py | 20 ++-- apps/scheduler/executor/flow.py | 4 +- apps/scheduler/executor/prompt.py | 173 ++++++++++++++++++++++++++++- apps/scheduler/executor/qa.py | 124 +++++++++++++++------ apps/scheduler/mcp/host.py | 61 ++++------ apps/scheduler/mcp/prompt.py | 4 +- apps/scheduler/mcp/select.py | 20 +++- apps/scheduler/mcp_agent/host.py | 2 +- apps/scheduler/mcp_agent/plan.py | 13 +-- apps/scheduler/mcp_agent/prompt.py | 8 +- apps/schemas/mcp.py | 11 ++ apps/schemas/rag_data.py | 29 +---- apps/services/rag.py | 140 ----------------------- 21 files changed, 413 insertions(+), 570 deletions(-) delete mode 100644 apps/services/rag.py diff --git a/apps/llm/function.py b/apps/llm/function.py index 491087916..00d622ca2 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -25,7 +25,6 @@ class FunctionLLM: """用于FunctionCall的模型""" timeout: float = 30.0 - config: LLMData def __init__(self, llm_config: LLMData | None = None) -> None: """ diff --git a/apps/scheduler/call/llm/prompt.py b/apps/scheduler/call/llm/prompt.py index 397ab84e1..13db1349c 100644 --- a/apps/scheduler/call/llm/prompt.py +++ b/apps/scheduler/call/llm/prompt.py @@ -69,60 +69,3 @@ LLM_DEFAULT_PROMPT: str = dedent( 现在,输出你的回答: """, ).strip("\n") -LLM_ERROR_PROMPT: dict[LanguageType, str] = { - LanguageType.CHINESE: dedent( - r""" - - 你是一位智能助手,能够根据用户的问题,使用Python工具获取信息,并作出回答。你在使用工具解决回答用户的问题时,发生了错误。 - 你的任务是:分析工具(Python程序)的异常信息,分析造成该异常可能的原因,并以通俗易懂的方式,将原因告知用户。 - - 当前时间:{{ time }},可以作为时间参照。 - 发生错误的程序异常信息将在中给出,用户的问题将在中给出,上下文背景信息将在中给出。 - 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息。 - - - - {{ error_info }} - - - - {{ question }} - - - - {{ context }} - - - 现在,输出你的回答: - """, - ).strip("\n"), - LanguageType.ENGLISH: dedent( - r""" - - You are an intelligent assistant. When using Python tools to answer user questions, an error occurred. - Your task is: Analyze the exception information of the tool (Python program), analyze the possible \ -causes of the error, and inform the user in an easy-to-understand way. - - Current time: {{ time }}, which can be used as a reference. - The program exception information that occurred will be given in , the user's question \ -will be given in , and the context background information will be given in . - Note: Do not include any XML tags in the output. Do not make up any information. If you think the \ -user's question is unrelated to the background information, please ignore the background information. - - - - {{ error_info }} - - - - {{ question }} - - - - {{ context }} - - - Now, please output your answer: - """, - ).strip("\n"), -} diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index eb2e33302..92d6c98f7 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -86,8 +86,10 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """初始化MCP""" # 获取MCP交互类 self._host = MCPHost( + call_vars.ids.user_sub, call_vars.ids.task_id, - call_vars.ids.executor_id, + self._llm_obj, + self._sys_vars.language, ) self._tool_list = await self._host.get_tool_list(self.mcp_list) self._call_vars = call_vars @@ -125,7 +127,7 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): yield self._create_output(MCP_GENERATE["START"][self._sys_vars.language], MCPMessageType.PLAN_BEGIN) # 选择工具并生成计划 - selector = MCPSelector() + selector = MCPSelector(self._llm_obj) top_tool = await selector.select_top_tool(self._call_vars.question, self.mcp_list) planner = MCPPlanner(self._call_vars.question, language=self._sys_vars.language) self._plan = await planner.create_plan(top_tool, self.max_steps) diff --git a/apps/scheduler/call/rag/prompt.py b/apps/scheduler/call/rag/prompt.py index b574555b2..d32f1c329 100644 --- a/apps/scheduler/call/rag/prompt.py +++ b/apps/scheduler/call/rag/prompt.py @@ -122,108 +122,3 @@ application scenarios." """).strip("\n"), } - -GEN_RAG_ANSWER: dict[LanguageType, str] = { - LanguageType.CHINESE: r""" - - 你是openEuler社区的智能助手。请结合给出的背景信息, 回答用户的提问,并且基于给出的背景信息在相关句子后\ -进行脚注。 - 一个例子将在中给出。 - 上下文背景信息将在中给出。 - 用户的提问将在中给出。 - 注意: - 1.输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。 - 2.脚注的格式为[[1]],[[2]],[[3]]等,脚注的内容为提供的文档的id。 - 3.脚注只出现在回答的句子的末尾,例如句号、问号等标点符号后面。 - 4.不要对脚注本身进行解释或说明。 - 5.请不要使用中的文档的id作为脚注。 - - - - - - openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。 - - - openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。 - - - - - openEuler社区的成员来自世界各地,包括开发者、用户和企业。 - - - openEuler社区的成员共同努力,推动开源操作系统的发展,并且为用户提供支持和帮助。 - - - - - openEuler社区的目标是什么? - - - openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。[[1]] openEuler社区的目标是为\ -用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。[[1]] - - - - - {bac_info} - - - {user_question} - - """, - LanguageType.ENGLISH: r""" - - You are a helpful assistant of openEuler community. Please answer the user's question based on the \ -given background information and add footnotes after the related sentences. - An example will be given in . - The background information will be given in . - The user's question will be given in . - Note: - 1. Do not include any XML tags in the output, and do not make up any information. If you think the \ -user's question is unrelated to the background information, please ignore the background information and directly \ -answer. - 2. Your response should not exceed 250 words. - - - - - - openEuler community is an open source operating system community, committed to promoting \ -the development of the Linux operating system. - - - openEuler community aims to provide users with a stable, secure, and efficient operating \ -system platform, and support multiple hardware architectures. - - - - - Members of the openEuler community come from all over the world, including developers, \ -users, and enterprises. - - - Members of the openEuler community work together to promote the development of open \ -source operating systems, and provide support and assistance to users. - - - - - What is the goal of openEuler community? - - - openEuler community is an open source operating system community, committed to promoting the \ -development of the Linux operating system. [[1]] openEuler community aims to provide users with a stable, secure, \ -and efficient operating system platform, and support multiple hardware architectures. [[1]] - - - - - {bac_info} - - - {user_question} - - """, - } diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py index d69fbe1ce..8bd842ee9 100644 --- a/apps/scheduler/call/rag/rag.py +++ b/apps/scheduler/call/rag/rag.py @@ -2,6 +2,7 @@ """RAG工具:查询知识库""" import logging +import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime from typing import Any @@ -13,6 +14,7 @@ from jinja2.sandbox import SandboxedEnvironment from pydantic import Field from apps.common.config import config +from apps.llm.token import TokenCalculator from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType, LanguageType from apps.schemas.scheduler import ( @@ -37,9 +39,9 @@ _logger = logging.getLogger(__name__) class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): """RAG工具:查询知识库""" - knowledge_base_ids: list[str] = Field(description="知识库的id列表", default=[]) + kb_ids: list[uuid.UUID] = Field(description="知识库的id列表", default=[]) top_k: int = Field(description="返回的分片数量", default=5) - document_ids: list[str] | None = Field(description="文档id列表", default=None) + doc_ids: list[str] | None = Field(description="文档id列表", default=None) search_method: str = Field(description="检索方法", default=SearchMethod.KEYWORD_AND_VECTOR.value) is_related_surrounding: bool = Field(description="是否关联上下文", default=True) is_classify_by_doc: bool = Field(description="是否按文档分类", default=False) @@ -47,6 +49,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): is_compress: bool = Field(description="是否压缩", default=False) tokens_limit: int = Field(description="token限制", default=8192) + @classmethod def info(cls, language: LanguageType = LanguageType.CHINESE) -> CallInfo: """返回Call的名称和描述""" @@ -61,6 +64,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): } return i18n_info[language] + async def _init(self, call_vars: CallVars) -> RAGInput: """初始化RAG工具""" if not call_vars.ids.session_id: @@ -70,10 +74,10 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): return RAGInput( session_id=call_vars.ids.session_id, - kbIds=self.knowledge_base_ids, + kbIds=self.kb_ids, topK=self.top_k, query=call_vars.question, - docIds=self.document_ids, + docIds=self.doc_ids, searchMethod=self.search_method, isRelatedSurrounding=self.is_related_surrounding, isClassifyByDoc=self.is_classify_by_doc, @@ -82,6 +86,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): tokensLimit=self.tokens_limit, ) + async def _assemble_doc_info( self, doc_chunk_list: list[dict[str, Any]], @@ -92,7 +97,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): doc_info_list = [] doc_cnt = 0 doc_id_map = {} - remaining_tokens = max_tokens * 0.8 + remaining_tokens = round(max_tokens * 0.8) for doc_chunk in doc_chunk_list: if doc_chunk["docId"] not in doc_id_map: @@ -127,7 +132,8 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): break chunk_text = chunk["text"] chunk_text = TokenCalculator().get_k_tokens_words_from_content( - content=chunk_text, k=remaining_tokens) + content=chunk_text, k=remaining_tokens, + ) remaining_tokens -= TokenCalculator().calculate_token_length(messages=[ {"role": "user", "content": ""}, {"role": "user", "content": chunk_text}, @@ -141,6 +147,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): bac_info += "" return bac_info, doc_info_list + async def _get_doc_info(self, doc_ids: list[str], data: RAGInput) -> AsyncGenerator[CallOutputChunk, None]: """获取文档信息""" url = config.rag.rag_service.rstrip("/") + "/chunk/search" @@ -151,7 +158,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): doc_chunk_list = [] if doc_ids: default_kb_id = "00000000-0000-0000-0000-000000000000" - tmp_data = RAGQueryReq( + tmp_data = RAGInput( kbIds=[default_kb_id], query=data.query, topK=data.top_k, @@ -184,6 +191,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): _logger.exception("[RAG] 获取文档分片失败") return doc_chunk_list + async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """调用RAG工具""" data = RAGInput(**input_data) @@ -196,14 +204,14 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): lstrip_blocks=True, ) tmpl = env.from_string(QUESTION_REWRITE[self._sys_vars.language]) - prompt = tmpl.render(history="", question=data.question) + prompt = tmpl.render(history="", question=data.query) # 使用_json方法直接获取JSON结果 json_result = await self._json([ {"role": "user", "content": prompt}, ], schema=QuestionRewriteOutput.model_json_schema()) # 直接使用解析后的JSON结果 - data.question = QuestionRewriteOutput.model_validate(json_result).question + data.query = QuestionRewriteOutput.model_validate(json_result).question except Exception: _logger.exception("[RAG] 问题重写失败,使用原始问题") @@ -232,7 +240,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): yield CallOutputChunk( type=CallOutputType.DATA, content=RAGOutput( - question=data.question, + question=data.query, corpus=corpus, ).model_dump(exclude_none=True, by_alias=True), ) @@ -244,7 +252,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): raise CallError( message=f"rag调用失败:{text}", data={ - "question": data.question, + "question": data.query, "status": response.status_code, "text": text, }, diff --git a/apps/scheduler/call/rag/schema.py b/apps/scheduler/call/rag/schema.py index f3a6ec69b..bceedebb1 100644 --- a/apps/scheduler/call/rag/schema.py +++ b/apps/scheduler/call/rag/schema.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """RAG工具的输入和输出""" +import uuid from enum import Enum from pydantic import BaseModel, Field @@ -37,26 +38,18 @@ class RAGOutput(DataBase): class RAGInput(DataBase): """RAG工具的输入""" + # 来自RAGInput的独有字段 session_id: str = Field(description="会话id") - knowledge_base_ids: list[str] = Field(description="知识库的id列表", default=[], alias="kbIds") - top_k: int = Field(description="返回的分片数量", default=5, alias="topK") - question: str = Field(description="用户输入", default="", alias="query") - document_ids: list[str] | None = Field(description="文档id列表", default=None, alias="docIds") - search_method: str = Field( - description="检索方法", - default=SearchMethod.KEYWORD_AND_VECTOR.value, - alias="searchMethod", - ) - is_related_surrounding: bool = Field( - description="是否关联上下文", - default=True, - alias="isRelatedSurrounding", - ) - is_classify_by_doc: bool = Field( - description="是否按文档分类", - default=True, - alias="isClassifyByDoc", - ) - is_rerank: bool = Field(description="是否重新排序", default=False, alias="isRerank") is_compress: bool = Field(description="是否压缩", default=False, alias="isCompress") - tokens_limit: int = Field(description="token限制", default=8192, alias="tokensLimit") + + # 来自RAGQueryReq的字段(以RAGQueryReq为准) + kb_ids: list[uuid.UUID] = Field(default=[], description="资产id", alias="kbIds") + query: str = Field(default="", description="查询内容") + top_k: int = Field(default=5, description="返回的结果数量", alias="topK") + doc_ids: list[str] | None = Field(default=None, description="文档id", alias="docIds") + search_method: str = Field(default="dynamic_weighted_keyword_and_vector", + description="检索方法", alias="searchMethod") + is_related_surrounding: bool = Field(default=True, description="是否关联上下文", alias="isRelatedSurrounding") + is_classify_by_doc: bool = Field(default=True, description="是否按文档分类", alias="isClassifyByDoc") + is_rerank: bool = Field(default=False, description="是否重新排序", alias="isRerank") + tokens_limit: int | None = Field(default=None, description="token限制", alias="tokensLimit") diff --git a/apps/scheduler/call/sql/schema.py b/apps/scheduler/call/sql/schema.py index 06ffb4f06..1fc4cfd0a 100644 --- a/apps/scheduler/call/sql/schema.py +++ b/apps/scheduler/call/sql/schema.py @@ -17,5 +17,5 @@ class SQLInput(DataBase): class SQLOutput(DataBase): """SQL工具的输出""" - dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果") + result: list[dict[str, Any]] = Field(description="SQL工具的执行结果") sql: str = Field(description="SQL语句") diff --git a/apps/scheduler/call/sql/sql.py b/apps/scheduler/call/sql/sql.py index abe575a11..99396d0d9 100644 --- a/apps/scheduler/call/sql/sql.py +++ b/apps/scheduler/call/sql/sql.py @@ -23,10 +23,6 @@ from .schema import SQLInput, SQLOutput logger = logging.getLogger(__name__) MESSAGE = { - "invaild": { - LanguageType.CHINESE: "SQL查询错误:无法生成有效的SQL语句!", - LanguageType.ENGLISH: "SQL query error: Unable to generate valid SQL statements!", - }, "fail": { LanguageType.CHINESE: "SQL查询错误:SQL语句执行失败!", LanguageType.ENGLISH: "SQL query error: SQL statement execution failed!", @@ -37,10 +33,13 @@ MESSAGE = { class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput): """SQL工具。用于调用外置的Chat2DB工具的API,获得SQL语句;再在PostgreSQL中执行SQL语句,获得数据。""" - database_url: str = Field(description="数据库连接地址") + database_type: str = Field(description="数据库类型",default="postgres") # mysql mongodb opengauss postgres + host: str = Field(description="数据库地址",default="localhost") + port: int = Field(description="数据库端口",default=5432) + username: str = Field(description="数据库用户名",default="root") + password: str = Field(description="数据库密码",default="root") + database: str = Field(description="数据库名称",default="postgres") table_name_list: list[str] = Field(description="表名列表",default=[]) - top_k: int = Field(description="生成SQL语句数量",default=5) - use_llm_enhancements: bool = Field(description="是否使用大模型增强", default=False) @classmethod @@ -66,98 +65,57 @@ class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput): ) - async def _generate_sql(self, data: SQLInput) -> list[dict[str, Any]]: - """生成SQL语句列表""" - post_data = { - "database_url": self.database_url, - "table_name_list": self.table_name_list, - "question": data.question, - "topk": self.top_k, - "use_llm_enhancements": self.use_llm_enhancements, - } - headers = {"Content-Type": "application/json"} - - sql_list = [] - request_num = 0 - max_request = 5 - - while request_num < max_request and len(sql_list) < self.top_k: - try: - async with httpx.AsyncClient() as client: - response = await client.post( - config.extra.sql_url + "/database/sql", - headers=headers, - json=post_data, - timeout=60.0, - ) - request_num += 1 - if response.status_code == status.HTTP_200_OK: - result = response.json() - if result["code"] == status.HTTP_200_OK: - sql_list.extend(result["result"]["sql_list"]) - else: - logger.error("[SQL] 生成失败:%s", response.text) - except Exception: - logger.exception("[SQL] 生成失败") - request_num += 1 - - return sql_list - - - async def _execute_sql( - self, - sql_list: list[dict[str, Any]], - ) -> tuple[list[dict[str, Any]] | None, str | None]: - """执行SQL语句并返回结果""" - headers = {"Content-Type": "application/json"} - - for sql_dict in sql_list: - try: - async with httpx.AsyncClient() as client: - response = await client.post( - config.extra.sql_url + "/sql/execute", - headers=headers, - json={ - "database_id": sql_dict["database_id"], - "sql": sql_dict["sql"], - }, - timeout=60.0, - ) - if response.status_code == status.HTTP_200_OK: - result = response.json() - if result["code"] == status.HTTP_200_OK: - return result["result"], sql_dict["sql"] - else: - logger.error("[SQL] 调用失败:%s", response.text) - except Exception: - logger.exception("[SQL] 调用失败") - - return None, None - - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: - """运行SQL工具""" + """运行SQL工具, 支持MySQL, MongoDB, PostgreSQL, OpenGauss""" data = SQLInput(**input_data) - # 生成SQL语句 - sql_list = await self._generate_sql(data) - if not sql_list: - raise CallError( - message=MESSAGE["invaild"][self._sys_vars.language], - data={}, - ) - - # 执行SQL语句 - sql_exec_results, sql_exec = await self._execute_sql(sql_list) - if sql_exec_results is None or sql_exec is None: + headers = {"Content-Type": "application/json"} + post_data = { + "type": self.database_type, + "host": self.host, + "port": self.port, + "username": self.username, + "password": self.password, + "database": self.database, + "goal": data.question, + "table_list": self.table_name_list, + } + try: + async with httpx.AsyncClient() as client: + response = await client.post( + config.extra.sql_url + "/sql/handler", + headers=headers, + json=post_data, + timeout=60.0, + ) + + result = response.json() + if response.status_code == status.HTTP_200_OK: + if result["code"] == status.HTTP_200_OK: + result_data = result["result"] + sql_exec_results = result_data.get("execute_result") + sql_exec = result_data.get("sql") + sql_exec_risk = result_data.get("risk") + logger.info( + "[SQL] 调用成功\n[SQL 语句]: %s\n[SQL 结果]: %s\n[SQL 风险]: %s", + sql_exec, + sql_exec_results, + sql_exec_risk, + ) + + else: + logger.error("[SQL] 调用失败:%s", response.text) + logger.error("[SQL] 错误信息:%s", response.json()["result"]) + except Exception as e: + logger.exception("[SQL] 调用失败") raise CallError( message=MESSAGE["fail"][self._sys_vars.language], data={}, - ) + ) from e # 返回结果 data = SQLOutput( - dataset=sql_exec_results, + result=sql_exec_results, sql=sql_exec, ).model_dump(exclude_none=True, by_alias=True) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 992cf71c3..bf54ce3a2 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -11,12 +11,11 @@ from pydantic import Field from apps.constants import AGENT_FINAL_STEP_NAME, AGENT_MAX_RETRY_TIMES, AGENT_MAX_STEPS from apps.models.mcp import MCPTools from apps.models.task import ExecutorHistory -from apps.scheduler.call.slot.slot import Slot from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.mcp_agent.host import MCPHost from apps.scheduler.mcp_agent.plan import MCPPlanner from apps.scheduler.pool.mcp.pool import MCPPool -from apps.schemas.enum_var import EventType, ExecutorStatus, LanguageType, StepStatus +from apps.schemas.enum_var import EventType, ExecutorStatus, StepStatus from apps.schemas.mcp import Step from apps.schemas.message import FlowParams from apps.services.appcenter import AppCenterManager @@ -49,8 +48,8 @@ class MCPAgentExecutor(BaseExecutor): self._mcp_list = [] self._current_input = {} # 初始化MCP Host相关对象 - self._planner = MCPPlanner(self.task.runtime.userInput, self.llm, self.task.runtime.language) - self._host = MCPHost(self.task.metadata.userSub, self.llm) + self._planner = MCPPlanner(self.task, self.llm) + self._host = MCPHost(self.task, self.llm) user = await UserManager.get_user(self.task.metadata.userSub) if not user: err = "[MCPAgentExecutor] 用户不存在: %s" @@ -359,6 +358,7 @@ class MCPAgentExecutor(BaseExecutor): if self.task.state.stepStatus == StepStatus.PARAM: if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: del self.task.context[-1] + await self.get_tool_input_param(is_first=False) elif self.task.state.stepStatus == StepStatus.WAITING: if self.params: if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: @@ -395,7 +395,7 @@ class MCPAgentExecutor(BaseExecutor): EventType.STEP_ERROR, data={ "message": self.task.state.errorMessage, - } + }, ) if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: self.task.context[-1].stepStatus = StepStatus.ERROR @@ -413,7 +413,7 @@ class MCPAgentExecutor(BaseExecutor): executorId=self.task.state.executorId, executorName=self.task.state.executorName, executorStatus=self.task.state.executorStatus, - inputData=self.task.state.currentInput, + inputData=self._current_input, outputData={ "message": self.task.state.errorMessage, }, @@ -423,13 +423,11 @@ class MCPAgentExecutor(BaseExecutor): else: mcp_tool = self.tools[self.task.state.toolId] is_param_error = await self._planner.is_param_error( - self.task.runtime.userInput, await self._host.assemble_memory(self.task.runtime, self.task.context), self.task.state.errorMessage, mcp_tool, self.task.state.stepDescription, - self.task.state.currentInput, - language=self.task.runtime.language, + self._current_input, ) if is_param_error.is_param_error: # 如果是参数错误,生成参数补充 @@ -457,7 +455,7 @@ class MCPAgentExecutor(BaseExecutor): executorId=self.task.state.executorId, executorName=self.task.state.executorName, executorStatus=self.task.state.executorStatus, - inputData=self.task.state.currentInput, + inputData=self._current_input, outputData={ "message": self.task.state.errorMessage, }, @@ -495,7 +493,7 @@ class MCPAgentExecutor(BaseExecutor): # 初始化状态 self.task.state.executorId = str(uuid.uuid4()) self.task.state.executorName = (await self._planner.get_flow_name()).flow_name - flow_risk = await self._planner.get_flow_excute_risk(self.tool_list, self.task.language) + flow_risk = await self._planner.get_flow_excute_risk(self.tool_list) if self._user.autoExecute: data = flow_risk.model_dump(exclude_none=True, by_alias=True) await self.get_next_step() diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 3362c918d..f22885ece 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -9,13 +9,13 @@ from datetime import UTC, datetime from pydantic import Field from apps.models.task import ExecutorCheckpoint -from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT from apps.schemas.enum_var import EventType, ExecutorStatus, LanguageType, SpecialCallType, StepStatus from apps.schemas.flow import Flow, Step from apps.schemas.request_data import RequestDataApp from apps.schemas.task import StepQueueItem from .base import BaseExecutor +from .prompt import FLOW_ERROR_PROMPT from .step import StepExecutor logger = logging.getLogger(__name__) @@ -219,7 +219,7 @@ class FlowExecutor(BaseExecutor): node=SpecialCallType.LLM.value, type=SpecialCallType.LLM.value, params={ - "user_prompt": LLM_ERROR_PROMPT[self.task.runtime.language].replace( + "user_prompt": FLOW_ERROR_PROMPT[self.task.runtime.language].replace( "{{ error_info }}", self.state.errorMessage["err_msg"], ), diff --git a/apps/scheduler/executor/prompt.py b/apps/scheduler/executor/prompt.py index 4200461ff..727ba3045 100644 --- a/apps/scheduler/executor/prompt.py +++ b/apps/scheduler/executor/prompt.py @@ -1,9 +1,11 @@ """Executor相关大模型提示词""" +from textwrap import dedent + from apps.schemas.enum_var import LanguageType EXECUTOR_REASONING: dict[LanguageType, str] = { - LanguageType.CHINESE: r""" + LanguageType.CHINESE: dedent(r""" 你是一个可以使用工具的智能助手。 @@ -33,8 +35,8 @@ EXECUTOR_REASONING: dict[LanguageType, str] = { 请综合以上信息,再次一步一步地进行思考,并给出见解和行动: - """, - LanguageType.ENGLISH: r""" + """), + LanguageType.ENGLISH: dedent(r""" You are an intelligent assistant who can use tools. @@ -66,5 +68,168 @@ and give the next action. Please integrate the above information, think step by step again, provide insights, and give actions: - """, + """), +} + +FLOW_ERROR_PROMPT: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + + 你是一位智能助手,能够根据用户的问题,使用Python工具获取信息,并作出回答。你在使用工具解决回答用户的问题时,发生了错误。 + 你的任务是:分析工具(Python程序)的异常信息,分析造成该异常可能的原因,并以通俗易懂的方式,将原因告知用户。 + + 当前时间:{{ time }},可以作为时间参照。 + 发生错误的程序异常信息将在中给出,用户的问题将在中给出,上下文背景信息将在中给出。 + 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息。 + + + + {{ error_info }} + + + + {{ question }} + + + + {{ context }} + + + 现在,输出你的回答: + """, + ).strip("\n"), + LanguageType.ENGLISH: dedent( + r""" + + You are an intelligent assistant. When using Python tools to answer user questions, an error occurred. + Your task is: Analyze the exception information of the tool (Python program), analyze the possible \ +causes of the error, and inform the user in an easy-to-understand way. + + Current time: {{ time }}, which can be used as a reference. + The program exception information that occurred will be given in , the user's question \ +will be given in , and the context background information will be given in . + Note: Do not include any XML tags in the output. Do not make up any information. If you think the \ +user's question is unrelated to the background information, please ignore the background information. + + + + {{ error_info }} + + + + {{ question }} + + + + {{ context }} + + + Now, please output your answer: + """, + ).strip("\n"), } + +GEN_RAG_ANSWER: dict[LanguageType, str] = { + LanguageType.CHINESE: r""" + + 你是openEuler社区的智能助手。请结合给出的背景信息, 回答用户的提问,并且基于给出的背景信息在相关句子后\ +进行脚注。 + 一个例子将在中给出。 + 上下文背景信息将在中给出。 + 用户的提问将在中给出。 + 注意: + 1.输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。 + 2.脚注的格式为[[1]],[[2]],[[3]]等,脚注的内容为提供的文档的id。 + 3.脚注只出现在回答的句子的末尾,例如句号、问号等标点符号后面。 + 4.不要对脚注本身进行解释或说明。 + 5.请不要使用中的文档的id作为脚注。 + + + + + + openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。 + + + openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。 + + + + + openEuler社区的成员来自世界各地,包括开发者、用户和企业。 + + + openEuler社区的成员共同努力,推动开源操作系统的发展,并且为用户提供支持和帮助。 + + + + + openEuler社区的目标是什么? + + + openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。[[1]] openEuler社区的目标是为\ +用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。[[1]] + + + + + {bac_info} + + + {user_question} + + """, + LanguageType.ENGLISH: r""" + + You are a helpful assistant of openEuler community. Please answer the user's question based on the \ +given background information and add footnotes after the related sentences. + An example will be given in . + The background information will be given in . + The user's question will be given in . + Note: + 1. Do not include any XML tags in the output, and do not make up any information. If you think the \ +user's question is unrelated to the background information, please ignore the background information and directly \ +answer. + 2. Your response should not exceed 250 words. + + + + + + openEuler community is an open source operating system community, committed to promoting \ +the development of the Linux operating system. + + + openEuler community aims to provide users with a stable, secure, and efficient operating \ +system platform, and support multiple hardware architectures. + + + + + Members of the openEuler community come from all over the world, including developers, \ +users, and enterprises. + + + Members of the openEuler community work together to promote the development of open \ +source operating systems, and provide support and assistance to users. + + + + + What is the goal of openEuler community? + + + openEuler community is an open source operating system community, committed to promoting the \ +development of the Linux operating system. [[1]] openEuler community aims to provide users with a stable, secure, \ +and efficient operating system platform, and support multiple hardware architectures. [[1]] + + + + + {bac_info} + + + {user_question} + + """, + } diff --git a/apps/scheduler/executor/qa.py b/apps/scheduler/executor/qa.py index 1fc185664..38eb47481 100644 --- a/apps/scheduler/executor/qa.py +++ b/apps/scheduler/executor/qa.py @@ -2,16 +2,16 @@ import logging import uuid +from collections.abc import AsyncGenerator from datetime import UTC, datetime -from textwrap import dedent +from apps.llm.token import TokenCalculator from apps.models.document import Document +from apps.scheduler.call import LLM, RAG from apps.schemas.enum_var import EventType, ExecutorStatus from apps.schemas.message import DocumentAddContent, TextAddContent -from apps.schemas.rag_data import RAGEventData from apps.schemas.record import RecordDocument from apps.services.document import DocumentManager -from apps.services.rag import RAG from .base import BaseExecutor @@ -21,7 +21,7 @@ _logger = logging.getLogger(__name__) class QAExecutor(BaseExecutor): """用于执行智能问答的Executor""" - async def get_docs(self, conversation_id: uuid.UUID) -> tuple[list[RecordDocument] | list[Document], list[str]]: + async def _get_docs(self, conversation_id: uuid.UUID) -> tuple[list[RecordDocument] | list[Document], list[str]]: """获取当前问答可供关联的文档""" doc_ids = [] # 从Conversation中获取刚上传的文档 @@ -32,43 +32,95 @@ class QAExecutor(BaseExecutor): return docs, doc_ids - async def _push_rag_chunk(self, content: str) -> RAGEventData | None: + async def _push_rag_text(self, content: str) -> None: """推送RAG单个消息块""" # 如果是换行 if not content or not content.rstrip().rstrip("\n"): - return None + return + await self._push_message( + event_type=EventType.TEXT_ADD.value, + data=TextAddContent(text=content).model_dump(exclude_none=True, by_alias=True), + ) - try: - content_obj = RAGEventData.model_validate_json(dedent(content[6:]).rstrip("\n")) - # 如果是空消息 - if not content_obj.content: - return None - - # 推送消息 - if content_obj.event_type == EventType.TEXT_ADD.value: - await self._push_message( - event_type=content_obj.event_type, - data=TextAddContent(text=content_obj.content).model_dump(exclude_none=True, by_alias=True), - ) - elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - await self._push_message( - event_type=content_obj.event_type, - data=DocumentAddContent( - documentId=content_obj.content.get("id", ""), - documentOrder=content_obj.content.get("order", 0), - documentAuthor=content_obj.content.get("author", ""), - documentName=content_obj.content.get("name", ""), - documentAbstract=content_obj.content.get("abstract", ""), - documentType=content_obj.content.get("extension", ""), - documentSize=content_obj.content.get("size", 0), - createdAt=round(content_obj.content.get("created_at", datetime.now(tz=UTC).timestamp()), 3), - ).model_dump(exclude_none=True, by_alias=True), + + async def chat_with_llm_base_on_rag( + self, + doc_ids: list[str], + data: RAGInput, + ) -> AsyncGenerator[str, None]: + """获取RAG服务的结果""" + bac_info, doc_info_list = await self._assemble_doc_info( + doc_chunk_list=doc_chunk_list, max_tokens=llm_config.maxToken, + ) + messages = [ + *history, + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": RAG.user_prompt[self.task.runtime.language].format( + bac_info=bac_info, + user_question=data.query, + ), + }, + ] + input_tokens = TokenCalculator().calculate_token_length(messages=messages) + output_tokens = 0 + doc_cnt: int = 0 + for doc_info in doc_info_list: + doc_cnt = max(doc_cnt, doc_info["order"]) + yield ( + "data: " + + json.dumps( + { + "event_type": EventType.DOCUMENT_ADD.value, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "content": doc_info, + }, + ensure_ascii=False, ) - except Exception: - _logger.exception("[Scheduler] RAG服务返回错误数据") - return None - else: - return content_obj + + "\n\n" + ) + max_footnote_length = 4 + tmp_doc_cnt = doc_cnt + while tmp_doc_cnt > 0: + tmp_doc_cnt //= 10 + max_footnote_length += 1 + + full_result = "" + async for chunk in self._llm( + messages, + streaming=True, + ): + full_result += chunk + yield chunk + + # 匹配脚注 + footnotes = re.findall(r"\[\[\d+\]\]", tmp_chunk) + # 去除编号大于doc_cnt的脚注 + footnotes = [fn for fn in footnotes if int(fn[2:-2]) > doc_cnt] + footnotes = list(set(footnotes)) # 去重 + + + async def _push_rag_doc(self, doc: Document) -> None: + """推送RAG单个消息块""" + # 如果是换行 + await self._push_message( + event_type=EventType.DOCUMENT_ADD.value, + data=DocumentAddContent( + documentId=doc.id, + documentOrder=doc.order, + documentAuthor=doc.author, + documentName=doc.name, + documentAbstract=doc.abstract, + documentType=doc.extension, + documentSize=doc.size, + createdAt=round(doc.created_at, 3), + ).model_dump(exclude_none=True, by_alias=True), + ) async def run(self) -> None: """运行QA""" diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index de4d39df5..c381f4d21 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -12,15 +12,13 @@ from mcp.types import TextContent from apps.llm.function import JsonGenerator from apps.models.mcp import MCPTools -from apps.models.task import ExecutorHistory from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.pool.mcp.client import MCPClient from apps.scheduler.pool.mcp.pool import MCPPool -from apps.schemas.enum_var import ExecutorStatus, LanguageType, StepStatus -from apps.schemas.mcp import MCPPlanItem +from apps.schemas.enum_var import LanguageType +from apps.schemas.mcp import MCPContext, MCPPlanItem from apps.schemas.scheduler import LLMConfig from apps.services.mcp_service import MCPServiceManager -from apps.services.task import TaskManager logger = logging.getLogger(__name__) @@ -28,14 +26,13 @@ logger = logging.getLogger(__name__) class MCPHost: """MCP宿主服务""" - def __init__(self, user_sub: str,task_id: uuid.UUID, llm: LLMConfig) -> None: + def __init__(self, user_sub: str, task_id: uuid.UUID, llm: LLMConfig, language: LanguageType) -> None: """初始化MCP宿主""" self._task_id = task_id self._user_sub = user_sub - # 注意:runtime在工作流中是flow_id和step_description,在Agent中可为标识Agent的id和description - self._runtime_id = runtime_id - self._runtime_name = runtime_name self._context_list = [] + self._language = language + self._llm = llm self._env = SandboxedEnvironment( loader=BaseLoader(), autoescape=False, @@ -63,19 +60,10 @@ class MCPHost: async def assemble_memory(self) -> str: """组装记忆""" - task = await TaskManager.get_task_data_by_task_id(self._task_id) - if not task: - logger.error("任务 %s 不存在", self._task_id) - return "" - - context_list = [] - for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) - if not context: - continue - context_list.append(context) + # 从host的context_list中获取context + context_list = self._context_list - return self._env.from_string(MEMORY_TEMPLATE[self.language]).render( + return self._env.from_string(MEMORY_TEMPLATE[self._language]).render( context_list=context_list, ) @@ -101,34 +89,26 @@ class MCPHost: "message": result, } - # 创建context;注意用法 - context = ExecutorHistory( - taskId=self._task_id, - executorId=self._runtime_id, - executorName=self._runtime_name, - executorStatus=ExecutorStatus.RUNNING, - stepId=uuid.uuid4(), - stepName=tool.toolName, - # description是规划的实际内容 - stepDescription=plan_item.content, - stepStatus=StepStatus.SUCCESS, - inputData=input_data, - outputData=output_data, + # 创建简化版context + context = MCPContext( + step_description=plan_item.content, + input_data=input_data, + output_data=output_data, ) - # 保存到task - task = await TaskManager.get_task_data_by_task_id(self._task_id) - if not task: - logger.error("任务 %s 不存在", self._task_id) - return {} - self._context_list.append(context.id) - context.append(context) + # 保存到host的context_list + self._context_list.append(context) return output_data async def _fill_params(self, tool: MCPTools, query: str) -> dict[str, Any]: """填充工具参数""" + if not self._llm.function: + err = "[MCPHost] 未设置FunctionCall模型" + logger.error(err) + raise RuntimeError(err) + # 更清晰的输入·指令,这样可以调用generate llm_query = rf""" 请使用参数生成工具,生成满足以下目标的工具参数: @@ -138,6 +118,7 @@ class MCPHost: # 进行生成 json_generator = JsonGenerator( + self._llm.function, llm_query, [ {"role": "system", "content": "You are a helpful assistant."}, diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py index 476094f4f..40af1ae31 100644 --- a/apps/scheduler/mcp/prompt.py +++ b/apps/scheduler/mcp/prompt.py @@ -506,8 +506,8 @@ MEMORY_TEMPLATE: dict[str, str] = { r""" {% for ctx in context_list %} - Step {{ loop.index }}: {{ ctx.step_description }} - - Called tool `{{ ctx.step_id }}` and provided parameters `{{ ctx.input_data }}` - - Execution status: {{ ctx.status }} + - Called tool `{{ ctx.step_name }}` and provided parameters `{{ ctx.input_data }}` + - Execution status: {{ ctx.step_status }} - Got data: `{{ ctx.output_data }}` {% endfor %} """, diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 3594b9d89..84d7f4d87 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -6,6 +6,7 @@ import logging from sqlalchemy import select from apps.common.postgres import postgres +from apps.llm.function import JsonGenerator from apps.models.mcp import MCPTools from apps.schemas.mcp import MCPSelectResult from apps.schemas.scheduler import LLMConfig @@ -42,14 +43,23 @@ class MCPSelector: raise RuntimeError(err) logger.info("[MCPSelector] 调用结构化输出小模型") - message = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": reasoning_result}, - ] schema = MCPSelectResult.model_json_schema() # schema中加入选项 schema["properties"]["mcp_id"]["enum"] = mcp_ids - result = await self._llm.function.call(messages=message, schema=schema) + + # 使用JsonGenerator生成JSON + conversation = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": reasoning_result}, + ] + generator = JsonGenerator( + llm=self._llm.function, + query=reasoning_result, + conversation=conversation, + schema=schema, + ) + result = await generator.generate() + try: result = MCPSelectResult.model_validate(result) except Exception: diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index b41b83722..1b8e3ed92 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -90,7 +90,7 @@ class MCPHost(MCPBase): current_goal=current_goal, tool_description=mcp_tool.description, input_schema=mcp_tool.inputSchema, - current_input=current_input, + input_params=current_input, error_message=error_message, params=params, params_description=params_description, diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index b060d6ddf..943a9312d 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -21,7 +21,6 @@ from apps.scheduler.mcp_agent.prompt import ( RISK_EVALUATE, ) from apps.scheduler.slot.slot import Slot -from apps.schemas.enum_var import LanguageType from apps.schemas.mcp import ( FlowName, FlowRisk, @@ -68,13 +67,9 @@ class MCPPlanner(MCPBase): # 使用Step模型解析结果 return Step.model_validate(step) - async def get_flow_excute_risk( - self, - tools: list[MCPTools], - language: LanguageType = LanguageType.CHINESE, - ) -> FlowRisk: + async def get_flow_excute_risk(self, tools: list[MCPTools]) -> FlowRisk: """获取当前流程的风险评估结果""" - template = _env.from_string(GENERATE_FLOW_EXCUTE_RISK[language]) + template = _env.from_string(GENERATE_FLOW_EXCUTE_RISK[self._language]) prompt = template.render( goal=self._goal, tools=tools, @@ -165,9 +160,7 @@ class MCPPlanner(MCPBase): # 解析为结构化数据 return await self._parse_result(result, schema_with_null) - async def generate_answer( - self, memory: str, - ) -> AsyncGenerator[str, None]: + async def generate_answer(self, memory: str) -> AsyncGenerator[str, None]: """生成最终回答""" template = _env.from_string(FINAL_ANSWER[self._language]) prompt = template.render( diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 68a1f0a57..313b4d3cc 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -1208,8 +1208,8 @@ REPAIR_PARAMS: dict[LanguageType, str] = { ## 工具入参Schema {{input_schema}} - ## 工具入参 - {{input_param}} + ## 工具当前的入参 + {{input_params}} ## 运行报错 {{error_message}} @@ -1301,8 +1301,8 @@ parameter descriptions. ## Tool input schema {{input_schema}} - ## Tool input parameters - {{input_param}} + ## Current tool input parameters + {{input_params}} ## Runtime error {{error_message}} diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index fda1044fe..063b37a33 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -18,6 +18,17 @@ class MCPStatus(str, Enum): ERROR = "error" +class MCPContext(BaseModel): + """MCP执行上下文""" + + step_description: str + """步骤描述""" + input_data: dict[str, Any] + """输入数据""" + output_data: dict[str, Any] + """输出数据""" + + class MCPBasicConfig(BaseModel): """MCP 基本配置""" diff --git a/apps/schemas/rag_data.py b/apps/schemas/rag_data.py index 605559e49..79dc21229 100644 --- a/apps/schemas/rag_data.py +++ b/apps/schemas/rag_data.py @@ -1,25 +1,9 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """请求RAG相关接口时,使用的数据类型""" -import uuid -from typing import Any, Literal +from typing import Literal -from pydantic import BaseModel, Field - - -class RAGQueryReq(BaseModel): - """查询RAG时的POST请求体""" - - kb_ids: list[uuid.UUID] = Field(default=[], description="资产id", alias="kbIds") - query: str = Field(default="", description="查询内容") - top_k: int = Field(default=5, description="返回的结果数量", alias="topK") - doc_ids: list[str] | None = Field(default=None, description="文档id", alias="docIds") - search_method: str = Field(default="dynamic_weighted_keyword_and_vector", - description="检索方法", alias="searchMethod") - is_related_surrounding: bool = Field(default=True, description="是否关联上下文", alias="isRelatedSurrounding") - is_classify_by_doc: bool = Field(default=True, description="是否按文档分类", alias="isClassifyByDoc") - is_rerank: bool = Field(default=False, description="是否重新排序", alias="isRerank") - tokens_limit: int | None = Field(default=None, description="token限制", alias="tokensLimit") +from pydantic import BaseModel class RAGFileParseReqItem(BaseModel): @@ -31,15 +15,6 @@ class RAGFileParseReqItem(BaseModel): type: str -class RAGEventData(BaseModel): - """RAG服务返回的事件数据""" - - content: Any - event_type: str - input_tokens: int = 0 - output_tokens: int = 0 - - class RAGFileParseReq(BaseModel): """请求RAG处理文件时的POST请求体""" diff --git a/apps/services/rag.py b/apps/services/rag.py deleted file mode 100644 index f7cd42c9a..000000000 --- a/apps/services/rag.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""对接Euler Copilot RAG""" - -import json -import logging -import re -from collections.abc import AsyncGenerator - -from apps.llm.token import TokenCalculator -from apps.schemas.enum_var import EventType, LanguageType -from apps.schemas.rag_data import RAGQueryReq - -logger = logging.getLogger(__name__) - - -class RAG: - """调用RAG服务,获取知识库答案""" - - @staticmethod - async def chat_with_llm_base_on_rag( # noqa: C901, PLR0913 - user_sub: str, - llm_id: str, - history: list[dict[str, str]], - doc_ids: list[str], - data: RAGQueryReq, - language: LanguageType = LanguageType.CHINESE, - ) -> AsyncGenerator[str, None]: - """获取RAG服务的结果""" - if history: - pass - doc_chunk_list = await RAG.get_doc_info_from_rag( - user_sub=user_sub, max_tokens=llm_config.maxToken, doc_ids=doc_ids, data=data, - ) - bac_info, doc_info_list = await RAG.assemble_doc_info( - doc_chunk_list=doc_chunk_list, max_tokens=llm_config.maxToken, - ) - messages = [ - *history, - { - "role": "system", - "content": RAG.system_prompt, - }, - { - "role": "user", - "content": RAG.user_prompt[language].format( - bac_info=bac_info, - user_question=data.query, - ), - }, - ] - input_tokens = TokenCalculator().calculate_token_length(messages=messages) - output_tokens = 0 - doc_cnt: int = 0 - for doc_info in doc_info_list: - doc_cnt = max(doc_cnt, doc_info["order"]) - yield ( - "data: " - + json.dumps( - { - "event_type": EventType.DOCUMENT_ADD.value, - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "content": doc_info, - }, - ensure_ascii=False, - ) - + "\n\n" - ) - max_footnote_length = 4 - tmp_doc_cnt = doc_cnt - while tmp_doc_cnt > 0: - tmp_doc_cnt //= 10 - max_footnote_length += 1 - buffer = "" - async for chunk in reasion_llm.call( - messages, - max_tokens=llm_config.maxToken, - streaming=True, - temperature=0.7, - result_only=False, - model=llm_config.modelName, - ): - tmp_chunk = buffer + chunk - # 防止脚注被截断 - if len(tmp_chunk) >= 2 and tmp_chunk[-2:] != "]]": - index = len(tmp_chunk) - 1 - while index >= max(0, len(tmp_chunk) - max_footnote_length) and tmp_chunk[index] != "]": - index -= 1 - if index >= 0: - buffer = tmp_chunk[index + 1:] - tmp_chunk = tmp_chunk[:index + 1] - else: - buffer = "" - # 匹配脚注 - footnotes = re.findall(r"\[\[\d+\]\]", tmp_chunk) - # 去除编号大于doc_cnt的脚注 - footnotes = [fn for fn in footnotes if int(fn[2:-2]) > doc_cnt] - footnotes = list(set(footnotes)) # 去重 - if footnotes: - for fn in footnotes: - tmp_chunk = tmp_chunk.replace(fn, "") - output_tokens += TokenCalculator().calculate_token_length( - messages=[ - {"role": "assistant", "content": tmp_chunk}, - ], - pure_text=True, - ) - yield ( - "data: " - + json.dumps( - { - "event_type": EventType.TEXT_ADD.value, - "content": tmp_chunk, - "input_tokens": input_tokens, - "output_tokens": output_tokens, - }, - ensure_ascii=False, - ) - + "\n\n" - ) - if buffer: - output_tokens += TokenCalculator().calculate_token_length( - messages=[ - {"role": "assistant", "content": buffer}, - ], - pure_text=True, - ) - yield ( - "data: " - + json.dumps( - { - "event_type": EventType.TEXT_ADD.value, - "content": buffer, - "input_tokens": input_tokens, - "output_tokens": output_tokens, - }, - ensure_ascii=False, - ) - + "\n\n" - ) -- Gitee