From b0df53ed8b7168f6c0d377410d0b5ac2729630e5 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Thu, 23 Jan 2025 23:23:04 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=9B=B4=E6=96=B0Call=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/llm/patterns/select.py | 4 +- apps/scheduler/call/api.py | 113 ++++++++++++--------------- apps/scheduler/call/choice.py | 10 +-- apps/scheduler/call/next_flow.py | 1 - apps/scheduler/call/rag.py | 18 +++++ apps/scheduler/call/reformat.py | 2 +- apps/scheduler/call/render/render.py | 8 +- apps/scheduler/call/sql.py | 90 ++++++++++++--------- apps/scheduler/pool/loader/btdl.py | 1 + apps/scheduler/pool/loader/call.py | 20 ++--- 10 files changed, 145 insertions(+), 122 deletions(-) diff --git a/apps/llm/patterns/select.py b/apps/llm/patterns/select.py index f7bb3284..97aba3d9 100644 --- a/apps/llm/patterns/select.py +++ b/apps/llm/patterns/select.py @@ -78,8 +78,8 @@ class Select(CorePattern): choices_prompt = "" choice_str_list = [] for choice in choices: - choices_prompt += "- {}: {}\n".format(choice["name"], choice["description"]) - choice_str_list.append(choice["name"]) + choices_prompt += "- {}: {}\n".format(choice["branch"], choice["description"]) + choice_str_list.append(choice["branch"]) return choices_prompt, choice_str_list async def _generate_single_attempt(self, task_id: str, user_input: str, choice_list: list[str]) -> str: diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api.py index 5ca7c7ac..2098b84b 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api.py @@ -23,27 +23,27 @@ class _APIParams(BaseModel): ] = Field(description="API接口的HTTP Method") timeout: int = Field(description="工具超时时间", default=300) input_data: dict[str, Any] = Field(description="固定数据", default={}) + auth: dict[str, Any] = Field(description="API鉴权信息", default={}) service_id: Optional[str] = Field(description="服务ID") class _APIOutput(BaseModel): """API调用工具的输出""" + http_code: int = Field(description="API调用工具的HTTP返回码") + message: str = Field(description="API调用工具的执行结果") output: dict[str, Any] = Field(description="API调用工具的输出") -class API(CoreCall): +class API(metaclass=CoreCall, param_cls=_APIParams, output_cls=_APIOutput): """API调用工具""" name: str = "api" description: str = "根据给定的用户输入和历史记录信息,向某一个API接口发送请求、获取数据。" - params: type[_APIParams] = _APIParams - async def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 + def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 """初始化API调用工具""" - # 插件鉴权 - self._auth = json.loads(str(plugin_data.auth)) # 从spec中找出该接口对应的spec for item in full_spec.endpoints: name, _, _ = item @@ -53,10 +53,10 @@ class API(CoreCall): err = "[API] Endpoint not found." raise ValueError(err) - if method == "POST": + if kwargs["method"] == "POST": if "requestBody" in self._spec[2]: self.slot_schema, self._data_type = self._check_data_type(self._spec[2]["requestBody"]["content"]) - elif method == "GET": + elif kwargs["method"] == "GET": if "parameters" in self._spec[2]: self.slot_schema = self.parameters_to_spec(self._spec[2]["parameters"]) self._data_type = "json" @@ -65,11 +65,11 @@ class API(CoreCall): raise NotImplementedError(err) - async def call(self, slot_data: dict[str, Any]) -> CallResult: + async def __call__(self, slot_data: dict[str, Any]) -> _APIOutput: """调用API,然后返回LLM解析后的数据""" self._session = aiohttp.ClientSession() try: - result = await self._call_api(method, url, slot_data) + result = await self._call_api(slot_data) await self._session.close() return result except Exception as e: @@ -112,78 +112,53 @@ class API(CoreCall): return schema - @staticmethod - def process( - response_data: Optional[str], url: str, usage: str, response_schema: dict[str, Any], - ) -> CallResult: - """对返回值进行整体处理""" - # 如果结果太长,不使用大模型进行总结;否则使用大模型生成自然语言总结 - if response_data is None: - return CallResult( - output={}, - output_schema={}, - message=f"调用接口{url}成功,但返回值为空。", - ) - - if len(response_data) > MAX_API_RESPONSE_LENGTH: - response_data = response_data[:MAX_API_RESPONSE_LENGTH] - response_data = response_data[:response_data.rfind(",") - 1] - response_data = untrunc.complete(response_data) - - response_data = APISanitizer._process_response_schema(response_data, response_schema) - - message = dedent(f"""调用API从外部数据源获取了数据。API和数据源的描述为:{usage}""") - - return CallResult( - output=json.loads(response_data), - output_schema=response_schema, - message=message, - ) - + async def _make_api_call(self, data: Optional[dict], files: aiohttp.FormData): # noqa: ANN202, C901 + # 获取必要参数 + params: _APIParams = getattr(self, "_params") + syscall_vars: SysCallVars = getattr(self, "_syscall_vars") - async def _make_api_call(self, method: str, url: str, data: Optional[dict], files: aiohttp.FormData): # noqa: ANN202, C901 """调用API""" if self._data_type != "form": - header = { + req_header = { "Content-Type": "application/json", } else: - header = {} - cookie = {} - params = {} + req_header = {} + req_cookie = {} + req_params = {} if data is None: data = {} if self._auth is not None and "type" in self._auth: if self._auth["type"] == "header": - header.update(self._auth["args"]) + req_header.update(self._auth["args"]) elif self._auth["type"] == "cookie": - cookie.update(self._auth["args"]) + req_cookie.update(self._auth["args"]) elif self._auth["type"] == "params": - params.update(self._auth["args"]) + req_params.update(self._auth["args"]) elif self._auth["type"] == "oidc": token = await TokenManager.get_plugin_token( self._auth["domain"], - self._syscall_vars.session_id, + syscall_vars.session_id, self._auth["access_token_url"], int(self._auth["token_expire_time"]), ) - header.update({"access-token": token}) + req_header.update({"access-token": token}) - if self._params.method == "GET": - params.update(data) - return self._session.get(self._server + url, params=params, headers=header, cookies=cookie, - timeout=self.params.timeout) - if method == "POST": + if params.method == "GET": + req_params.update(data) + return self._session.get(params.full_url, params=req_params, headers=req_header, cookies=req_cookie, + timeout=params.timeout) + if params.method == "POST": if self._data_type == "form": form_data = files for key, val in data.items(): form_data.add_field(key, val) - return self._session.post(self._server + url, data=form_data, headers=header, cookies=cookie, - timeout=self.params.timeout) - return self._session.post(self._server + url, json=data, headers=header, cookies=cookie, - timeout=self.params.timeout) + return self._session.post(params.full_url, data=form_data, headers=req_header, cookies=req_cookie, + timeout=params.timeout) + return self._session.post(params.full_url, json=data, headers=req_header, cookies=req_cookie, + timeout=params.timeout) err = "Method not implemented." raise NotImplementedError(err) @@ -201,9 +176,12 @@ class API(CoreCall): raise NotImplementedError(err) - async def _call_api(self, method: str, url: str, slot_data: Optional[dict[str, Any]] = None) -> CallResult: - LOGGER.info(f"调用接口{url},请求数据为{slot_data}") - session_context = await self._make_api_call(method, url, slot_data, aiohttp.FormData()) + async def _call_api(self, slot_data: Optional[dict[str, Any]] = None) -> _APIOutput: + # 获取必要参数 + params: _APIParams = getattr(self, "_params") + LOGGER.info(f"调用接口{params.full_url},请求数据为{slot_data}") + + session_context = await self._make_api_call(slot_data, aiohttp.FormData()) async with session_context as response: if response.status >= status.HTTP_400_BAD_REQUEST: text = f"API发生错误:API返回状态码{response.status}, 原因为{response.reason}。" @@ -219,6 +197,19 @@ class API(CoreCall): response_schema = self._spec[2]["responses"]["content"]["application/json"]["schema"] else: response_schema = {} - LOGGER.info(f"调用接口{url}, 结果为 {response_data}") + LOGGER.info(f"调用接口{params.full_url}, 结果为 {response_data}") + + # 如果没有返回结果 + if response_data is None: + return _APIOutput( + http_code=response.status, + output={}, + message=f"调用接口{params.full_url},作用为但返回值为空。", + ) - return self.process(response_data, url, self._spec[1], response_schema) + response_data = self._process_response_schema(response_data, response_schema) + return _APIOutput( + output=json.loads(response_data), + output_schema=response_schema, + message=f"""调用API从外部数据源获取了数据。API和数据源的描述为:{usage}""", + ) diff --git a/apps/scheduler/call/choice.py b/apps/scheduler/call/choice.py index f2845fa0..2d696e75 100644 --- a/apps/scheduler/call/choice.py +++ b/apps/scheduler/call/choice.py @@ -14,8 +14,8 @@ from apps.scheduler.call.core import CoreCall class _ChoiceBranch(BaseModel): """Choice工具的选项""" - name: str = Field(description="选项的名称") - value: str = Field(description="选项的值") + branch: str = Field(description="选项的名称") + description: str = Field(description="选项的描述") class _ChoiceParams(BaseModel): @@ -37,7 +37,6 @@ class Choice(metaclass=CoreCall): name: str = "choice" description: str = "选择工具,用于根据给定的上下文和问题,判断正确/错误,或从选项列表中选择最符合用户要求的一项。" - params: type[_ChoiceParams] = _ChoiceParams async def __call__(self, _slot_data: dict[str, Any]) -> _ChoiceOutput: @@ -51,11 +50,12 @@ class Choice(metaclass=CoreCall): previous_data = syscall_vars.history[-1].output_data try: + choice_list = [item.model_dump() for item in params.choices] result = await Select().generate( question=params.propose, background=syscall_vars.background, data=previous_data, - choices=params.choices, + choices=choice_list, task_id=syscall_vars.task_id, ) except Exception as e: @@ -63,5 +63,5 @@ class Choice(metaclass=CoreCall): return _ChoiceOutput( next_step=result, - message=f"针对“{self.params.propose}”,作出的选择为:{result}。", + message=f"针对“{params.propose}”,作出的选择为:{result}。", ) diff --git a/apps/scheduler/call/next_flow.py b/apps/scheduler/call/next_flow.py index c60b9063..67b49b27 100644 --- a/apps/scheduler/call/next_flow.py +++ b/apps/scheduler/call/next_flow.py @@ -12,4 +12,3 @@ class NextFlowCall(metaclass=CoreCall): async def __call__(self, _slot_data: dict[str, Any]): """调用NextFlow工具""" - diff --git a/apps/scheduler/call/rag.py b/apps/scheduler/call/rag.py index e69de29b..08f7c974 100644 --- a/apps/scheduler/call/rag.py +++ b/apps/scheduler/call/rag.py @@ -0,0 +1,18 @@ +"""RAG工具:查询知识库 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" + +from typing import Any + +from apps.scheduler.call.core import CoreCall + + +class RAG(metaclass=CoreCall): + """RAG工具:查询知识库""" + + name: str = "rag" + description: str = "RAG工具,用于查询知识库" + + async def __call__(self, _slot_data: dict[str, Any]): + """调用RAG工具""" diff --git a/apps/scheduler/call/reformat.py b/apps/scheduler/call/reformat.py index 608dc05e..9e25e5f6 100644 --- a/apps/scheduler/call/reformat.py +++ b/apps/scheduler/call/reformat.py @@ -68,7 +68,7 @@ class Reformat(metaclass=CoreCall, param_cls=_ReformatParam, output_cls=_Reforma "time": time, "question": syscall_vars.question, }, ensure_ascii=False) - history_str = json.dumps([CallResult(**item.output_data).output for item in syscall_vars.history], ensure_ascii=False) + history_str = json.dumps([item.output_data["output"] for item in syscall_vars.history if "output" in item.output_data], ensure_ascii=False) data_template = dedent(f""" local extra = {extra_str}; local history = {history_str}; diff --git a/apps/scheduler/call/render/render.py b/apps/scheduler/call/render/render.py index 443f5bd5..f6eb2bf0 100644 --- a/apps/scheduler/call/render/render.py +++ b/apps/scheduler/call/render/render.py @@ -66,7 +66,13 @@ class Render(metaclass=CoreCall, param_cls=_RenderParam, output_cls=_RenderOutpu syscall_vars: SysCallVars = getattr(self, "_syscall_vars") # 检测前一个工具是否为SQL - data = syscall_vars.history[-1].output_data + if "dataset" not in syscall_vars.history[-1].output_data: + raise CallError( + message="图表生成失败!Render必须在SQL后调用!", + data={}, + ) + + data = syscall_vars.history[-1].output_data["output"] if data["type"] != "sql" or "dataset" not in data: raise CallError( message="图表生成失败!Render必须在SQL后调用!", diff --git a/apps/scheduler/call/sql.py b/apps/scheduler/call/sql.py index 43dd359d..05701a2b 100644 --- a/apps/scheduler/call/sql.py +++ b/apps/scheduler/call/sql.py @@ -4,47 +4,70 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json -from typing import Any +from typing import Any, Optional import aiohttp from fastapi import status -from sqlalchemy import create_engine, text +from pydantic import BaseModel, Field +from sqlalchemy import text from apps.common.config import config from apps.constants import LOGGER -from apps.entities.plugin import CallError, CallResult, SysCallVars +from apps.entities.plugin import CallError, SysCallVars +from apps.models.postgres import PostgreSQL from apps.scheduler.call.core import CoreCall -class SQL(CoreCall): +class _SQLParams(BaseModel): + """SQL工具的参数""" + + sql: Optional[str] = Field(description="用户输入") + + +class _SQLOutput(BaseModel): + """SQL工具的输出""" + + message: str = Field(description="SQL工具的执行结果") + dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果") + + +class SQL(metaclass=CoreCall, param_cls=_SQLParams, output_cls=_SQLOutput): """SQL工具。用于调用外置的Chat2DB工具的API,获得SQL语句;再在PostgreSQL中执行SQL语句,获得数据。""" name: str = "sql" description: str = "SQL工具,用于查询数据库中的结构化数据" - async def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 + def init(self, _syscall_vars: SysCallVars, **_kwargs) -> None: # noqa: ANN003 """初始化SQL工具。""" - await super().init(syscall_vars, **kwargs) # 初始化aiohttp的ClientSession self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300)) - # 初始化SQLAlchemy Engine - try: - db_url = f'postgresql+psycopg2://{config["POSTGRES_USER"]}:{config["POSTGRES_PWD"]}@{config["POSTGRES_HOST"]}/{config["POSTGRES_DATABASE"]}' - self._engine = create_engine(db_url, pool_size=20, max_overflow=80, pool_recycle=300, pool_pre_ping=True) - except Exception as e: - raise CallError(message=f"数据库连接失败:{e!s}", data={}) from e - async def call(self, _slot_data: dict[str, Any]) -> CallResult: - """运行SQL工具。 + async def __call__(self, _slot_data: dict[str, Any]) -> _SQLOutput: + """运行SQL工具""" + # 获取必要参数 + params: _SQLParams = getattr(self, "_params") + syscall_vars: SysCallVars = getattr(self, "_syscall_vars") + + # 若手动设置了SQL,则直接使用 + session = await PostgreSQL.get_session() + if params.sql: + try: + result = (await session.execute(text(params.sql))).all() + await session.close() + + dataset_list = [db_item._asdict() for db_item in result] + return _SQLOutput( + message="SQL查询成功!", + dataset=dataset_list, + ) + except Exception as e: + raise CallError(message=f"SQL查询错误:{e!s}", data={}) from e - 访问Chat2DB工具API,拿到针对用户输入的最多5条SQL语句。依次尝试每一条语句,直到查询出数据或全部不可用。 - :param slot_data: 经用户确认后的参数(目前未使用) - :return: 从数据库中查询得到的数据,或报错信息 - """ + # 若未设置SQL,则调用Chat2DB工具API,获取SQL post_data = { - "question": self._syscall_vars.question, + "question": syscall_vars.question, "topk_sql": 5, "use_llm_enhancements": True, } @@ -60,29 +83,18 @@ class SQL(CoreCall): ) result = json.loads(await response.text()) LOGGER.info(f"SQL工具返回的信息为:{result}") - await self._session.close() + for item in result["sql_list"]: try: - with self._engine.connect() as connection: - db_result = connection.execute(text(item["sql"])).all() - dataset_list = [db_item._asdict() for db_item in db_result] - return CallResult( - output={ - "dataset": dataset_list, - }, - output_schema={ - "dataset": { - "type": "array", - "description": "数据库查询结果", - "items": { - "type": "object", - "description": "数据库查询结果的每一行", - }, - }, - }, - message="数据库查询成功!", - ) + db_result = (await session.execute(text(item["sql"]))).all() + await session.close() + + dataset_list = [db_item._asdict() for db_item in db_result] + return _SQLOutput( + message="数据库查询成功!", + dataset=dataset_list, + ) except Exception as e: # noqa: PERF203 LOGGER.error(f"SQL查询错误,错误信息为:{e},正在换用下一条SQL语句。") diff --git a/apps/scheduler/pool/loader/btdl.py b/apps/scheduler/pool/loader/btdl.py index 75623973..57d9277c 100644 --- a/apps/scheduler/pool/loader/btdl.py +++ b/apps/scheduler/pool/loader/btdl.py @@ -47,6 +47,7 @@ class BTDLLoader: BTDLLoader._check_single_argument(value, strict=False) def _load_single_subcmd(self, binary_name: str, subcmd_spec: dict[str, Any]) -> dict[str, tuple[str, str, dict[str, Any], dict[str, Any], str]]: + """加载单个子命令""" if "name" not in subcmd_spec: err = "subcommand must have a name" raise ValueError(err) diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 36fae596..988390eb 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -12,7 +12,7 @@ import apps.scheduler.call as system_call from apps.common.config import config from apps.constants import CALL_DIR, LOGGER from apps.entities.enum_var import CallType -from apps.entities.flow import CallMetadata +from apps.entities.pool import CallPool from apps.models.mongo import MongoDB @@ -54,7 +54,7 @@ class CallLoader: return True @staticmethod - async def _load_system_call() -> list[CallMetadata]: + async def _load_system_call() -> list[CallPool]: """加载系统Call""" metadata = [] @@ -66,7 +66,7 @@ class CallLoader: continue metadata.append( - CallMetadata( + CallPool( _id=call_name, type=CallType.SYSTEM, name=call_cls.name, @@ -78,7 +78,7 @@ class CallLoader: return metadata @classmethod - async def _load_python_call(cls) -> list[CallMetadata]: + async def _load_python_call(cls) -> list[CallPool]: """加载Python Call""" call_dir = Path(config["SERVICE_DIR"]) / CALL_DIR metadata = [] @@ -136,25 +136,21 @@ class CallLoader: @staticmethod - async def load_one() - - - @staticmethod - async def load() -> None: + async def load_one() -> None: """加载Call""" call_metadata = await CallLoader._load_system_call() call_metadata.extend(await CallLoader._load_python_call()) @staticmethod - async def get() -> list[CallMetadata]: + async def get() -> list[CallPool]: """获取当前已知的所有Call元数据""" call_collection = MongoDB.get_collection("call") - result: list[CallMetadata] = [] + result: list[CallPool] = [] try: cursor = call_collection.find({}) async for item in cursor: - result.extend([CallMetadata(**item)]) + result.extend([CallPool.model_validate(item)]) except Exception as e: LOGGER.error(msg=f"获取Call元数据失败:{e}") -- Gitee From 10363c47af2ba5d180e3580f45f929f97091140d Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 24 Jan 2025 01:06:35 +0800 Subject: [PATCH 2/2] =?UTF-8?q?Loader=E6=8D=A2=E7=94=A8anyio?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/pool.py | 5 ++-- apps/entities/task.py | 1 + apps/models/postgres.py | 16 ++++++++++++ apps/scheduler/call/api.py | 28 ++++++++++---------- apps/scheduler/embedding.py | 21 --------------- apps/scheduler/executor/flow.py | 1 - apps/scheduler/executor/message.py | 3 --- apps/scheduler/pool/loader/call.py | 25 ++++-------------- apps/scheduler/pool/loader/metadata.py | 6 ++--- apps/scheduler/pool/loader/openapi.py | 16 +++++++----- apps/scheduler/pool/loader/service.py | 36 +++++++++++++++++++++++--- apps/scheduler/scheduler/flow.py | 23 ++++++++-------- apps/scheduler/scheduler/message.py | 26 +++++++++---------- 13 files changed, 107 insertions(+), 100 deletions(-) delete mode 100644 apps/scheduler/embedding.py diff --git a/apps/entities/pool.py b/apps/entities/pool.py index 85206e39..ca0773b4 100644 --- a/apps/entities/pool.py +++ b/apps/entities/pool.py @@ -49,7 +49,7 @@ class CallPool(PoolBase): collection: call """ - id: str = Field(description="Call的ID") + id: str = Field(description="Call的ID", alias="_id") type: CallType = Field(description="Call的类型") path: str = Field(description="Call的路径") @@ -65,8 +65,7 @@ class NodePool(PoolBase): 2. Python Node的路径格式样例:“tune::call.tune.CheckSystem” """ - id: str = Field(description="Node的ID", default_factory=lambda: str(uuid.uuid4())) - created_at: None = None + id: str = Field(description="Node的ID", default_factory=lambda: str(uuid.uuid4()), alias="_id") service_id: str = Field(description="Node所属的Service ID") call_id: str = Field(description="所使用的Call的ID") fixed_params: dict[str, Any] = Field(description="Node的固定参数", default={}) diff --git a/apps/entities/task.py b/apps/entities/task.py index 3b9af5b0..e51f543d 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -56,6 +56,7 @@ class TaskBlock(BaseModel): class RequestDataApp(BaseModel): """模型对话中包含的app信息""" + app_id: str = Field(description="应用ID", alias="appId") flow_id: str = Field(description="Flow ID", alias="flowId") params: dict[str, Any] = Field(description="插件参数") diff --git a/apps/models/postgres.py b/apps/models/postgres.py index 639fc878..9ede34c2 100644 --- a/apps/models/postgres.py +++ b/apps/models/postgres.py @@ -3,6 +3,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +import aiohttp from sqlalchemy.ext.asyncio import ( AsyncSession, async_sessionmaker, @@ -37,3 +38,18 @@ class PostgreSQL: async def get_session(cls) -> AsyncSession: """获取异步会话""" return async_sessionmaker(cls._engine, class_=AsyncSession, expire_on_commit=False)() + + + @staticmethod + async def get_embedding(text: list[str]) -> list[float]: + """访问Vectorize的Embedding API,获得向量化数据 + + :param text: 待向量化文本(多条文本组成List) + :return: 文本对应的向量(顺序与text一致,也为List) + """ + api = config["VECTORIZE_HOST"].rstrip("/") + "/embedding" + + async with aiohttp.ClientSession() as session, session.post( + api, json={"texts": text}, timeout=30, + ) as response: + return await response.json() diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api.py index 2098b84b..50a78658 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api.py @@ -130,19 +130,19 @@ class API(metaclass=CoreCall, param_cls=_APIParams, output_cls=_APIOutput): if data is None: data = {} - if self._auth is not None and "type" in self._auth: - if self._auth["type"] == "header": - req_header.update(self._auth["args"]) - elif self._auth["type"] == "cookie": - req_cookie.update(self._auth["args"]) - elif self._auth["type"] == "params": - req_params.update(self._auth["args"]) - elif self._auth["type"] == "oidc": + if params.auth is not None and "type" in params.auth: + if params.auth["type"] == "header": + req_header.update(params.auth["args"]) + elif params.auth["type"] == "cookie": + req_cookie.update(params.auth["args"]) + elif params.auth["type"] == "params": + req_params.update(params.auth["args"]) + elif params.auth["type"] == "oidc": token = await TokenManager.get_plugin_token( - self._auth["domain"], + params.auth["domain"], syscall_vars.session_id, - self._auth["access_token_url"], - int(self._auth["token_expire_time"]), + params.auth["access_token_url"], + int(params.auth["token_expire_time"]), ) req_header.update({"access-token": token}) @@ -190,6 +190,7 @@ class API(metaclass=CoreCall, param_cls=_APIParams, output_cls=_APIOutput): message=text, data={"api_response_data": await response.text()}, ) + response_status = response.status response_data = await response.text() # 返回值只支持JSON的情况 @@ -202,14 +203,15 @@ class API(metaclass=CoreCall, param_cls=_APIParams, output_cls=_APIOutput): # 如果没有返回结果 if response_data is None: return _APIOutput( - http_code=response.status, + http_code=response_status, output={}, message=f"调用接口{params.full_url},作用为但返回值为空。", ) response_data = self._process_response_schema(response_data, response_schema) return _APIOutput( + http_code=response_status, output=json.loads(response_data), - output_schema=response_schema, message=f"""调用API从外部数据源获取了数据。API和数据源的描述为:{usage}""", ) + diff --git a/apps/scheduler/embedding.py b/apps/scheduler/embedding.py deleted file mode 100644 index 03493803..00000000 --- a/apps/scheduler/embedding.py +++ /dev/null @@ -1,21 +0,0 @@ -"""从Vectorize获取向量化数据 - -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -""" -import aiohttp - -from apps.common.config import config - - -async def get_embedding(text: list[str]) -> list[float]: - """访问Vectorize的Embedding API,获得向量化数据 - - :param text: 待向量化文本(多条文本组成List) - :return: 文本对应的向量(顺序与text一致,也为List) - """ - api = config["VECTORIZE_HOST"].rstrip("/") + "/embedding" - - async with aiohttp.ClientSession() as session, session.post( - api, json={"texts": text}, timeout=30, - ) as response: - return await response.json() diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 1563168e..c663ffde 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -9,7 +9,6 @@ from apps.constants import LOGGER, MAX_SCHEDULER_HISTORY_SIZE from apps.entities.enum_var import StepStatus from apps.entities.flow import Step from apps.entities.plugin import ( - CallResult, SysCallVars, SysExecVars, ) diff --git a/apps/scheduler/executor/message.py b/apps/scheduler/executor/message.py index 6ecdd0bd..2cea3e59 100644 --- a/apps/scheduler/executor/message.py +++ b/apps/scheduler/executor/message.py @@ -12,9 +12,6 @@ from apps.entities.message import ( StepOutputContent, TextAddContent, ) -from apps.entities.plugin import ( - CallResult, -) from apps.entities.task import ExecutorState, FlowHistory from apps.llm.patterns.executor import ExecutorResult from apps.manager.task import TaskManager diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 988390eb..617834ec 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -14,6 +14,7 @@ from apps.constants import CALL_DIR, LOGGER from apps.entities.enum_var import CallType from apps.entities.pool import CallPool from apps.models.mongo import MongoDB +from apps.scheduler.pool.util import get_short_hash class CallLoader: @@ -40,19 +41,6 @@ class CallLoader: return flag - @staticmethod - def _check_package(package_name: str) -> bool: - """检查包是否符合要求""" - package = sys.modules["call." + package_name] - - if not hasattr(package, "service") or not isinstance(package.service, str) or not package.service: - return False - - if not hasattr(package, "__all__") or not isinstance(package.__all__, list) or not package.__all__: # noqa: SIM103 - return False - - return True - @staticmethod async def _load_system_call() -> list[CallPool]: """加载系统Call""" @@ -101,10 +89,6 @@ class CallLoader: if not call_file.is_dir(): continue - if not cls._check_package(call_file.name): - LOGGER.info(msg=f"包call.{call_file.name}不符合Call标准要求,跳过载入。") - continue - # 载入包 try: call_package = importlib.import_module("call." + call_file.name) @@ -118,13 +102,14 @@ class CallLoader: LOGGER.info(msg=f"类{call_cls.__name__}不符合Call标准要求,跳过载入。") continue + cls_path = f"{call_package.service}::call.{call_file.name}.{call_id}" metadata.append( - CallMetadata( - _id=call_id, + CallPool( + _id=get_short_hash(cls_path.encode()), type=CallType.PYTHON, name=call_cls.name, description=call_cls.description, - path=f"{call_package.service}::call.{call_file.name}.{call_id}", + path=cls_path, ), ) except Exception as e: diff --git a/apps/scheduler/pool/loader/metadata.py b/apps/scheduler/pool/loader/metadata.py index 6e51c09c..1038ac90 100644 --- a/apps/scheduler/pool/loader/metadata.py +++ b/apps/scheduler/pool/loader/metadata.py @@ -2,10 +2,10 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from pathlib import Path from typing import Any, Union import yaml +from anyio import Path from apps.common.config import config from apps.constants import LOGGER @@ -24,7 +24,7 @@ class MetadataLoader: """加载【单个】元数据""" # 检查yaml格式 try: - metadata_dict = yaml.safe_load(file_path.read_text()) + metadata_dict = yaml.safe_load(await file_path.read_text()) metadata_type = metadata_dict["type"] except Exception as e: err = f"metadata.yaml读取失败: {e}" @@ -74,4 +74,4 @@ class MetadataLoader: raise RuntimeError(err) from e yaml_data = data.model_dump(by_alias=True, exclude_none=True) - resource_path.write_text(yaml.safe_dump(yaml_data)) + await resource_path.write_text(yaml.safe_dump(yaml_data)) diff --git a/apps/scheduler/pool/loader/openapi.py b/apps/scheduler/pool/loader/openapi.py index 05b8be9b..46a6c36d 100644 --- a/apps/scheduler/pool/loader/openapi.py +++ b/apps/scheduler/pool/loader/openapi.py @@ -2,9 +2,8 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from pathlib import Path - import yaml +from anyio import Path from apps.constants import LOGGER from apps.entities.pool import NodePool @@ -23,14 +22,15 @@ class OpenAPILoader: @classmethod - def load(cls, yaml_path: Path) -> ReducedOpenAPISpec: + async def load(cls, yaml_path: Path) -> ReducedOpenAPISpec: """从本地磁盘加载OpenAPI文档""" if not yaml_path.exists(): msg = f"File not found: {yaml_path}" raise FileNotFoundError(msg) - with yaml_path.open(mode="r") as f: - spec = yaml.safe_load(f) + f = await yaml_path.open(mode="r") + spec = yaml.safe_load(await f.read()) + await f.aclose() return reduce_openapi_spec(spec) @@ -56,11 +56,13 @@ class OpenAPILoader: return nodes @classmethod - async def load_one(cls, yaml_path: Path) -> None: + async def load_one(cls, yaml_path: Path) -> list[NodePool]: """加载单个OpenAPI文档,可以直接指定路径""" try: - spec = cls.load(yaml_path) + spec = await cls.load(yaml_path) except Exception as e: err = f"加载OpenAPI文档失败:{e}" LOGGER.error(msg=err) raise RuntimeError(err) from e + + return cls._process_spec(yaml_path.name, spec) diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index d5ef4254..77a09c4b 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -2,12 +2,16 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from pathlib import Path from typing import Any +from anyio import Path + from apps.common.config import config +from apps.entities.vector import NodeVector, ServiceVector from apps.models.mongo import MongoDB +from apps.models.postgres import PostgreSQL from apps.scheduler.pool.loader.metadata import MetadataLoader +from apps.scheduler.pool.loader.openapi import OpenAPILoader class ServiceLoader: @@ -21,9 +25,33 @@ class ServiceLoader: """加载单个Service""" service_path = Path(config["SERVICE_DIR"]) / "service" / service_dir # 载入元数据 - metadata = MetadataLoader.load(service_path / "metadata.yaml") - # 载入OpenAPI文档 - + metadata = await MetadataLoader.load(service_path / "metadata.yaml") + # 载入OpenAPI文档,获取Node列表 + nodes = await OpenAPILoader.load_one(service_path / "openapi.yaml") + + # 向量化所有数据 + session = await PostgreSQL.get_session() + service_vec = ServiceVector( + _id=metadata.id, + embedding=PostgreSQL.get_embedding([metadata.description]), + ) + session.add(service_vec) + + node_descriptions = [] + for node in nodes: + node_descriptions += [node.description] + + node_vecs = await PostgreSQL.get_embedding(node_descriptions) + for i, data in enumerate(node_vecs): + node_vec = NodeVector( + _id=nodes[i].id, + embedding=data, + ) + session.add(node_vec) + + await session.commit() + + diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 4804cce6..c03978fb 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -4,12 +4,12 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from typing import Optional -from apps.entities.task import RequestDataPlugin +from apps.entities.task import RequestDataApp from apps.llm.patterns import Select from apps.scheduler.pool.pool import Pool -async def choose_flow(task_id: str, question: str, origin_plugin_list: list[RequestDataPlugin]) -> tuple[str, Optional[RequestDataPlugin]]: +async def choose_flow(task_id: str, question: str, origin_app_list: list[RequestDataApp]) -> tuple[str, Optional[RequestDataApp]]: """依据用户的输入和选择,构造对应的Flow。 - 当用户没有选择任何Plugin时,直接进行智能问答 @@ -21,21 +21,21 @@ async def choose_flow(task_id: str, question: str, origin_plugin_list: list[Requ :result: 经LLM选择的Plugin ID和Flow ID """ # 去掉无效的插件选项:plugin_id为空 - plugin_ids = [] + app_ids = [] flow_ids = [] - for item in origin_plugin_list: - if not item.plugin_id: + for item in origin_app_list: + if not item.app_id: continue - plugin_ids.append(item.plugin_id) + app_ids.append(item.app_id) if item.flow_id: flow_ids.append(item) # 用户什么都不选,直接智能问答 - if len(plugin_ids) == 0: + if len(app_ids) == 0: return "", None # 用户只选了auto - if len(plugin_ids) == 1 and plugin_ids[0] == "auto": + if len(app_ids) == 1 and app_ids[0] == "auto": # 用户要求自动识别 plugin_top = Pool().get_k_plugins(question) # 聚合插件的Flow @@ -67,11 +67,10 @@ async def choose_flow(task_id: str, question: str, origin_plugin_list: list[Requ if selected_id == "KnowledgeBase": return "", None - plugin_id = selected_id.split("/")[0] + app_id = selected_id.split("/")[0] flow_id = selected_id.split("/")[1] - return plugin_id, RequestDataPlugin( - plugin_id=plugin_id, + return app_id, RequestDataApp( + app_id=app_id, flow_id=flow_id, params={}, - auth={}, ) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index fb7bd87b..9771b108 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -34,17 +34,17 @@ async def push_init_message(task_id: str, queue: MessageQueue, post_body: Reques # 组装feature if is_flow: feature = InitContentFeature( - max_tokens=post_body.features.max_tokens, - context_num=post_body.features.context_num, - enable_feedback=False, - enable_regenerate=False, + maxTokens=post_body.features.max_tokens, + contextNum=post_body.features.context_num, + enableFeedback=False, + enableRegenerate=False, ) else: feature = InitContentFeature( - max_tokens=post_body.features.max_tokens, - context_num=post_body.features.context_num, - enable_feedback=True, - enable_regenerate=True, + maxTokens=post_body.features.max_tokens, + contextNum=post_body.features.context_num, + enableFeedback=True, + enableRegenerate=True, ) # 保存必要信息到Task @@ -54,7 +54,7 @@ async def push_init_message(task_id: str, queue: MessageQueue, post_body: Reques await TaskManager.set_task(task_id, task) # 推送初始化消息 - await queue.push_output(event_type=EventType.INIT, data=InitContent(feature=feature, created_at=created_at).model_dump(exclude_none=True, by_alias=True)) + await queue.push_output(event_type=EventType.INIT, data=InitContent(feature=feature, createdAt=created_at).model_dump(exclude_none=True, by_alias=True)) async def push_rag_message(task_id: str, queue: MessageQueue, user_sub: str, rag_data: RAGQueryReq) -> None: @@ -108,9 +108,9 @@ async def _push_rag_chunk(task_id: str, queue: MessageQueue, content: str, rag_i async def push_document_message(queue: MessageQueue, doc: Union[RecordDocument, Document]) -> None: """推送文档消息""" content = DocumentAddContent( - document_id=doc.id, - document_name=doc.name, - document_type=doc.type, - document_size=round(doc.size, 2), + documentId=doc.id, + documentName=doc.name, + documentType=doc.type, + documentSize=round(doc.size, 2), ) await queue.push_output(event_type=EventType.DOCUMENT_ADD, data=content.model_dump(exclude_none=True, by_alias=True)) -- Gitee