diff --git a/apps/common/oidc.py b/apps/common/oidc.py
index 46c7ad93d207ee43bc0ac23f48e5f45f60d17d0f..a54f1a60d25f0d5ca68da0c64a249e5e246aca35 100644
--- a/apps/common/oidc.py
+++ b/apps/common/oidc.py
@@ -10,7 +10,7 @@ from apps.constants import OIDC_ACCESS_TOKEN_EXPIRE_TIME, OIDC_REFRESH_TOKEN_EXP
from apps.models import Session, SessionType
from .config import config
-from .oidc_provider import AuthhubOIDCProvider, OpenEulerOIDCProvider
+from .oidc_provider import AutheliaOIDCProvider, AuthhubOIDCProvider, OpenEulerOIDCProvider
logger = logging.getLogger(__name__)
@@ -24,6 +24,8 @@ class OIDCProvider:
self.provider = OpenEulerOIDCProvider()
elif config.login.provider == "authhub":
self.provider = AuthhubOIDCProvider()
+ elif config.login.provider == "authelia":
+ self.provider = AutheliaOIDCProvider()
else:
err = f"[OIDC] 未知OIDC提供商: {config.login.provider}"
logger.error(err)
diff --git a/apps/common/oidc_provider/__init__.py b/apps/common/oidc_provider/__init__.py
index f84ba7204e3735f63a142f5b500e7591ddb809b8..cde9c166197653ce1e933b332331949a613e5a04 100644
--- a/apps/common/oidc_provider/__init__.py
+++ b/apps/common/oidc_provider/__init__.py
@@ -1,11 +1,13 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""OIDC Provider"""
+from .authelia import AutheliaOIDCProvider
from .authhub import AuthhubOIDCProvider
from .base import OIDCProviderBase
from .openeuler import OpenEulerOIDCProvider
__all__ = [
+ "AutheliaOIDCProvider",
"AuthhubOIDCProvider",
"OIDCProviderBase",
"OpenEulerOIDCProvider",
diff --git a/apps/common/oidc_provider/authelia.py b/apps/common/oidc_provider/authelia.py
new file mode 100644
index 0000000000000000000000000000000000000000..e57630ab4aad42752eedb5dbb6cca8ffa146de00
--- /dev/null
+++ b/apps/common/oidc_provider/authelia.py
@@ -0,0 +1,160 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""Authelia OIDC Provider"""
+
+import logging
+import secrets
+from typing import Any
+
+import httpx
+from fastapi import status
+
+from apps.common.config import config
+from apps.common.oidc_provider.base import OIDCProviderBase
+from apps.schemas.config import OIDCConfig
+
+logger = logging.getLogger(__name__)
+
+
+class AutheliaOIDCProvider(OIDCProviderBase):
+ """Authelia OIDC Provider"""
+
+ @classmethod
+ def _get_login_config(cls) -> OIDCConfig:
+ """获取并验证登录配置"""
+ login_config = config.login.settings
+ if not isinstance(login_config, OIDCConfig):
+ err = "OpenEuler OIDC配置错误"
+ raise TypeError(err)
+ return login_config
+
+ @classmethod
+ async def get_oidc_token(cls, code: str) -> dict[str, Any]:
+ """获取Authelia OIDC Token"""
+ login_config = cls._get_login_config()
+
+ data = {
+ "client_id": login_config.app_id,
+ "client_secret": login_config.app_secret,
+ "redirect_uri": login_config.login_api,
+ "grant_type": "authorization_code",
+ "code": code,
+ }
+ headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ }
+ url = await cls.get_access_token_url()
+ result = None
+ async with httpx.AsyncClient(verify=False) as client: # noqa: S501
+ resp = await client.post(
+ url,
+ headers=headers,
+ data=data,
+ timeout=10,
+ )
+ if resp.status_code != status.HTTP_200_OK:
+ err = f"[Authelia] 获取OIDC Token失败: {resp.status_code},完整输出: {resp.text}"
+ raise RuntimeError(err)
+ logger.info("[Authelia] 获取OIDC Token成功: %s", resp.text)
+ result = resp.json()
+ return {
+ "access_token": result["access_token"],
+ "refresh_token": result.get("refresh_token", ""),
+ }
+
+ @classmethod
+ async def get_oidc_user(cls, access_token: str) -> dict[str, Any]:
+ """获取Authelia OIDC用户"""
+ login_config = cls._get_login_config()
+
+ if not access_token:
+ err = "Access token is empty."
+ raise RuntimeError(err)
+ headers = {
+ "Authorization": f"Bearer {access_token}",
+ "Content-Type": "application/json",
+ }
+ url = login_config.host.rstrip("/") + "/api/oidc/userinfo"
+ result = None
+ async with httpx.AsyncClient(verify=False) as client: # noqa: S501
+ resp = await client.get(
+ url,
+ headers=headers,
+ timeout=10,
+ )
+ if resp.status_code != status.HTTP_200_OK:
+ err = f"[Authelia] 获取用户信息失败: {resp.status_code},完整输出: {resp.text}"
+ raise RuntimeError(err)
+ logger.info("[Authelia] 获取用户信息成功: %s", resp.text)
+ result = resp.json()
+
+ return {
+ "user_sub": result.get("sub", result.get("preferred_username", "")),
+ "user_name": result.get("name", result.get("preferred_username", result.get("nickname", ""))),
+ }
+
+ @classmethod
+ async def get_login_status(cls, cookie: dict[str, str]) -> dict[str, Any]:
+ """检查登录状态;Authelia通过session cookie检查"""
+ login_config = cls._get_login_config()
+
+ headers = {
+ "Content-Type": "application/json",
+ }
+ url = login_config.host.rstrip("/") + "/api/user/info"
+ async with httpx.AsyncClient(verify=False) as client: # noqa: S501
+ resp = await client.get(
+ url,
+ headers=headers,
+ cookies=cookie,
+ timeout=10,
+ )
+ if resp.status_code != status.HTTP_200_OK:
+ err = f"[Authelia] 获取登录状态失败: {resp.status_code},完整输出: {resp.text}"
+ raise RuntimeError(err)
+
+ # Authelia 返回用户信息表示已登录
+ return {
+ "access_token": "",
+ "refresh_token": "",
+ }
+
+ @classmethod
+ async def oidc_logout(cls, cookie: dict[str, str]) -> None:
+ """触发OIDC的登出"""
+ login_config = cls._get_login_config()
+
+ headers = {
+ "Content-Type": "application/json",
+ }
+ url = login_config.host.rstrip("/") + "/api/logout"
+ async with httpx.AsyncClient(verify=False) as client: # noqa: S501
+ resp = await client.post(
+ url,
+ headers=headers,
+ cookies=cookie,
+ timeout=10,
+ )
+ # Authelia登出成功通常返回200或302重定向
+ if resp.status_code not in [status.HTTP_200_OK, status.HTTP_302_FOUND]:
+ err = f"[Authelia] 登出失败: {resp.status_code},完整输出: {resp.text}"
+ raise RuntimeError(err)
+
+ @classmethod
+ async def get_redirect_url(cls) -> str:
+ """获取Authelia OIDC 重定向URL"""
+ login_config = cls._get_login_config()
+
+ # 生成随机的 state 参数以确保安全性和唯一性
+ state = secrets.token_urlsafe(32)
+ return (f"{login_config.host.rstrip('/')}/api/oidc/authorization?"
+ f"client_id={login_config.app_id}&"
+ f"response_type=code&"
+ f"scope=openid profile email&"
+ f"redirect_uri={login_config.login_api}&"
+ f"state={state}")
+
+ @classmethod
+ async def get_access_token_url(cls) -> str:
+ """获取Authelia OIDC 访问Token URL"""
+ login_config = cls._get_login_config()
+ return login_config.host.rstrip("/") + "/api/oidc/token"
diff --git a/apps/common/queue.py b/apps/common/queue.py
index c0981527276e7e5613294e38ed0cdf1b4a2c6490..685511477eea0bda1fd8eb6df576aa5e125975ff 100644
--- a/apps/common/queue.py
+++ b/apps/common/queue.py
@@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any
+from apps.llm import LLMConfig
from apps.schemas.enum_var import EventType
from apps.schemas.message import (
HeartbeatData,
@@ -15,7 +16,6 @@ from apps.schemas.message import (
MessageFlow,
MessageMetadata,
)
-from apps.schemas.scheduler import LLMConfig
from apps.schemas.task import TaskData
logger = logging.getLogger(__name__)
diff --git a/apps/llm/__init__.py b/apps/llm/__init__.py
index 4938632956ed55b0c295cc487712f49a674af25d..fa0ef8e4fc88ab295492772422bfee42e51ba906 100644
--- a/apps/llm/__init__.py
+++ b/apps/llm/__init__.py
@@ -4,11 +4,13 @@
from .embedding import Embedding
from .generator import JsonGenerator
from .llm import LLM
+from .schema import LLMConfig
from .token import token_calculator
__all__ = [
"LLM",
"Embedding",
"JsonGenerator",
+ "LLMConfig",
"token_calculator",
]
diff --git a/apps/llm/generator.py b/apps/llm/generator.py
index 045076fcfcc3246053e62a253ffb1c007bdaf615..3ce5a300924912e0cdd9758967cc7569ca64b094 100644
--- a/apps/llm/generator.py
+++ b/apps/llm/generator.py
@@ -12,10 +12,10 @@ from jsonschema import Draft7Validator
from apps.models import LLMType
from apps.schemas.llm import LLMFunctions
-from apps.schemas.scheduler import LLMConfig
from .llm import LLM
from .prompt import JSON_GEN_BASIC, JSON_NO_FUNCTION_CALL
+from .schema import LLMConfig
_logger = logging.getLogger(__name__)
@@ -72,7 +72,7 @@ class JsonGenerator:
# 如果支持FunctionCall,使用provider的function调用逻辑
result = await self._call_with_function()
else:
- # 如果不支持FunctionCall,使用JSON_GEN_BASIC和provider的call进行调用
+ # 如果不支持FunctionCall,使用JSON_GEN_BASIC
result = await self._call_without_function(max_tokens, temperature)
# 校验结果
@@ -157,9 +157,16 @@ class JsonGenerator:
while self._count < JSON_GEN_MAX_TRIAL:
self._count += 1
try:
+ # 如果_single_trial没有抛出异常,直接返回结果,不进行重试
return await self._single_trial()
- except Exception: # noqa: BLE001
- # 校验失败,_single_trial已经将错误信息记录到self._err_info中并记录日志
- continue
-
+ except Exception:
+ # 每次捕获异常时都记录错误
+ _logger.exception(
+ "[JSONGenerator] 第 %d/%d 次尝试失败",
+ self._count,
+ JSON_GEN_MAX_TRIAL,
+ )
+ # 如果还有重试机会,继续下一次尝试
+ if self._count < JSON_GEN_MAX_TRIAL:
+ continue
return {}
diff --git a/apps/llm/schema.py b/apps/llm/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..53296a1bd51d379510fdc14adf9104e2f2976f13
--- /dev/null
+++ b/apps/llm/schema.py
@@ -0,0 +1,20 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""LLM配置"""
+
+from pydantic import BaseModel, Field
+
+from .embedding import Embedding
+from .llm import LLM
+
+
+class LLMConfig(BaseModel):
+ """LLM配置"""
+
+ reasoning: LLM = Field(description="推理LLM")
+ function: LLM | None = Field(description="函数LLM")
+ embedding: Embedding | None = Field(description="Embedding")
+
+ class Config:
+ """配置"""
+
+ arbitrary_types_allowed = True
diff --git a/apps/scheduler/call/slot/prompt.py b/apps/scheduler/call/slot/prompt.py
index 89c3829791f07d33207d5988af674ac8d2830f67..0767cda56a3fed186caccea2f969cf7a34f42e84 100644
--- a/apps/scheduler/call/slot/prompt.py
+++ b/apps/scheduler/call/slot/prompt.py
@@ -1,181 +1,200 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""自动参数填充工具的提示词"""
+from textwrap import dedent
+
from apps.models import LanguageType
+# 主查询提示词模板
SLOT_GEN_PROMPT: dict[LanguageType, str] = {
- LanguageType.CHINESE: r"""
+ LanguageType.CHINESE: dedent(r"""
- 你是一个可以使用工具的AI助手,正尝试使用工具来完成任务。
- 目前,你正在生成一个JSON参数对象,以作为调用工具的输入。
- 请根据用户输入、背景信息、工具信息和JSON Schema内容,生成符合要求的JSON对象。
-
- 背景信息将在中给出,工具信息将在中给出,JSON Schema将在中给出,\
- 用户的问题将在中给出。
- 请在
+
+ 你需要使用 fill_parameters 工具来填充参数。请仔细结合对话上下文,根据用户的问题和历史对话内容,\
+为当前工具调用生成合适的参数。要求:
+ 1. 严格按照下面提供的 JSON Schema 格式输出,不要编造不存在的字段;
+ 2. 参数值优先从用户的当前问题中提取;如果当前问题中没有,则从对话历史中查找相关信息;
+ 3. 必须仔细理解对话上下文,确保生成的参数与上下文保持一致,符合用户的真实意图;
+ 4. 只输出 JSON 对象,不要包含任何解释说明或其他内容;
+ 5. 如果 JSON Schema 中的字段是可选的(optional),在无法确定其值时可以省略该字段;
+ 6. 不要编造或猜测参数值,所有参数必须基于对话上下文中的实际信息。
+
- 用户询问杭州今天的天气情况。AI回复杭州今天晴,温度20℃。用户询问杭州明天的天气情况。
+ 用户询问:"北京今天天气怎么样?" AI回复:"北京今天晴,温度20℃。" 用户继续问:"明天呢?"
-
- 杭州明天的天气情况如何?
-
工具名称:check_weather
工具描述:查询指定城市的天气信息
-
+
{
"type": "object",
"properties": {
- "city": {
- "type": "string",
- "description": "城市名称"
- },
- "date": {
- "type": "string",
- "description": "查询日期"
- },
- "required": ["city", "date"]
- }
+ "city": {"type": "string", "description": "城市名称"},
+ "date": {"type": "string", "description": "查询日期"}
+ },
+ "required": ["city", "date"]
}
-
-
+
+
+
-
-
- 以下是对用户给你的任务的历史总结,在中给出:
-
- {{summary}}
-
- 附加的条目化信息:
- {{ facts }}
-
-
- 在本次任务中,你已经调用过一些工具,并获得了它们的输出,在中给出:
-
- {% for tool in history_data %}
-
- {{ tool.step_name }}
- {{ tool.step_description }}
-
-
- {% endfor %}
-
-
-
- {{question}}
-
-
- 工具名称:{{current_tool["name"]}}
- 工具描述:{{current_tool["description"]}}
-
-
- {{schema}}
-
-