diff --git a/.dockerignore b/.dockerignore
index 728281bcc8fecf11d1730acc96ae73e7ee0e442a..9e7e8498cf2e15e959462f02614827ff289ec2e3 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -9,4 +9,5 @@ Dockerfile
.git/
deploy/
docs/
+docs_for_openEuler/
.DS_Store
diff --git a/.gitignore b/.gitignore
index 64798293e2ad5a361b966d4e57d8e78573747901..723b472786fe1e6d55de7947d77dccda1ab1cb87 100644
--- a/.gitignore
+++ b/.gitignore
@@ -13,3 +13,4 @@ apps/embedding
logs
.git-credentials
.ruff_cache/
+config
diff --git a/ChangeLog b/ChangeLog
deleted file mode 100644
index 8b01da76a29b22f247b26e239234ce9a2555d30d..0000000000000000000000000000000000000000
--- a/ChangeLog
+++ /dev/null
@@ -1,3 +0,0 @@
-## [0.9.5] - 2025/04/15
-
-- 增加可视化工作流编辑页面
diff --git a/Dockerfile-base b/Dockerfile-base
index 3510709591d2f077f1060309dde7ea26ab9b7011..0dd94323db86e8e85fa341a9dc647951cdb136d1 100644
--- a/Dockerfile-base
+++ b/Dockerfile-base
@@ -9,7 +9,7 @@ RUN sed -i 's|repo.openeuler.org|mirrors.nju.edu.cn/openeuler|g' /etc/yum.repos.
sed -i '/metalink/d' /etc/yum.repos.d/openEuler.repo && \
sed -i '/metadata_expire/d' /etc/yum.repos.d/openEuler.repo && \
yum update -y &&\
- yum install -y python3 python3-pip shadow-utils findutils &&\
+ yum install -y python3 python3-pip shadow-utils findutils git nodejs npm &&\
groupadd -g 1001 eulercopilot && useradd -u 1001 -g eulercopilot eulercopilot &&\
yum clean all
diff --git a/README.md b/README.md
index 66fecbc36e50e45262126d8abf4c196b081f38f8..fac2bd6f5567e6a9885f66b7778da7869f3443c6 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
-# EulerCopilot 大模型应用平台
+# openEuler 智能化解决方案
#### 介绍
-EulerCopilot 是一款基于openEuler原生os构建的大模型应用平台,主要有以下几个核心功能:
+openEuler智能化解决方案(openEuler Intelligence) 是一个基于openEuler操作系统的大模型应用平台,主要有以下几个核心功能:
- 多路增强RAG:单路->多路,检索后综合优选,回答更准确40%->90%;
- 知识库管理:可视化、全流程覆盖、多用户空间保障数据隐私安全;
diff --git a/apps/common/__init__.py b/apps/common/__init__.py
index 6dfaec20b8a2672e82a6e0a41166b299c55a085b..a8b5806d99daa2b0f3f798f71d6c9b05b5d67e04 100644
--- a/apps/common/__init__.py
+++ b/apps/common/__init__.py
@@ -1,5 +1,2 @@
-"""
-Framework 公共模块
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""Framework 公共模块"""
diff --git a/apps/common/config.py b/apps/common/config.py
index 67c1a3a65ae2ac91a2df8137198d6b3ccb08f1fe..f5ced71c8523a735626445e4e8cb4222dab722a0 100644
--- a/apps/common/config.py
+++ b/apps/common/config.py
@@ -1,8 +1,6 @@
-"""
-配置文件处理模块
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""配置文件处理模块"""
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
import os
from copy import deepcopy
from pathlib import Path
@@ -22,7 +20,7 @@ class Config(metaclass=SingletonMeta):
"""读取配置文件;当PROD环境变量设置时,配置文件将在读取后删除"""
config_file = os.getenv("CONFIG")
if config_file is None:
- config_file = "./config/config.toml"
+ config_file = Path(__file__).parents[2] / "config" / "config.toml"
self._config = ConfigModel.model_validate(toml.load(config_file))
if os.getenv("PROD"):
diff --git a/apps/common/cryptohub.py b/apps/common/cryptohub.py
index 24c49a2894c0884f86aba48c92a1348f4de9caf8..e98ea2b81f58ea5019d6f42f18d90fb6ea8f41fc 100644
--- a/apps/common/cryptohub.py
+++ b/apps/common/cryptohub.py
@@ -1,8 +1,5 @@
-"""
-加密解密模块
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""加密解密模块"""
import hashlib
diff --git a/apps/common/oidc.py b/apps/common/oidc.py
index f4df6bdd3cc33184b3f6f7edd30f355c2171cd61..d57b57ec8eb1f9dc1a64a13ef4c749c9654b784a 100644
--- a/apps/common/oidc.py
+++ b/apps/common/oidc.py
@@ -1,8 +1,6 @@
-"""
-OIDC模块
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""OIDC模块"""
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
@@ -33,7 +31,8 @@ class OIDCProvider:
@staticmethod
async def set_token(user_sub: str, access_token: str, refresh_token: str) -> None:
"""设置MongoDB中的OIDC Token到sessions集合"""
- sessions_collection = MongoDB.get_collection("session")
+ mongo = MongoDB()
+ sessions_collection = mongo.get_collection("session")
try:
await sessions_collection.update_one(
diff --git a/apps/common/oidc_provider/authhub.py b/apps/common/oidc_provider/authhub.py
index f893a9da467726f83186004e376d4da463657daf..2301f4f2496a929be0ee4593285cbdebaaacf664 100644
--- a/apps/common/oidc_provider/authhub.py
+++ b/apps/common/oidc_provider/authhub.py
@@ -1,9 +1,10 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""Authhub OIDC Provider"""
import logging
from typing import Any
-import aiohttp
+import httpx
from fastapi import status
from apps.common.config import Config
@@ -41,15 +42,18 @@ class AuthhubOIDCProvider(OIDCProviderBase):
}
url = await cls.get_access_token_url()
result = None
- async with (
- aiohttp.ClientSession() as session,
- session.post(url, headers=headers, json=data, timeout=aiohttp.ClientTimeout(total=10)) as resp,
- ):
- if resp.status != status.HTTP_200_OK:
- err = f"[Authhub] 获取OIDC Token失败: {resp.status},完整输出: {await resp.text()}"
+ async with httpx.AsyncClient() as client:
+ resp = await client.post(
+ url,
+ headers=headers,
+ json=data,
+ timeout=10,
+ )
+ if resp.status_code != status.HTTP_200_OK:
+ err = f"[Authhub] 获取OIDC Token失败: {resp.status_code},完整输出: {resp.text}"
raise RuntimeError(err)
- logger.info("[Authhub] 获取OIDC Token成功: %s", await resp.text())
- result = await resp.json()
+ logger.info("[Authhub] 获取OIDC Token成功: %s", resp.text)
+ result = resp.json()
return {
"access_token": result["data"]["access_token"],
"refresh_token": result["data"]["refresh_token"],
@@ -72,15 +76,18 @@ class AuthhubOIDCProvider(OIDCProviderBase):
"client_id": login_config.app_id,
}
result = None
- async with (
- aiohttp.ClientSession() as session,
- session.post(url, headers=headers, json=data, timeout=aiohttp.ClientTimeout(total=10)) as resp,
- ):
- if resp.status != status.HTTP_200_OK:
- err = f"[Authhub] 获取用户信息失败: {resp.status},完整输出: {await resp.text()}"
+ async with httpx.AsyncClient() as client:
+ resp = await client.post(
+ url,
+ headers=headers,
+ json=data,
+ timeout=10,
+ )
+ if resp.status_code != status.HTTP_200_OK:
+ err = f"[Authhub] 获取用户信息失败: {resp.status_code},完整输出: {resp.text}"
raise RuntimeError(err)
- logger.info("[Authhub] 获取用户信息成功: %s", await resp.text())
- result = await resp.json()
+ logger.info("[Authhub] 获取用户信息成功: %s", resp.text)
+ result = resp.json()
return {
"user_sub": result["data"],
@@ -98,20 +105,18 @@ class AuthhubOIDCProvider(OIDCProviderBase):
"Content-Type": "application/json",
}
url = login_config.host_inner.rstrip("/") + "/oauth2/login-status"
- async with (
- aiohttp.ClientSession() as session,
- session.post(
+ async with httpx.AsyncClient() as client:
+ resp = await client.post(
url,
headers=headers,
json=data,
cookies=cookie,
- timeout=aiohttp.ClientTimeout(total=10),
- ) as resp,
- ):
- if resp.status != status.HTTP_200_OK:
- err = f"[Authhub] 获取登录状态失败: {resp.status},完整输出: {await resp.text()}"
+ timeout=10,
+ )
+ if resp.status_code != status.HTTP_200_OK:
+ err = f"[Authhub] 获取登录状态失败: {resp.status_code},完整输出: {resp.text}"
raise RuntimeError(err)
- result = await resp.json()
+ result = resp.json()
return {
"access_token": result["data"]["access_token"],
"refresh_token": result["data"]["refresh_token"],
@@ -126,12 +131,15 @@ class AuthhubOIDCProvider(OIDCProviderBase):
"Content-Type": "application/json",
}
url = login_config.host_inner.rstrip("/") + "/oauth2/logout"
- async with (
- aiohttp.ClientSession() as session,
- session.get(url, headers=headers, cookies=cookie, timeout=aiohttp.ClientTimeout(total=10)) as resp,
- ):
- if resp.status != status.HTTP_200_OK:
- err = f"[Authhub] 登出失败: {resp.status},完整输出: {await resp.text()}"
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(
+ url,
+ headers=headers,
+ cookies=cookie,
+ timeout=10,
+ )
+ if resp.status_code != status.HTTP_200_OK:
+ err = f"[Authhub] 登出失败: {resp.status_code},完整输出: {resp.text}"
raise RuntimeError(err)
@classmethod
diff --git a/apps/common/oidc_provider/base.py b/apps/common/oidc_provider/base.py
index f05255c5c28746203cdf8894f5b758f37bb1d4d7..aea39e4df09c02be9c536e5bbcdf761250d9992c 100644
--- a/apps/common/oidc_provider/base.py
+++ b/apps/common/oidc_provider/base.py
@@ -1,3 +1,4 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""OIDC Provider Base"""
from typing import Any
diff --git a/apps/common/oidc_provider/openeuler.py b/apps/common/oidc_provider/openeuler.py
index 278bac40d2af663510e2cac1d6f3e4ed90b9a5c4..28e938f128e4db50efb8d40cc3498f560757fe60 100644
--- a/apps/common/oidc_provider/openeuler.py
+++ b/apps/common/oidc_provider/openeuler.py
@@ -1,9 +1,10 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""OpenEuler OIDC Provider"""
import logging
from typing import Any
-import aiohttp
+import httpx
from fastapi import status
from apps.common.config import Config
@@ -39,19 +40,22 @@ class OpenEulerOIDCProvider(OIDCProviderBase):
"code": code,
}
url = await cls.get_access_token_url()
- headers = {
- "Content-Type": "application/x-www-form-urlencoded",
- }
- result = None
- async with (
- aiohttp.ClientSession() as session,
- session.post(url, headers=headers, data=data, timeout=aiohttp.ClientTimeout(total=10)) as resp,
- ):
- if resp.status != status.HTTP_200_OK:
- err = f"[OpenEuler] 获取OIDC Token失败: {resp.status},完整输出: {await resp.text()}"
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.post(
+ url,
+ headers={
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ data=data,
+ timeout=10.0,
+ )
+ if resp.status_code != status.HTTP_200_OK:
+ err = f"[OpenEuler] 获取OIDC Token失败: {resp.status_code},完整输出: {resp.text}"
raise RuntimeError(err)
- logger.info("[OpenEuler] 获取OIDC Token成功: %s", await resp.text())
- result = await resp.json()
+ logger.info("[OpenEuler] 获取OIDC Token成功: %s", resp.text)
+ result = resp.json()
+
return {
"access_token": result["access_token"],
"refresh_token": result["refresh_token"],
@@ -67,20 +71,20 @@ class OpenEulerOIDCProvider(OIDCProviderBase):
err = "Access token is empty."
raise RuntimeError(err)
url = login_config.host_inner.rstrip("/") + "/oneid/oidc/user"
- headers = {
- "Authorization": access_token,
- }
- result = None
- async with (
- aiohttp.ClientSession() as session,
- session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=10)) as resp,
- ):
- if resp.status != status.HTTP_200_OK:
- err = f"[OpenEuler] 获取OIDC用户失败: {resp.status},完整输出: {await resp.text()}"
+ async with httpx.AsyncClient() as client:
+ resp = await client.get(
+ url,
+ headers={
+ "Authorization": access_token,
+ },
+ timeout=10.0,
+ )
+ if resp.status_code != status.HTTP_200_OK:
+ err = f"[OpenEuler] 获取OIDC用户失败: {resp.status_code},完整输出: {resp.text}"
raise RuntimeError(err)
- logger.info("[OpenEuler] 获取OIDC用户成功: %s", await resp.text())
- result = await resp.json()
+ logger.info("[OpenEuler] 获取OIDC用户成功: %s", resp.text)
+ result = resp.json()
if not result["phone_number_verified"]:
err = "Could not validate credentials."
diff --git a/apps/common/queue.py b/apps/common/queue.py
index ba1fe3ed80b6be3943dfb5de0c83dca16f58bf04..8b481efeb9696fe3bc38b50422b6726dcb57b8b7 100644
--- a/apps/common/queue.py
+++ b/apps/common/queue.py
@@ -1,4 +1,6 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""消息队列模块"""
+
import asyncio
import json
import logging
diff --git a/apps/common/security.py b/apps/common/security.py
index 8955c4bd96974b6cc67c9d17a49f180a28e0ee0d..7bae582e69b8e835df84033a416f53c91c90b73e 100644
--- a/apps/common/security.py
+++ b/apps/common/security.py
@@ -1,8 +1,5 @@
-"""
-密文加密解密模块
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""密文加密解密模块"""
import base64
import binascii
diff --git a/apps/common/singleton.py b/apps/common/singleton.py
index 6f9ed5599f0a05b5d69926e81c13bb1beecf32da..673bbd4905248ed5871af3e4c828ab5645036fea 100644
--- a/apps/common/singleton.py
+++ b/apps/common/singleton.py
@@ -1,3 +1,4 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""单例模式"""
import threading
@@ -8,7 +9,9 @@ class SingletonMeta(type):
"""单例元类"""
_instances: ClassVar[dict[type, Any]] = {}
- _lock: ClassVar[threading.Lock] = threading.Lock()
+ """单例实例字典"""
+ _lock: ClassVar[threading.RLock] = threading.RLock()
+ """可重入锁"""
def __call__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
"""获取单例"""
@@ -16,4 +19,4 @@ class SingletonMeta(type):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
- return cls._instances[cls]
+ return cls._instances[cls]
diff --git a/apps/common/wordscheck.py b/apps/common/wordscheck.py
index 80ea7d31c8447d0a38ef1197d754dcd84403f2b3..67d87c9a1d837d645067f679dd27f7ff67171368 100644
--- a/apps/common/wordscheck.py
+++ b/apps/common/wordscheck.py
@@ -1,8 +1,5 @@
-"""
-敏感词检查模块
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""敏感词检查模块"""
import logging
from pathlib import Path
diff --git a/apps/constants.py b/apps/constants.py
index ca5421612c1682187f92785d3e677814d6a183ab..27e3e92aa82c5fa57c86e0bec211488188fd46e4 100644
--- a/apps/constants.py
+++ b/apps/constants.py
@@ -18,10 +18,15 @@ SLIDE_WINDOW_QUESTION_COUNT = 10
MAX_API_RESPONSE_LENGTH = 8192
# Executor最大步骤历史数
STEP_HISTORY_SIZE = 3
-
+# Session时间,单位为分钟
+SESSION_TTL = 30 * 24 * 60
+# JSON生成最大尝试次数
+JSON_GEN_MAX_TRIAL = 3
+# 推理开始标记
REASONING_BEGIN_TOKEN = [
"",
]
+# 推理结束标记
REASONING_END_TOKEN = [
"",
]
diff --git a/apps/dependency/__init__.py b/apps/dependency/__init__.py
index 824f0b4369a2b95697b53f786ec821362b00473d..83c26ea2f11c0fdbaf704a9fdcd91b0eacb57825 100644
--- a/apps/dependency/__init__.py
+++ b/apps/dependency/__init__.py
@@ -1,11 +1,6 @@
-"""
-FastAPI 依赖注入模块
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""FastAPI 依赖注入模块"""
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
-
-from apps.dependency.csrf import verify_csrf_token
-from apps.dependency.session import VerifySessionMiddleware
from apps.dependency.user import (
get_session,
get_user,
@@ -15,11 +10,9 @@ from apps.dependency.user import (
)
__all__ = [
- "VerifySessionMiddleware",
"get_session",
"get_user",
"get_user_by_api_key",
"verify_api_key",
- "verify_csrf_token",
"verify_user",
]
diff --git a/apps/dependency/csrf.py b/apps/dependency/csrf.py
deleted file mode 100644
index fb747b5ede895863f2625d4ebb17cf50b1854a70..0000000000000000000000000000000000000000
--- a/apps/dependency/csrf.py
+++ /dev/null
@@ -1,35 +0,0 @@
-"""
-CSRF Token校验
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
-
-from fastapi import HTTPException, Request, Response, status
-
-from apps.common.config import Config
-from apps.manager.session import SessionManager
-
-
-async def verify_csrf_token(request: Request, response: Response) -> Response | None:
- """验证CSRF Token"""
- if not Config().get_config().fastapi.csrf:
- return None
-
- csrf_token = request.headers["x-csrf-token"].strip('"')
- session = request.cookies["ECSESSION"]
-
- if not await SessionManager.verify_csrf_token(session, csrf_token):
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF token is invalid.")
-
- new_csrf_token = await SessionManager.create_csrf_token(session)
- if not new_csrf_token:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Renew CSRF token failed.")
-
- if Config().get_config().deploy.cookie == "DEBUG":
- response.set_cookie("_csrf_tk", new_csrf_token, max_age=Config().get_config().fastapi.session_ttl * 60,
- domain=Config().get_config().fastapi.domain)
- else:
- response.set_cookie("_csrf_tk", new_csrf_token, max_age=Config().get_config().fastapi.session_ttl * 60,
- secure=True, domain=Config().get_config().fastapi.domain, samesite="strict")
- return response
-
diff --git a/apps/dependency/session.py b/apps/dependency/session.py
deleted file mode 100644
index e83a6a97ebedc03cd5c3f1fb729ebcbadd5ce502..0000000000000000000000000000000000000000
--- a/apps/dependency/session.py
+++ /dev/null
@@ -1,90 +0,0 @@
-"""
-浏览器Session校验
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
-
-from typing import Any
-
-from fastapi import Response
-from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
-from starlette.requests import Request
-
-from apps.common.config import Config
-from apps.manager.session import SessionManager
-
-BYPASS_LIST = [
- "/health_check",
- "/api/auth/login",
- "/api/auth/logout",
-]
-
-
-class VerifySessionMiddleware(BaseHTTPMiddleware):
- """浏览器Session校验中间件"""
-
- def _check_bypass_list(self, path: str) -> bool:
- """检查请求路径是否需要跳过验证"""
- return path in BYPASS_LIST
-
- def _validate_client(self, request: Request) -> str:
- """验证客户端信息并返回主机地址"""
- if request.client is None or request.client.host is None:
- err = "[VerifySessionMiddleware] 无法检测请求来源IP!"
- raise ValueError(err)
- return request.client.host
-
- def _update_cookie_header(self, request: Request, session_id: str) -> None:
- """更新请求头中的cookie信息"""
- cookie_str = ""
- for item in request.scope["headers"]:
- if item[0] == b"cookie":
- cookie_str = item[1].decode()
- request.scope["headers"].remove(item)
- break
-
- all_cookies = ""
- if cookie_str:
- other_headers = cookie_str.split(";")
- all_cookies = "; ".join(item for item in other_headers if "ECSESSION" not in item)
-
- all_cookies = f"{all_cookies}; ECSESSION={session_id}" if all_cookies else f"ECSESSION={session_id}"
- request.scope["headers"].append((b"cookie", all_cookies.encode()))
-
- def _set_response_cookie(self, response: Response, session_id: str) -> None:
- """设置响应cookie"""
- # 检查 是否其他dependence 设置过cookie
- if "ECSESSION" in response.headers.get("set-cookie", ""):
- return
-
- cookie_params: dict[str, Any] = {
- "key": "ECSESSION",
- "value": session_id,
- "domain": Config().get_config().fastapi.domain,
- }
-
- if Config().get_config().deploy.cookie != "DEBUG":
- cookie_params.update({
- "httponly": True,
- "secure": True,
- "samesite": "strict",
- "max_age": Config().get_config().fastapi.session_ttl * 60,
- })
-
- response.set_cookie(**cookie_params)
-
- async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
- """浏览器Session校验中间件"""
- if self._check_bypass_list(request.url.path):
- return await call_next(request)
-
- host = self._validate_client(request)
- cookie = request.cookies.get("ECSESSION", "")
- session_id = await SessionManager.get_session(cookie, host)
-
- if session_id != cookie:
- self._update_cookie_header(request, session_id)
-
- response = await call_next(request)
- self._set_response_cookie(response, session_id)
- return response
diff --git a/apps/dependency/user.py b/apps/dependency/user.py
index 4699edb47ad19e968262075c8df56139de31fc11..d4d258fee5e969c3b72c4c08713cc0beb6d787f7 100644
--- a/apps/dependency/user.py
+++ b/apps/dependency/user.py
@@ -1,19 +1,14 @@
-"""
-用户鉴权
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""用户鉴权"""
import logging
-from fastapi import Depends, Response
+from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from starlette import status
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection
-from apps.common.config import Config
-from apps.common.oidc import oidc_provider
from apps.manager.api_key import ApiKeyManager
from apps.manager.session import SessionManager
@@ -21,76 +16,29 @@ logger = logging.getLogger(__name__)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
-async def _verify_oidc_auth(request: HTTPConnection, response: Response) -> str:
+async def _get_session_id_from_request(request: HTTPConnection) -> str | None:
"""
- 验证OIDC认证状态并获取用户信息
+ 从请求中获取 session_id
:param request: HTTP请求
- :return: 用户信息字典
- :raises: HTTPException 当OIDC验证失败时
+ :return: session_id
"""
- try:
- tokens = await oidc_provider.get_login_status(request.cookies)
- except Exception as err:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 检查OIDC登录状态失败") from err
-
- if not tokens:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 检查OIDC登录状态失败")
-
- try:
- user_info = await oidc_provider.get_oidc_user(tokens["access_token"])
- except Exception as err:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 获取用户信息失败") from err
-
- if not user_info:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 获取用户信息失败")
-
- # 创建新的session
- if request.client is None:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 获取登录IP失败")
-
- user_sub = user_info["user_sub"]
- user_host = request.client.host
- try:
- current_session = request.cookies["ECSESSION"]
- await SessionManager.delete_session(current_session)
- except Exception:
- logger.exception("[VerifySessionMiddleware] 删除session失败")
-
- current_session = await SessionManager.create_session(user_host, user_sub)
-
- # 设置cookie
- if Config().get_config().deploy.cookie == "DEBUG":
- response.set_cookie(
- "ECSESSION",
- current_session,
- )
- else:
- response.set_cookie(
- "ECSESSION",
- current_session,
- max_age=Config().get_config().fastapi.session_ttl * 60,
- secure=True,
- domain=Config().get_config().fastapi.domain,
- httponly=True,
- samesite="strict",
- )
+ session_id = None
+ auth_header = request.headers.get("Authorization")
+ if auth_header and auth_header.startswith("Bearer "):
+ session_id = auth_header.split(" ", 1)[1]
- return user_sub
+ return session_id
-async def verify_user(request: HTTPConnection, response: Response) -> None:
+async def verify_user(request: HTTPConnection) -> None:
"""
验证Session是否已鉴权;未鉴权则抛出HTTP 401;接口级dependence
:param request: HTTP请求
:return: None
"""
- session_id = request.cookies["ECSESSION"]
- if await SessionManager.verify_user(session_id):
- return
-
- await _verify_oidc_auth(request, response)
+ request.state.session_id = await get_session(request)
async def get_session(request: HTTPConnection) -> str:
@@ -100,26 +48,44 @@ async def get_session(request: HTTPConnection) -> str:
:param request: HTTP请求
:return: Session ID
"""
- session_id = request.cookies["ECSESSION"]
+ session_id = await _get_session_id_from_request(request)
+ if not session_id:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Session ID 不存在",
+ )
if not await SessionManager.verify_user(session_id):
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.")
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Session ID 鉴权失败",
+ )
return session_id
-async def get_user(request: HTTPConnection, response: Response) -> str:
+async def get_user(request: HTTPConnection) -> str:
"""
验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence
:param request: HTTP请求体
:return: 用户sub
"""
- session_id = request.cookies["ECSESSION"]
+ session_id = await _get_session_id_from_request(request)
+ if not session_id:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Session ID 不存在",
+ )
- user = await SessionManager.get_user(session_id)
- if user:
- return user
+ user_sub = await SessionManager.get_user(session_id)
+ if not user_sub:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Session ID 鉴权失败",
+ )
- return await _verify_oidc_auth(request, response)
+ request.state.user_sub = user_sub
+ request.state.session_id = session_id
+ return user_sub
async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None:
diff --git a/apps/entities/__init__.py b/apps/entities/__init__.py
index eb970fbf3966fb8f4c9b4b801ffc0582fe8fb5ca..d44768392a78ae68e301f11652c231dd0ea027f6 100644
--- a/apps/entities/__init__.py
+++ b/apps/entities/__init__.py
@@ -1,5 +1,2 @@
-"""
-Framework 数据结构定义
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""Framework 数据结构定义"""
diff --git a/apps/entities/agent.py b/apps/entities/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..69426e90078aeaa6247fbfde64d40f4f4c7b6483
--- /dev/null
+++ b/apps/entities/agent.py
@@ -0,0 +1,23 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""agent 相关数据结构"""
+
+from pydantic import Field
+
+from apps.entities.enum_var import (
+ AppType,
+ MetadataType,
+)
+from apps.entities.flow import Permission
+from apps.entities.mcp import MCPMetadataBase
+
+
+class AgentAppMetadata(MCPMetadataBase):
+ """智能体App的元数据"""
+
+ type: MetadataType = MetadataType.APP
+ app_type: AppType = Field(default=AppType.AGENT, description="应用类型", frozen=True)
+ published: bool = Field(description="是否发布", default=False)
+ history_len: int = Field(description="对话轮次", default=3, le=10)
+ mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表")
+ permission: Permission | None = Field(description="应用权限配置", default=None)
+ version: str = Field(description="元数据版本")
diff --git a/apps/entities/api_key.py b/apps/entities/api_key.py
index a2f2566da4895f7e56705e4f4d15c7dbc33b14c8..ccc6e2f4ee5b08d7e3735b4c0ad2bf8c89071651 100644
--- a/apps/entities/api_key.py
+++ b/apps/entities/api_key.py
@@ -1,3 +1,4 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""API密钥相关数据结构"""
from pydantic import BaseModel
diff --git a/apps/entities/appcenter.py b/apps/entities/appcenter.py
index bf021c826b552ca64e7277fa52b311db0f0f03b6..1244f72525ab6de9c5d1e9246a9852e9bbbd7270 100644
--- a/apps/entities/appcenter.py
+++ b/apps/entities/appcenter.py
@@ -1,18 +1,16 @@
-"""
-应用中心相关 API 基础数据结构定义
-
-Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
+"""应用中心相关 API 基础数据结构定义"""
from pydantic import BaseModel, Field
-from apps.entities.enum_var import PermissionType
+from apps.entities.enum_var import AppType, PermissionType
class AppCenterCardItem(BaseModel):
"""应用中心卡片数据结构"""
app_id: str = Field(..., alias="appId", description="应用ID")
+ app_type: AppType = Field(..., alias="appType", description="应用类型")
icon: str = Field(..., description="应用图标")
name: str = Field(..., description="应用名称")
description: str = Field(..., description="应用简介")
@@ -55,6 +53,7 @@ class AppFlowInfo(BaseModel):
class AppData(BaseModel):
"""应用信息数据结构"""
+ app_type: AppType = Field(..., alias="appType", description="应用类型")
icon: str = Field(default="", description="图标")
name: str = Field(..., max_length=20, description="应用名称")
description: str = Field(..., max_length=150, description="应用简介")
@@ -65,3 +64,4 @@ class AppData(BaseModel):
permission: AppPermissionData = Field(
default_factory=lambda: AppPermissionData(authorizedUsers=None), description="权限配置")
workflows: list[AppFlowInfo] = Field(default=[], description="工作流信息列表")
+ mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表")
diff --git a/apps/entities/collection.py b/apps/entities/collection.py
index af0b5a88e128b2c40c6e86eb814a9889bde654aa..1d6aa097da435e611a794f3a7d93fa2da3f37861 100644
--- a/apps/entities/collection.py
+++ b/apps/entities/collection.py
@@ -1,15 +1,14 @@
-"""
-MongoDB中的数据结构
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""MongoDB中的数据结构"""
import uuid
from datetime import UTC, datetime
from pydantic import BaseModel, Field
+from apps.common.config import Config
from apps.constants import NEW_CHAT
+from apps.templates.generate_llm_operator_config import llm_provider_dict
class Blacklist(BaseModel):
@@ -56,12 +55,44 @@ class User(BaseModel):
is_whitelisted: bool = False
credit: int = 100
api_key: str | None = None
- kb_id: str | None = None
conversations: list[str] = []
domains: list[UserDomainData] = []
app_usage: dict[str, AppUsageData] = {}
fav_apps: list[str] = []
fav_services: list[str] = []
+ is_admin: bool = Field(default=False, description="是否为管理员")
+
+
+class LLM(BaseModel):
+ """
+ 大模型信息
+
+ Collection: llm
+ """
+
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
+ user_sub: str = Field(default="", description="用户ID")
+ icon: str = Field(default=llm_provider_dict["ollama"]["icon"], description="图标")
+ openai_base_url: str = Field(default=Config().get_config().llm.endpoint)
+ openai_api_key: str = Field(default=Config().get_config().llm.key)
+ model_name: str = Field(default=Config().get_config().llm.model)
+ max_tokens: int | None = Field(default=Config().get_config().llm.max_tokens)
+ created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3))
+
+
+class LLMItem(BaseModel):
+ """大模型信息"""
+
+ llm_id: str = Field(default="empty")
+ model_name: str = Field(default=Config().get_config().llm.model)
+ icon: str = Field(default=llm_provider_dict["ollama"]["icon"])
+
+
+class KnowledgeBaseItem(BaseModel):
+ """知识库信息"""
+
+ kb_id: str
+ kb_name: str
class Conversation(BaseModel):
@@ -80,7 +111,9 @@ class Conversation(BaseModel):
tasks: list[str] = []
unused_docs: list[str] = []
record_groups: list[str] = []
- debug : bool = Field(default=False)
+ debug: bool = Field(default=False)
+ llm: LLMItem | None = None
+ kb_list: list[KnowledgeBaseItem] = Field(default=[])
class Document(BaseModel):
diff --git a/apps/entities/config.py b/apps/entities/config.py
index bec3718e6fae8cf6f7d0be22449fafda204be311..b88a81f1afb018663713a3bf4c0b9b62436e59d4 100644
--- a/apps/entities/config.py
+++ b/apps/entities/config.py
@@ -1,3 +1,4 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""配置文件数据结构"""
from typing import Literal
@@ -55,8 +56,6 @@ class FastAPIConfig(BaseModel):
"""FastAPI配置"""
domain: str = Field(description="当前实例的域名")
- session_ttl: int = Field(description="用户需要刷新Token的间隔(min)", default=30)
- csrf: bool = Field(description="是否启用CSRF Token功能", default=False)
class MinioConfig(BaseModel):
@@ -84,8 +83,8 @@ class LLMConfig(BaseModel):
key: str = Field(description="LLM API密钥")
endpoint: str = Field(description="LLM API URL地址")
model: str = Field(description="LLM API 模型名")
- max_tokens: int = Field(description="LLM API 最大Token数", default=8192)
- temperature: float = Field(description="LLM API 温度", default=0.7)
+ max_tokens: int | None = Field(description="LLM API 最大Token数", default=None)
+ temperature: float | None = Field(description="LLM API 温度", default=None)
class FunctionCallConfig(BaseModel):
@@ -95,8 +94,8 @@ class FunctionCallConfig(BaseModel):
model: str = Field(description="Function Call 模型名")
endpoint: str = Field(description="Function Call API URL地址")
api_key: str = Field(description="Function Call API密钥")
- max_tokens: int = Field(description="Function Call 最大Token数", default=8192)
- temperature: float = Field(description="Function Call 温度", default=0.7)
+ max_tokens: int | None = Field(description="Function Call 最大Token数", default=None)
+ temperature: float | None = Field(description="Function Call 温度", default=None)
class SecurityConfig(BaseModel):
diff --git a/apps/entities/enum_var.py b/apps/entities/enum_var.py
index 281d53e07d3ecf9756bf7bf01ce28b1fad6a21b2..52a0bb163868873d381265ae67032d95661215e3 100644
--- a/apps/entities/enum_var.py
+++ b/apps/entities/enum_var.py
@@ -1,8 +1,5 @@
-"""
-枚举类型
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""枚举类型"""
from enum import Enum
@@ -60,6 +57,7 @@ class MetadataType(str, Enum):
SERVICE = "service"
APP = "app"
+ MCP_SERVICE = "mcp_service"
class EdgeType(str, Enum):
@@ -155,3 +153,18 @@ class CommentType(str, Enum):
LIKE = "liked"
DISLIKE = "disliked"
NONE = "none"
+
+
+class MCPSearchType(str, Enum):
+ """搜索类型"""
+
+ ALL = "all"
+ NAME = "name"
+ AUTHOR = "author"
+
+
+class AppType(str, Enum):
+ """应用中心应用类型"""
+
+ FLOW = "flow"
+ AGENT = "agent"
diff --git a/apps/entities/flow.py b/apps/entities/flow.py
index 22f278b5d7e8d71d5879e48e4c7b3bec9b3a6f50..8785b25eaf3d685f28dd6b41f4e90e2281f2e17d 100644
--- a/apps/entities/flow.py
+++ b/apps/entities/flow.py
@@ -1,8 +1,5 @@
-"""
-App、Flow和Service等外置配置数据结构
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""App、Flow和Service等外置配置数据结构"""
from typing import Any
@@ -10,6 +7,7 @@ from pydantic import BaseModel, Field
from apps.entities.appcenter import AppLink
from apps.entities.enum_var import (
+ AppType,
EdgeType,
MetadataType,
PermissionType,
@@ -134,6 +132,7 @@ class AppMetadata(MetadataBase):
"""App的元数据"""
type: MetadataType = MetadataType.APP
+ app_type: AppType = Field(default=AppType.FLOW, description="应用类型", frozen=True)
published: bool = Field(description="是否发布", default=False)
links: list[AppLink] = Field(description="相关链接", default=[])
first_questions: list[str] = Field(description="首次提问", default=[])
diff --git a/apps/entities/flow_topology.py b/apps/entities/flow_topology.py
index b547f09bf1b968665e53f7bab10f05656d0e292b..0c6463bb82f355ba05d475344683e970cfa46e00 100644
--- a/apps/entities/flow_topology.py
+++ b/apps/entities/flow_topology.py
@@ -1,8 +1,5 @@
-"""
-前端展示flow用到的数据结构
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""前端展示flow用到的数据结构"""
from typing import Any
diff --git a/apps/entities/mcp.py b/apps/entities/mcp.py
new file mode 100644
index 0000000000000000000000000000000000000000..252fd57db8cb3c75198ba2d5cb7bba2efdf0c2eb
--- /dev/null
+++ b/apps/entities/mcp.py
@@ -0,0 +1,143 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""MCP 相关数据结构"""
+
+from enum import Enum
+from typing import Any
+
+from lancedb.pydantic import LanceModel, Vector
+from pydantic import BaseModel, Field
+
+from apps.entities.enum_var import (
+ MetadataType,
+)
+
+
+class MCPType(str, Enum):
+ """MCP 类型"""
+
+ SSE = "sse"
+ STDIO = "stdio"
+ STREAMABLE = "stream"
+
+
+class MCPMetadataBase(BaseModel):
+ """
+ MCPService或MCPApp的元数据
+
+ 注意:hash字段在save和load的时候exclude
+ """
+
+ id: str = Field(description="元数据ID")
+ type: MetadataType = Field(description="元数据类型")
+ icon: str = Field(description="图标", default="")
+ name: str = Field(description="元数据名称")
+ description: str = Field(description="元数据描述")
+ author: str = Field(description="创建者的用户名")
+ hashes: dict[str, str] = Field(description="配置文件的hash值", default={})
+
+
+class MCPServerConfig(BaseModel):
+ """MCP 服务器配置"""
+
+ name: str = Field(description="MCP 服务器自然语言名称", default="")
+ description: str = Field(description="MCP 服务器自然语言描述", default="")
+ type: MCPType = Field(description="MCP 服务器类型", default=MCPType.STDIO)
+ disabled: bool = Field(description="MCP 服务器是否禁用", default=False)
+ auto_install: bool = Field(description="是否自动安装MCP服务器", default=True, alias="autoInstall")
+ icon_path: str = Field(description="MCP 服务器图标路径", default="", alias="iconPath")
+ env: dict[str, str] = Field(description="MCP 服务器环境变量", default={})
+ auto_approve: list[str] = Field(description="自动批准的MCP工具ID列表", default=[], alias="autoApprove")
+
+
+class MCPServerStdioConfig(MCPServerConfig):
+ """MCP 服务器配置"""
+
+ command: str = Field(description="MCP 服务器命令")
+ args: list[str] = Field(description="MCP 服务器命令参数")
+
+
+class MCPServerSSEConfig(MCPServerConfig):
+ """MCP 服务器配置"""
+
+ url: str = Field(description="MCP 服务器地址", default="")
+
+
+class MCPConfig(BaseModel):
+ """MCP 配置"""
+
+ mcp_servers: dict[
+ str,
+ MCPServerSSEConfig | MCPServerStdioConfig,
+ ] = Field(description="MCP 服务器配置", alias="mcpServers")
+
+
+class MCPTool(BaseModel):
+ """MCP工具"""
+
+ id: str = Field(description="MCP工具ID")
+ name: str = Field(description="MCP工具名称")
+ description: str = Field(description="MCP工具描述")
+ mcp_id: str = Field(description="MCP ID")
+ input_schema: dict[str, Any] = Field(description="MCP工具输入参数")
+
+
+class MCPCollection(BaseModel):
+ """MCP相关信息,存储在MongoDB的 ``mcp`` 集合中"""
+
+ id: str = Field(description="MCP ID", alias="_id")
+ name: str = Field(description="MCP 自然语言名称")
+ description: str = Field(description="MCP 自然语言描述")
+ type: MCPType = Field(description="MCP 类型")
+ activated: list[str] = Field(description="激活该MCP的用户ID列表", default=[])
+ tools: list[MCPTool] = Field(description="MCP工具列表", default=[])
+
+
+class MCPVector(LanceModel):
+ """MCP向量化数据,存储在LanceDB的 ``mcp`` 表中"""
+
+ id: str = Field(description="MCP ID")
+ embedding: Vector(dim=1024) = Field(description="MCP描述的向量信息") # type: ignore[call-arg]
+
+
+class MCPToolVector(LanceModel):
+ """MCP工具向量化数据,存储在LanceDB的 ``mcp_tool`` 表中"""
+
+ id: str = Field(description="工具ID")
+ mcp_id: str = Field(description="MCP ID")
+ embedding: Vector(dim=1024) = Field(description="MCP工具描述的向量信息") # type: ignore[call-arg]
+
+
+class MCPSelectResult(BaseModel):
+ """MCP选择结果"""
+
+ mcp_id: str = Field(description="MCP Server的ID")
+
+
+class MCPToolSelectResult(BaseModel):
+ """MCP工具选择结果"""
+
+ name: str = Field(description="工具名称")
+
+
+class MCPServiceMetadata(MCPMetadataBase):
+ """MCPService的元数据"""
+
+ type: MetadataType = MetadataType.SERVICE
+ config: MCPConfig = Field(description="MCP服务配置")
+ config_str: str = Field(description="MCP服务配置字符串", alias="configStr")
+ tools: list[MCPTool] = Field(description="MCP服务Tools列表")
+ mcp_type: MCPType = Field(description="MCP 类型", alias="mcpType")
+
+
+class MCPPlanItem(BaseModel):
+ """MCP 计划"""
+
+ plan: str = Field(description="计划内容")
+ tool: str = Field(description="工具名称")
+ instruction: str = Field(description="工具指令")
+
+
+class MCPPlan(BaseModel):
+ """MCP 计划"""
+
+ plans: list[MCPPlanItem] = Field(description="计划列表")
diff --git a/apps/entities/message.py b/apps/entities/message.py
index 557d311b0ccb677dace093b320c1c0a012b29adc..0ae807790a6b2dbe0f41fe5e61bf587602c6ddc4 100644
--- a/apps/entities/message.py
+++ b/apps/entities/message.py
@@ -1,8 +1,5 @@
-"""
-队列中的消息结构
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""队列中的消息结构"""
from typing import Any
diff --git a/apps/entities/node.py b/apps/entities/node.py
index dcb458a88cf42c475100cd0a407171bca15f8819..fe25a690db02116b457c2c8483fd576376bb924f 100644
--- a/apps/entities/node.py
+++ b/apps/entities/node.py
@@ -1,4 +1,6 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""Node实体类"""
+
from typing import Any
from pydantic import BaseModel, Field
diff --git a/apps/entities/pool.py b/apps/entities/pool.py
index 807a55a33e3ec5b2e11b0915db8689d021c2fb21..4b0c319eb1ff30aa4ccec19f69e2b66e51802fa0 100644
--- a/apps/entities/pool.py
+++ b/apps/entities/pool.py
@@ -1,8 +1,5 @@
-"""
-App和Service等数据库内数据结构
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""App和Service等数据库内数据结构"""
from datetime import UTC, datetime
from typing import Any
@@ -10,7 +7,7 @@ from typing import Any
from pydantic import BaseModel, Field
from apps.entities.appcenter import AppLink
-from apps.entities.enum_var import CallType, PermissionType
+from apps.entities.enum_var import AppType, CallType, PermissionType
from apps.entities.flow import AppFlow, Permission
@@ -102,6 +99,7 @@ class AppPool(BaseData):
"""
author: str = Field(description="作者的用户ID")
+ app_type: AppType = Field(description="应用类型", default=AppType.FLOW)
type: str = Field(description="应用类型", default="default")
icon: str = Field(description="应用图标", default="")
published: bool = Field(description="是否发布", default=False)
@@ -111,3 +109,4 @@ class AppPool(BaseData):
permission: Permission = Field(description="应用权限配置", default=Permission())
flows: list[AppFlow] = Field(description="Flow列表", default=[])
hashes: dict[str, str] = Field(description="关联文件的hash值", default={})
+ mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表")
diff --git a/apps/entities/rag_data.py b/apps/entities/rag_data.py
index c66b989878078446881eafd9f366e7ac2ccadb98..2847bf387c9f4b0b8386d833c070664652c33f49 100644
--- a/apps/entities/rag_data.py
+++ b/apps/entities/rag_data.py
@@ -1,24 +1,24 @@
-"""
-请求RAG相关接口时,使用的数据类型
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""请求RAG相关接口时,使用的数据类型"""
from typing import Literal
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
class RAGQueryReq(BaseModel):
"""查询RAG时的POST请求体"""
- question: str
- history: list[dict[str, str]] = []
- language: str = "zh"
- kb_sn: str | None = None
- top_k: int = 5
- fetch_source: bool = False
- document_ids: list[str] = []
+ kb_ids: list[str] = Field(default=[], description="资产id", alias="kbIds")
+ query: str = Field(default="", description="查询内容")
+ top_k: int = Field(default=5, description="返回的结果数量", alias="topK")
+ doc_ids: list[str] = Field(default=None, description="文档id", alias="docIds")
+ search_method: str = Field(default="keyword_and_vector",
+ description="检索方法", alias="searchMethod")
+ is_related_surrounding: bool = Field(default=True, description="是否关联上下文", alias="isRelatedSurrounding")
+ is_classify_by_doc: bool = Field(default=True, description="是否按文档分类", alias="isClassifyByDoc")
+ is_rerank: bool = Field(default=False, description="是否重新排序", alias="isRerank")
+ tokens_limit: int = Field(default=8192, description="token限制", alias="tokensLimit")
class RAGFileParseReqItem(BaseModel):
diff --git a/apps/entities/record.py b/apps/entities/record.py
index 68ee72885038ae9f33dcc391794905b2b264ac61..2f6306b035f71f5fa58ca94470a8adbb3b2dbab5 100644
--- a/apps/entities/record.py
+++ b/apps/entities/record.py
@@ -1,8 +1,5 @@
-"""
-Record数据结构
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""Record数据结构"""
import uuid
from datetime import UTC, datetime
@@ -103,7 +100,7 @@ class Record(RecordData):
user_sub: str
key: dict[str, Any] = {}
content: str
- comment: RecordComment= Field(default=RecordComment())
+ comment: RecordComment = Field(default=RecordComment())
flow: list[str] = Field(default=[])
diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py
index 89e94ae47e969e869e35f7dc596384efdceb9e10..61714aca1cbb218f91e695cc1812cbb22f2adccb 100644
--- a/apps/entities/request_data.py
+++ b/apps/entities/request_data.py
@@ -1,8 +1,5 @@
-"""
-FastAPI 请求体
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""FastAPI 请求体"""
from typing import Any
@@ -12,6 +9,7 @@ from apps.common.config import Config
from apps.entities.appcenter import AppData
from apps.entities.enum_var import CommentType
from apps.entities.flow_topology import FlowItem
+from apps.entities.mcp import MCPType
class RequestDataApp(BaseModel):
@@ -93,6 +91,23 @@ class ModFavAppRequest(BaseModel):
favorited: bool = Field(..., description="是否收藏")
+class UpdateMCPServiceRequest(BaseModel):
+ """POST /api/mcpservice 请求数据结构"""
+
+ service_id: str | None = Field(None, alias="serviceId", description="服务ID(更新时传递)")
+ icon: str = Field(description="图标", default="")
+ name: str = Field(..., description="MCP服务名称")
+ description: str = Field(..., description="MCP服务描述")
+ config: str = Field(..., description="MCP服务配置")
+ mcp_type: MCPType = Field(description="MCP传输协议(Stdio/SSE/Streamable)", default=MCPType.STDIO, alias="mcpType")
+
+
+class ActiveMCPServiceRequest(BaseModel):
+ """POST /api/mcp/{serviceId} 请求数据结构"""
+
+ active: bool = Field(description="是否激活mcp服务")
+
+
class UpdateServiceRequest(BaseModel):
"""POST /api/service 请求数据结构"""
@@ -142,13 +157,29 @@ class PostDomainData(BaseModel):
domain_description: str = Field(..., max_length=2000)
-class PostKnowledgeIDData(BaseModel):
- """添加知识库"""
-
- kb_id: str
-
-
class PutFlowReq(BaseModel):
"""创建/修改流拓扑结构"""
flow: FlowItem
+
+
+class UpdateLLMReq(BaseModel):
+ """更新大模型请求体"""
+
+ icon: str = Field(description="图标", default="")
+ openai_base_url: str = Field(default="", description="OpenAI API Base URL", alias="openaiBaseUrl")
+ openai_api_key: str = Field(default="", description="OpenAI API Key", alias="openaiApiKey")
+ model_name: str = Field(default="", description="模型名称", alias="modelName")
+ max_tokens: int = Field(default=8192, description="最大token数", alias="maxTokens")
+
+
+class DeleteLLMReq(BaseModel):
+ """删除大模型请求体"""
+
+ llm_id: str = Field(description="大模型ID", alias="llmId")
+
+
+class UpdateKbReq(BaseModel):
+ """更新知识库请求体"""
+
+ kb_ids: list[str] = Field(description="知识库ID列表", alias="kbIds", default=[])
diff --git a/apps/entities/response_data.py b/apps/entities/response_data.py
index decb55dc5502a6676e684724f477e53e238b9be3..edcc2b5dcef0e8cc31a49ec50b4740bea3318b34 100644
--- a/apps/entities/response_data.py
+++ b/apps/entities/response_data.py
@@ -1,8 +1,5 @@
-"""
-FastAPI 返回数据结构
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""FastAPI 返回数据结构"""
from typing import Any
@@ -17,8 +14,10 @@ from apps.entities.flow_topology import (
NodeServiceItem,
PositionItem,
)
+from apps.entities.mcp import MCPTool, MCPType
from apps.entities.record import RecordData
from apps.entities.user import UserInfo
+from apps.templates.generate_llm_operator_config import llm_provider_dict
class ResponseData(BaseModel):
@@ -47,6 +46,7 @@ class AuthUserMsg(BaseModel):
user_sub: str
revision: bool
+ is_admin: bool
class AuthUserRsp(ResponseData):
@@ -85,6 +85,21 @@ class GetBlacklistQuestionRsp(ResponseData):
result: GetBlacklistQuestionMsg
+class LLMIteam(BaseModel):
+ """GET /api/conversation Result数据结构"""
+
+ icon: str = Field(default=llm_provider_dict["ollama"]["icon"])
+ llm_id: str = Field(alias="llmId", default="empty")
+ model_name: str = Field(alias="modelName", default="Ollama LLM")
+
+
+class KbIteam(BaseModel):
+ """GET /api/conversation Result数据结构"""
+
+ kb_id: str = Field(alias="kbId")
+ kb_name: str = Field(alias="kbName")
+
+
class ConversationListItem(BaseModel):
"""GET /api/conversation Result数据结构"""
@@ -94,6 +109,8 @@ class ConversationListItem(BaseModel):
created_time: str = Field(alias="createdTime")
app_id: str = Field(alias="appId")
debug: bool = Field(alias="debug")
+ llm: LLMIteam | None = Field(alias="llm", default=None)
+ kb_list: list[KbIteam] = Field(alias="kbList", default=[])
class ConversationListMsg(BaseModel):
@@ -214,16 +231,33 @@ class OidcRedirectRsp(ResponseData):
result: OidcRedirectMsg
-class GetKnowledgeIDMsg(BaseModel):
+class KnowledgeBaseItem(BaseModel):
+ """知识库列表项数据结构"""
+
+ kb_id: str = Field(..., alias="kbId", description="知识库ID")
+ kb_name: str = Field(..., description="知识库名称", alias="kbName")
+ description: str = Field(..., description="知识库描述")
+ is_used: bool = Field(..., description="是否使用", alias="isUsed")
+
+
+class TeamKnowledgeBaseItem(BaseModel):
+ """团队知识库列表项数据结构"""
+
+ team_id: str = Field(..., alias="teamId", description="团队ID")
+ team_name: str = Field(..., alias="teamName", description="团队名称")
+ kb_list: list[KnowledgeBaseItem] = Field(default=[], description="知识库列表")
+
+
+class ListTeamKnowledgeMsg(BaseModel):
"""GET /api/knowledge Result数据结构"""
- kb_id: str
+ team_kb_list: list[TeamKnowledgeBaseItem] = Field(default=[], alias="teamKbList", description="团队知识库列表")
-class GetKnowledgeIDRsp(ResponseData):
+class ListTeamKnowledgeRsp(ResponseData):
"""GET /api/knowledge 返回数据结构"""
- result: GetKnowledgeIDMsg
+ result: ListTeamKnowledgeMsg
class BaseAppOperationMsg(BaseModel):
@@ -396,6 +430,75 @@ class NodeServiceListRsp(ResponseData):
result: NodeServiceListMsg
+class MCPServiceCardItem(BaseModel):
+ """插件中心:MCP服务卡片数据结构"""
+
+ mcpservice_id: str = Field(..., alias="mcpserviceId", description="mcp服务ID")
+ name: str = Field(..., description="mcp服务名称")
+ description: str = Field(..., description="mcp服务简介")
+ icon: str = Field(..., description="mcp服务图标")
+ author: str = Field(..., description="mcp服务作者")
+ is_active: bool = Field(alias="isActive", description="mcp服务是否激活", default=False)
+
+
+class BaseMCPServiceOperationMsg(BaseModel):
+ """插件中心:MCP服务操作Result数据结构"""
+
+ service_id: str = Field(..., alias="serviceId", description="服务ID")
+
+
+class GetMCPServiceListMsg(BaseModel):
+ """GET /api/service Result数据结构"""
+
+ current_page: int = Field(..., alias="currentPage", description="当前页码")
+ total_count: int = Field(..., alias="totalCount", description="总服务数")
+ services: list[MCPServiceCardItem] = Field(..., description="解析后的服务列表")
+
+
+class GetMCPServiceListRsp(ResponseData):
+ """GET /api/service 返回数据结构"""
+
+ result: GetMCPServiceListMsg = Field(..., title="Result")
+
+
+class UpdateMCPServiceMsg(BaseModel):
+ """插件中心:MCP服务属性数据结构"""
+
+ service_id: str = Field(..., alias="serviceId", description="MCP服务ID")
+ name: str = Field(..., description="MCP服务名称")
+
+
+class UpdateMCPServiceRsp(ResponseData):
+ """POST /api/mcp_service 返回数据结构"""
+
+ result: UpdateMCPServiceMsg = Field(..., title="Result")
+
+
+class GetMCPServiceDetailMsg(BaseModel):
+ """GET /api/mcp_service/{serviceId} Result数据结构"""
+
+ service_id: str = Field(..., alias="serviceId", description="MCP服务ID")
+ icon: str = Field(description="图标", default="")
+ name: str = Field(..., description="MCP服务名称")
+ description: str = Field(description="MCP服务描述")
+ data: str = Field(description="MCP服务配置")
+ tools: list[MCPTool] = Field(description="MCP服务Tools列表", default=[])
+ is_active: bool = Field(alias="isActive", description="mcp服务是否激活", default=False)
+ mcp_type: MCPType = Field(alias="mcpType", description="MCP 类型")
+
+
+class GetMCPServiceDetailRsp(ResponseData):
+ """GET /api/service/{serviceId} 返回数据结构"""
+
+ result: GetMCPServiceDetailMsg = Field(..., title="Result")
+
+
+class DeleteMCPServiceRsp(ResponseData):
+ """DELETE /api/service/{serviceId} 返回数据结构"""
+
+ result: BaseMCPServiceOperationMsg = Field(..., title="Result")
+
+
class NodeMetaDataRsp(ResponseData):
"""GET /api/flow/service/node 返回数据结构"""
@@ -444,12 +547,54 @@ class FlowStructureDeleteRsp(ResponseData):
result: FlowStructureDeleteMsg
+
class UserGetMsp(BaseModel):
"""GET /api/user result"""
- user_info_list : list[UserInfo] = Field(alias="userInfoList", default=[])
+ user_info_list: list[UserInfo] = Field(alias="userInfoList", default=[])
+
class UserGetRsp(ResponseData):
"""GET /api/user 返回数据结构"""
result: UserGetMsp
+
+
+class ActiveMCPServiceRsp(ResponseData):
+ """POST /api/mcp/active/{serviceId} 返回数据结构"""
+
+ result: BaseMCPServiceOperationMsg = Field(..., title="Result")
+
+
+class LLMProvider(BaseModel):
+ """LLM提供商数据结构"""
+
+ provider: str = Field(..., description="LLM提供商")
+ description: str = Field(..., description="LLM提供商描述")
+ url: str | None = Field(default=None, description="LLM提供商URL")
+ icon: str = Field(..., description="LLM提供商图标")
+
+
+class ListLLMProviderRsp(ResponseData):
+ """GET /api/llm/provider 返回数据结构"""
+
+ result: list[LLMProvider] = Field(default=[], title="Result")
+
+
+class LLM(BaseModel):
+ """LLM数据结构"""
+
+ llm_id: str = Field(..., alias="llmId", description="LLM ID")
+ icon: str = Field(default="", description="LLM图标", max_length=25536)
+ openai_base_url: str = Field(default="https://api.openai.com/v1",
+ description="OpenAI API Base URL", alias="openaiBaseUrl")
+ openai_api_key: str = Field(description="OpenAI API Key", alias="openaiApiKey")
+ model_name: str = Field(description="模型名称", alias="modelName")
+ max_tokens: int = Field(description="最大token数", alias="maxTokens")
+ is_editable: bool = Field(default=True, description="是否可编辑", alias="isEditable")
+
+
+class ListLLMRsp(ResponseData):
+ """GET /api/llm 返回数据结构"""
+
+ result: list[LLM] = Field(default=[], title="Result")
diff --git a/apps/entities/scheduler.py b/apps/entities/scheduler.py
index c767b66b898397d012e9e604d3d36f4ec4ca5c53..c1e31405222d32de92a5891a4090ce977a7191ac 100644
--- a/apps/entities/scheduler.py
+++ b/apps/entities/scheduler.py
@@ -1,8 +1,5 @@
-"""
-插件、工作流、步骤相关数据结构定义
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""插件、工作流、步骤相关数据结构定义"""
from typing import Any
diff --git a/apps/entities/session.py b/apps/entities/session.py
index fb619833804463878e22c4ee60ffc20762154349..361987cc23d93e428e7fee12a520f65435946f61 100644
--- a/apps/entities/session.py
+++ b/apps/entities/session.py
@@ -1,3 +1,4 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""Session相关数据结构"""
from datetime import datetime
diff --git a/apps/entities/task.py b/apps/entities/task.py
index e4eedbbb8e6defec047cd81b2cd68fd77cf2b40f..77d8c1a47f85454396ddaa6a16caeea0b70208a1 100644
--- a/apps/entities/task.py
+++ b/apps/entities/task.py
@@ -1,8 +1,5 @@
-"""
-Task相关数据结构定义
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""Task相关数据结构定义"""
import uuid
from datetime import UTC, datetime
@@ -47,6 +44,7 @@ class ExecutorState(BaseModel):
step_name: str = Field(description="当前步骤名称")
app_id: str = Field(description="应用ID")
slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={})
+ error_info: dict[str, Any] = Field(description="错误信息", default={})
class TaskIds(BaseModel):
diff --git a/apps/entities/user.py b/apps/entities/user.py
index 15a0c73b3282385c3cb9e53a330131af5ae1e58c..61aa2587b8ae6255dc3885b04a96f12310f70773 100644
--- a/apps/entities/user.py
+++ b/apps/entities/user.py
@@ -1,8 +1,5 @@
-"""
-User用户信息数据结构
-
-Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
+"""User用户信息数据结构"""
from pydantic import BaseModel, Field
diff --git a/apps/entities/vector.py b/apps/entities/vector.py
index 361c0d5beac7d083f58d0e3680ebcdcb77c308ba..1cdc85c9b9fbe5aaa902e1b93238001e5bbe44c5 100644
--- a/apps/entities/vector.py
+++ b/apps/entities/vector.py
@@ -1,8 +1,5 @@
-"""
-向量数据库数据结构;数据将存储在LanceDB中
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""向量数据库数据结构;数据将存储在LanceDB中"""
from lancedb.pydantic import LanceModel, Vector
diff --git a/apps/llm/__init__.py b/apps/llm/__init__.py
index 8b01256ea0cc406c385ac49fe9115b09451bd38e..e28ad31d70fd0bbc692694ff57f16014fd7580d5 100644
--- a/apps/llm/__init__.py
+++ b/apps/llm/__init__.py
@@ -1,4 +1,2 @@
-"""模型调用模块
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""模型调用模块"""
diff --git a/apps/llm/embedding.py b/apps/llm/embedding.py
index b6faebba773d43632e0dc35c9d65796f36b9787e..28ab86b49a96d19cc6e2db83930d5aff48f82260 100644
--- a/apps/llm/embedding.py
+++ b/apps/llm/embedding.py
@@ -1,6 +1,6 @@
"""Embedding模型"""
-import aiohttp
+import httpx
from apps.common.config import Config
@@ -9,6 +9,12 @@ class Embedding:
"""Embedding模型"""
# TODO: 应当自动检测向量维度
+ @classmethod
+ async def _get_embedding_dimension(cls) -> int:
+ """获取Embedding的维度"""
+ embedding = await cls.get_embedding(["测试文本"])
+ return len(embedding[0])
+
@classmethod
async def _get_openai_embedding(cls, text: list[str]) -> list[list[float]]:
@@ -26,16 +32,14 @@ class Embedding:
if Config().get_config().embedding.api_key:
headers["Authorization"] = f"Bearer {Config().get_config().embedding.api_key}"
- async with (
- aiohttp.ClientSession() as session,
- session.post(
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
api,
json=data,
headers=headers,
- timeout=aiohttp.ClientTimeout(total=60),
- ) as response,
- ):
- json = await response.json()
+ timeout=60.0,
+ )
+ json = response.json()
return [item["embedding"] for item in json["data"]]
@classmethod
@@ -48,22 +52,20 @@ class Embedding:
if Config().get_config().embedding.api_key:
headers["Authorization"] = f"Bearer {Config().get_config().embedding.api_key}"
- session = aiohttp.ClientSession()
-
- result = []
- for single_text in text:
- data = {
- "inputs": single_text,
- "normalize": True,
- }
- async with session.post(
- api, json=data, headers=headers, timeout=aiohttp.ClientTimeout(total=60),
- ) as response:
- json = await response.json()
+ async with httpx.AsyncClient() as client:
+ result = []
+ for single_text in text:
+ data = {
+ "inputs": single_text,
+ "normalize": True,
+ }
+ response = await client.post(
+ api, json=data, headers=headers, timeout=60.0,
+ )
+ json = response.json()
result.append(json[0])
- await session.close()
- return result
+ return result
@classmethod
async def get_embedding(cls, text: list[str]) -> list[list[float]]:
diff --git a/apps/llm/function.py b/apps/llm/function.py
index a17370a0be450d3a64de747c28831088ba77bbc7..1f995fe7ba187cead03aa6fc62a4cbce1ec05a65 100644
--- a/apps/llm/function.py
+++ b/apps/llm/function.py
@@ -1,271 +1,305 @@
-"""
-用于FunctionCall的大模型
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""用于FunctionCall的大模型"""
import json
+import logging
+import re
+from textwrap import dedent
from typing import Any
-from asyncer import asyncify
+from jinja2 import BaseLoader
+from jinja2.sandbox import SandboxedEnvironment
+from jsonschema import Draft7Validator
from apps.common.config import Config
-from apps.constants import REASONING_BEGIN_TOKEN, REASONING_END_TOKEN
-from apps.scheduler.json_schema import build_regex_from_schema
+from apps.constants import JSON_GEN_MAX_TRIAL, REASONING_END_TOKEN
+from apps.llm.prompt import JSON_GEN_BASIC
+
+logger = logging.getLogger(__name__)
class FunctionLLM:
"""用于FunctionCall的模型"""
- _client: Any
-
def __init__(self) -> None:
"""
初始化用于FunctionCall的模型
目前支持:
- - sglang
- vllm
- ollama
+ - function_call
+ - json_mode
+ - structured_output
"""
- if Config().get_config().function_call.backend == "sglang":
- import sglang
- from sglang.lang.chat_template import get_chat_template
-
- if not Config().get_config().function_call.api_key:
- self._client = sglang.RuntimeEndpoint(Config().get_config().function_call.endpoint)
- else:
- self._client = sglang.RuntimeEndpoint(
- Config().get_config().function_call.endpoint, api_key=Config().get_config().function_call.api_key,
- )
- self._client.chat_template = get_chat_template("chatml")
-
- if (
- Config().get_config().function_call.backend == "vllm"
- or Config().get_config().function_call.backend == "openai"
- ):
- import openai
+ # 暂存config;这里可以替代为从其他位置获取
+ self._config = Config().get_config().function_call
+ if not self._config.model:
+ err_msg = "[FunctionCall] 未设置FuntionCall所用模型!"
+ logger.error(err_msg)
+ raise ValueError(err_msg)
- if not Config().get_config().function_call.api_key:
- self._client = openai.AsyncOpenAI(base_url=Config().get_config().function_call.endpoint + "/v1")
- else:
- self._client = openai.AsyncOpenAI(
- base_url=Config().get_config().function_call.endpoint + "/v1",
- api_key=Config().get_config().function_call.api_key,
- )
+ self._params = {
+ "model": self._config.model,
+ "messages": [],
+ }
- if Config().get_config().function_call.backend == "ollama":
+ if self._config.backend == "ollama":
import ollama
- if not Config().get_config().function_call.api_key:
- self._client = ollama.AsyncClient(host=Config().get_config().function_call.endpoint)
+ if not self._config.api_key:
+ self._client = ollama.AsyncClient(host=self._config.endpoint)
else:
self._client = ollama.AsyncClient(
- host=Config().get_config().function_call.endpoint,
+ host=self._config.endpoint,
headers={
- "Authorization": f"Bearer {Config().get_config().function_call.api_key}",
+ "Authorization": f"Bearer {self._config.api_key}",
},
)
- @staticmethod
- def _sglang_func(
- s, messages: list[dict[str, Any]], schema: dict[str, Any], max_tokens: int, temperature: float, # noqa: ANN001
- ) -> None:
- """
- 构建sglang需要的执行函数
+ else:
+ import openai
- :param s: sglang context
- :param messages: 历史消息
- :param schema: 输出JSON Schema
- :param max_tokens: 最大Token长度
- :param temperature: 大模型温度
- """
- for msg in messages:
- if msg["role"] == "user":
- s += s.user(msg["content"])
- elif msg["role"] == "assistant":
- s += s.assistant(msg["content"])
- elif msg["role"] == "system":
- s += s.system(msg["content"])
+ if not self._config.api_key:
+ self._client = openai.AsyncOpenAI(base_url=self._config.endpoint)
else:
- err_msg = f"Unknown message role: {msg['role']}"
- raise ValueError(err_msg)
+ self._client = openai.AsyncOpenAI(
+ base_url=self._config.endpoint,
+ api_key=self._config.api_key,
+ )
- # 如果Schema为空,认为是直接问答,不加输出限制
- if not schema:
- s += s.assistant(s.gen(name="output", max_tokens=max_tokens, temperature=temperature))
- else:
- s += s.assistant(
- s.gen(
- name="output",
- regex=build_regex_from_schema(json.dumps(schema)),
- max_tokens=max_tokens,
- temperature=temperature,
- ),
- )
-
- async def _call_vllm(
- self, messages: list[dict[str, Any]], schema: dict[str, Any], max_tokens: int, temperature: float,
+
+ async def _call_openai(
+ self,
+ messages: list[dict[str, str]],
+ schema: dict[str, Any],
+ max_tokens: int | None = None,
+ temperature: float | None = None,
) -> str:
"""
- 调用vllm模型生成JSON
+ 调用openai模型生成JSON
- :param messages: 历史消息列表
- :param schema: 输出JSON Schema
- :param max_tokens: 最大Token长度
- :param temperature: 大模型温度
+ :param list[dict[str, str]] messages: 历史消息列表
+ :param dict[str, Any] schema: 输出JSON Schema
+ :param int | None max_tokens: 最大Token长度
+ :param float | None temperature: 大模型温度
:return: 生成的JSON
+ :rtype: str
"""
- model = Config().get_config().function_call.model
- if not model:
- err_msg = "未设置FuntionCall所用模型!"
- raise ValueError(err_msg)
-
- param = {
- "model": model,
+ self._params.update({
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
- "stream": True,
- }
-
- # 如果Schema不为空,认为是FunctionCall,需要指定输出格式
- if schema:
- param["extra_body"] = {"guided_json": schema}
-
- chat = await self._client.chat.completions.create(**param)
+ })
+
+ if self._config.backend == "vllm":
+ self._params["extra_body"] = {"guided_json": schema}
+
+ elif self._config.backend == "json_mode":
+ logger.warning("[FunctionCall] json_mode无法确保输出格式符合要求,使用效果将受到影响")
+ self._params["response_format"] = {"type": "json_object"}
+
+ elif self._config.backend == "structured_output":
+ self._params["response_format"] = {
+ "type": "json_schema",
+ "json_schema": {
+ "name": "generate",
+ "description": "Generate answer based on the background information",
+ "schema": schema,
+ "strict": True,
+ },
+ }
- reasoning = False
- result = ""
- async for chunk in chat:
- chunk_str = chunk.choices[0].delta.content or ""
- for token in REASONING_BEGIN_TOKEN:
- if token in chunk_str:
- reasoning = True
- break
+ elif self._config.backend == "function_call":
+ logger.warning("[FunctionCall] function_call无法确保一定调用工具,使用效果将受到影响")
+ self._params["tools"] = [
+ {
+ "type": "function",
+ "function": {
+ "name": "generate",
+ "description": "Generate answer based on the background information",
+ "parameters": schema,
+ },
+ },
+ ]
- for token in REASONING_END_TOKEN:
- if token in chunk_str:
- reasoning = False
- chunk_str = ""
- break
+ response = await self._client.chat.completions.create(**self._params) # type: ignore[arg-type]
+ try:
+ logger.info("[FunctionCall] 大模型输出:%s", response.choices[0].message.tool_calls[0].function.arguments)
+ return response.choices[0].message.tool_calls[0].function.arguments
+ except Exception: # noqa: BLE001
+ ans = response.choices[0].message.content
+ logger.info("[FunctionCall] 大模型输出:%s", ans)
+ return await FunctionLLM.process_response(ans)
- if not reasoning:
- result += chunk_str
- return result.strip().strip(" ").strip("\n")
- async def _call_openai(
- self, messages: list[dict[str, Any]], schema: dict[str, Any], max_tokens: int, temperature: float,
- ) -> str:
- """
- 调用openai模型生成JSON
+ @staticmethod
+ async def process_response(response: str) -> str:
+ """处理大模型的输出"""
+ # 去掉推理过程,避免干扰
+ for token in REASONING_END_TOKEN:
+ response = response.split(token)[-1]
+
+ # 尝试解析JSON
+ response = dedent(response).strip()
+ error_flag = False
+ try:
+ json.loads(response)
+ except Exception: # noqa: BLE001
+ error_flag = True
- :param messages: 历史消息列表
- :param schema: 输出JSON Schema
- :param max_tokens: 最大Token长度
- :param temperature: 大模型温度
- :return: 生成的JSON
- """
- model = Config().get_config().function_call.model
- if not model:
- err_msg = "未设置FuntionCall所用模型!"
- raise ValueError(err_msg)
+ if not error_flag:
+ return response
- param = {
- "model": model,
- "messages": messages,
- "max_tokens": max_tokens,
- "temperature": temperature,
- }
+ # 尝试提取```json中的JSON
+ logger.warning("[FunctionCall] 直接解析失败!尝试提取```json中的JSON")
+ try:
+ json_str = re.findall(r"```json(.*)```", response, re.DOTALL)[-1]
+ json_str = dedent(json_str).strip()
+ json.loads(json_str)
+ except Exception: # noqa: BLE001
+ # 尝试直接通过括号匹配JSON
+ logger.warning("[FunctionCall] 提取失败!尝试正则提取JSON")
+ try:
+ json_str = re.findall(r"\{.*\}", response, re.DOTALL)[-1]
+ json_str = dedent(json_str).strip()
+ json.loads(json_str)
+ except Exception: # noqa: BLE001
+ json_str = "{}"
- if schema:
- tool_data = {
- "type": "function",
- "function": {
- "name": "output",
- "description": "Call the function to get the output",
- "parameters": schema,
- },
- }
- param["tools"] = [tool_data]
- param["tool_choice"] = "required"
+ return json_str
- response = await self._client.chat.completions.create(**param)
- try:
- ans = response.choices[0].message.tool_calls[0].function.arguments or ""
- except IndexError:
- ans = ""
- return ans
async def _call_ollama(
- self, messages: list[dict[str, Any]], schema: dict[str, Any], max_tokens: int, temperature: float,
+ self,
+ messages: list[dict[str, str]],
+ schema: dict[str, Any],
+ max_tokens: int | None = None,
+ temperature: float | None = None,
) -> str:
"""
调用ollama模型生成JSON
- :param messages: 历史消息列表
- :param schema: 输出JSON Schema
- :param max_tokens: 最大Token长度
- :param temperature: 大模型温度
+ :param list[dict[str, str]] messages: 历史消息列表
+ :param dict[str, Any] schema: 输出JSON Schema
+ :param int | None max_tokens: 最大Token长度
+ :param float | None temperature: 大模型温度
:return: 生成的对话回复
+ :rtype: str
"""
- param = {
- "model": Config().get_config().function_call.model,
+ self._params.update({
"messages": messages,
"options": {
"temperature": temperature,
- "num_ctx": max_tokens,
"num_predict": max_tokens,
},
- }
- # 如果Schema不为空,认为是FunctionCall,需要指定输出格式
- if schema:
- param["format"] = schema
+ "format": schema,
+ })
- response = await self._client.chat(**param)
- return response.message.content or ""
+ response = await self._client.chat(**self._params) # type: ignore[arg-type]
+ return await self.process_response(response.message.content or "")
- async def _call_sglang(
- self, messages: list[dict[str, Any]], schema: dict[str, Any], max_tokens: int, temperature: float,
- ) -> str:
- """
- 调用sglang模型生成JSON
- :param messages: 历史消息
- :param schema: 输出JSON Schema
- :param max_tokens: 最大Token长度
- :param temperature: 大模型温度
- :return: 生成的JSON
- """
- # 构造sglang执行函数
- import sglang
-
- sglang.set_default_backend(self._client)
-
- sglang_func = sglang.function(self._sglang_func)
- state = await asyncify(sglang_func.run)(messages, schema, max_tokens, temperature) # type: ignore[arg-type]
- return state["output"]
-
- async def call(self, **kwargs) -> str: # noqa: ANN003
+ async def call(
+ self,
+ messages: list[dict[str, Any]],
+ schema: dict[str, Any],
+ max_tokens: int | None = None,
+ temperature: float | None = None,
+ ) -> dict[str, Any]:
"""
调用FunctionCall小模型
- 暂不开放流式输出
+ 不开放流式输出
"""
- if Config().get_config().function_call.backend == "vllm":
- json_str = await self._call_vllm(**kwargs)
+ # 检查max_tokens和temperature是否设置
+ if max_tokens is None:
+ max_tokens = self._config.max_tokens
+ if temperature is None:
+ temperature = self._config.temperature
- elif Config().get_config().function_call.backend == "sglang":
- json_str = await self._call_sglang(**kwargs)
+ if self._config.backend == "ollama":
+ json_str = await self._call_ollama(messages, schema, max_tokens, temperature)
- elif Config().get_config().function_call.backend == "ollama":
- json_str = await self._call_ollama(**kwargs)
-
- elif Config().get_config().function_call.backend == "openai":
- json_str = await self._call_openai(**kwargs)
+ elif self._config.backend in ["function_call", "json_mode", "response_format", "vllm"]:
+ json_str = await self._call_openai(messages, schema, max_tokens, temperature)
else:
err = "未知的Function模型后端"
raise ValueError(err)
- return json_str
+ try:
+ return json.loads(json_str)
+ except Exception: # noqa: BLE001
+ logger.error("[FunctionCall] 大模型JSON解析失败:%s", json_str) # noqa: TRY400
+ return {}
+
+
+class JsonGenerator:
+ """JSON生成器"""
+
+ def __init__(self, query: str, conversation: list[dict[str, str]], schema: dict[str, Any]) -> None:
+ """初始化JSON生成器"""
+ self._query = query
+ self._conversation = conversation
+ self._schema = schema
+
+ self._trial = {}
+ self._count = 0
+ self._env = SandboxedEnvironment(
+ loader=BaseLoader(),
+ autoescape=False,
+ trim_blocks=True,
+ lstrip_blocks=True,
+ )
+ self._err_info = ""
+
+
+ async def _assemble_message(self) -> str:
+ """组装消息"""
+ # 检查类型
+ function_call = Config().get_config().function_call.backend == "function_call"
+
+ # 渲染模板
+ template = self._env.from_string(JSON_GEN_BASIC)
+ return template.render(
+ query=self._query,
+ conversation=self._conversation,
+ previous_trial=self._trial,
+ schema=self._schema,
+ function_call=function_call,
+ err_info=self._err_info,
+ )
+
+ async def _single_trial(self, max_tokens: int | None = None, temperature: float | None = None) -> dict[str, Any]:
+ """单次尝试"""
+ prompt = await self._assemble_message()
+ messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt},
+ ]
+ function = FunctionLLM()
+ return await function.call(messages, self._schema, max_tokens, temperature)
+
+
+ async def generate(self) -> dict[str, Any]:
+ """生成JSON"""
+ Draft7Validator.check_schema(self._schema)
+ validator = Draft7Validator(self._schema)
+ logger.info("[JSONGenerator] Schema:%s", self._schema)
+
+ while self._count < JSON_GEN_MAX_TRIAL:
+ self._count += 1
+ result = await self._single_trial()
+ logger.info("[JSONGenerator] 得到:%s", result)
+ try:
+ validator.validate(result)
+ except Exception as err: # noqa: BLE001
+ err_info = str(err)
+ err_info = err_info.split("\n\n")[0]
+ self._err_info = err_info
+ logger.info("[JSONGenerator] 验证失败:%s", self._err_info)
+ continue
+ return result
+
+ return {}
diff --git a/apps/llm/patterns/__init__.py b/apps/llm/patterns/__init__.py
index 65772e1f52e1052cb257f42e8a6e8707fb3c1293..b12f2a1201f8636fdbb2d6caadb40d4daa7e71cd 100644
--- a/apps/llm/patterns/__init__.py
+++ b/apps/llm/patterns/__init__.py
@@ -1,25 +1,18 @@
-"""
-LLM大模型Prompt模板
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""LLM大模型Prompt模板"""
from apps.llm.patterns.core import CorePattern
-from apps.llm.patterns.domain import Domain
from apps.llm.patterns.executor import (
ExecutorSummary,
ExecutorThought,
)
-from apps.llm.patterns.json_gen import Json
from apps.llm.patterns.recommend import Recommend
from apps.llm.patterns.select import Select
__all__ = [
"CorePattern",
- "Domain",
"ExecutorSummary",
"ExecutorThought",
- "Json",
"Recommend",
"Select",
]
diff --git a/apps/llm/patterns/core.py b/apps/llm/patterns/core.py
index 756fa14331dae266d4b7e67078d9c7f98a59d273..4ef8133a9fed1b1e62f1ceb578c6bdb5a93b12a5 100644
--- a/apps/llm/patterns/core.py
+++ b/apps/llm/patterns/core.py
@@ -1,12 +1,8 @@
-"""
-基础大模型范式抽象类
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""基础大模型范式抽象类"""
from abc import ABC, abstractmethod
from textwrap import dedent
-from typing import Any, ClassVar
class CorePattern(ABC):
@@ -16,8 +12,6 @@ class CorePattern(ABC):
"""系统提示词"""
user_prompt: str = ""
"""用户提示词"""
- slot_schema: ClassVar[dict[str, Any]] = {}
- """输出格式的JSON Schema"""
input_tokens: int = 0
"""输入Token数量"""
output_tokens: int = 0
diff --git a/apps/llm/patterns/domain.py b/apps/llm/patterns/domain.py
deleted file mode 100644
index 26bd8fa77a7481e4e357426da9b0fd1e0b11eb50..0000000000000000000000000000000000000000
--- a/apps/llm/patterns/domain.py
+++ /dev/null
@@ -1,84 +0,0 @@
-"""
-LLM Pattern: 从问答中提取领域信息
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""
-
-from typing import Any, ClassVar
-
-from apps.llm.patterns.core import CorePattern
-from apps.llm.patterns.json_gen import Json
-from apps.llm.reasoning import ReasoningLLM
-from apps.llm.snippet import convert_context_to_prompt
-
-
-class Domain(CorePattern):
- """从问答中提取领域信息"""
-
- user_prompt: str = r"""
-
-
- 根据对话上文,提取推荐系统所需的关键词标签,要求:
- 1. 实体名词、技术术语、时间范围、地点、产品等关键信息均可作为关键词标签
- 2. 至少一个关键词与对话的话题有关
- 3. 标签需精简,不得重复,不得超过10个字
- 4. 使用JSON格式输出,不要包含XML标签,不要包含任何解释说明
-
-
-
-
- 北京天气如何?
- 北京今天晴。
-
-
-
-
-
-
- {conversation}
-