diff --git a/apps/scheduler/call/sql.py b/apps/scheduler/call/sql.py index 6368293936fabda693b10597dbbe90c8123bc4ec..312c5faafed695c2eb015d7f1ab22ca2197f03b4 100644 --- a/apps/scheduler/call/sql.py +++ b/apps/scheduler/call/sql.py @@ -23,7 +23,8 @@ logger = logging.getLogger("ray") class SQLOutput(BaseModel): """SQL工具的输出""" - messages: list[str] = Field(description="SQL工具的执行结果列表") + message: str = Field(description="SQL工具的执行结果") + dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果") class SQL(CoreCall): @@ -31,73 +32,68 @@ class SQL(CoreCall): name: str = "数据库" description: str = "使用大模型生成SQL语句,用于查询数据库中的结构化数据" - message: str = Field(description="SQL工具的执行结果") - database_id: str = Field(description="数据库的id", alias="db_sn", default=None) - table_id_list: Optional[list[str]] = Field(description="表id列表", default=[]) - top_k: int = Field(description="返回的答案数量", default=1) - use_llm_enhancements: Optional[bool] = Field(description="是否使用大模型增强", default=False) + sql: Optional[str] = Field(description="用户输入") - def init(self, _syscall_vars: CallVars, **_kwargs) -> dict[str, Any]: + + def init(self, _syscall_vars: CallVars, **_kwargs) -> None: # noqa: ANN003 """初始化SQL工具。""" # 初始化aiohttp的ClientSession - self._params_dict = { - "database_id": self.database_id, - "table_id_list": self.table_id_list, - "topk": self.top_k, - "use_llm_enhancements": self.use_llm_enhancements, - "question": _syscall_vars.question, - } + self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300)) + + + async def exec(self, _slot_data: dict[str, Any]) -> SQLOutput: + """运行SQL工具""" + # 获取必要参数 + syscall_vars: CallVars = getattr(self, "_syscall_vars") - self._generate_url = config["CHAT2DB_HOST"].rstrip("/") + "/sql/generate" - self._execute_url = config["CHAT2DB_HOST"].rstrip("/") + "/sql/execute" - self._headers = { + # 若手动设置了SQL,则直接使用 + session = await PostgreSQL.get_session() + if params.sql: + try: + result = (await session.execute(text(params.sql))).all() + await session.close() + + dataset_list = [db_item._asdict() for db_item in result] + return SQLOutput( + message="SQL查询成功!", + dataset=dataset_list, + ) + except Exception as e: + raise CallError(message=f"SQL查询错误:{e!s}", data={}) from e + + # 若未设置SQL,则调用Chat2DB工具API,获取SQL + post_data = { + "question": syscall_vars.question, + "topk_sql": 5, + "use_llm_enhancements": True, + } + headers = { "Content-Type": "application/json", } - return self._params_dict + async with self._session.post(config["SQL_URL"], ssl=False, json=post_data, headers=headers) as response: + if response.status != status.HTTP_200_OK: + raise CallError( + message=f"SQL查询错误:API返回状态码{response.status}, 详细原因为{response.reason}。", + data={"response": await response.text()}, + ) + result = json.loads(await response.text()) + logger.info("SQL工具返回的信息为:%s", result) + await self._session.close() + + for item in result["sql_list"]: + try: + db_result = (await session.execute(text(item["sql"]))).all() + await session.close() + + dataset_list = [db_item._asdict() for db_item in db_result] + return SQLOutput( + message="数据库查询成功!", + dataset=dataset_list, + ) + except Exception: # noqa: PERF203 + logger.exception("SQL查询错误,正在换用下一条SQL语句。") - async def exec(self) -> SQLOutput: - """运行SQL工具""" - # 获取必要参数 - sql_list = [] - retry = 0 - max_retry = 3 - while retry < max_retry: - async with aiohttp.ClientSession() as session, session.post(self._generate_url, headers=self._headers, json=self._params_dict) as response: - # 检查响应状态码 - if response.status == status.HTTP_200_OK: - result = await response.json() - sub_sql_list = result['result']['sql_list'] - sql_list += sub_sql_list - else: - text = await response.text() - logger.error("[SQL] 调用失败:%s", text) - continue - if len(sql_list) >= self._params_dict['topk']: - break - retry += 1 - sql_exec_results = [] - for sql_dict in sql_list: - database_id = sql_dict['database_id'] - sql = sql_dict['sql'] - data = { - 'database_id': database_id, - 'sql': sql - } - async with aiohttp.ClientSession() as session, session.post(self._execute_url, headers=self._headers, json=data) as response: - # 检查响应状态码 - if response.status == status.HTTP_200_OK: - result = await response.json() - sql_exec_result = result['result'] - sql_exec_results.append(sql_exec_result) - else: - text = await response.text() - logger.error("[SQL] 调用失败:%s", text) - continue - if len(sql_exec_results) >= self._params_dict['topk']: - return SQLOutput( - messages=sql_exec_results, - ).model_dump(exclude_none=True, by_alias=True) raise CallError( message="SQL查询错误:SQL语句错误,数据库查询失败!", data={},