diff --git a/apps/common/oidc.py b/apps/common/oidc.py index 46c7ad93d207ee43bc0ac23f48e5f45f60d17d0f..a54f1a60d25f0d5ca68da0c64a249e5e246aca35 100644 --- a/apps/common/oidc.py +++ b/apps/common/oidc.py @@ -10,7 +10,7 @@ from apps.constants import OIDC_ACCESS_TOKEN_EXPIRE_TIME, OIDC_REFRESH_TOKEN_EXP from apps.models import Session, SessionType from .config import config -from .oidc_provider import AuthhubOIDCProvider, OpenEulerOIDCProvider +from .oidc_provider import AutheliaOIDCProvider, AuthhubOIDCProvider, OpenEulerOIDCProvider logger = logging.getLogger(__name__) @@ -24,6 +24,8 @@ class OIDCProvider: self.provider = OpenEulerOIDCProvider() elif config.login.provider == "authhub": self.provider = AuthhubOIDCProvider() + elif config.login.provider == "authelia": + self.provider = AutheliaOIDCProvider() else: err = f"[OIDC] 未知OIDC提供商: {config.login.provider}" logger.error(err) diff --git a/apps/common/oidc_provider/__init__.py b/apps/common/oidc_provider/__init__.py index f84ba7204e3735f63a142f5b500e7591ddb809b8..cde9c166197653ce1e933b332331949a613e5a04 100644 --- a/apps/common/oidc_provider/__init__.py +++ b/apps/common/oidc_provider/__init__.py @@ -1,11 +1,13 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """OIDC Provider""" +from .authelia import AutheliaOIDCProvider from .authhub import AuthhubOIDCProvider from .base import OIDCProviderBase from .openeuler import OpenEulerOIDCProvider __all__ = [ + "AutheliaOIDCProvider", "AuthhubOIDCProvider", "OIDCProviderBase", "OpenEulerOIDCProvider", diff --git a/apps/common/oidc_provider/authelia.py b/apps/common/oidc_provider/authelia.py new file mode 100644 index 0000000000000000000000000000000000000000..e57630ab4aad42752eedb5dbb6cca8ffa146de00 --- /dev/null +++ b/apps/common/oidc_provider/authelia.py @@ -0,0 +1,160 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Authelia OIDC Provider""" + +import logging +import secrets +from typing import Any + +import httpx +from fastapi import status + +from apps.common.config import config +from apps.common.oidc_provider.base import OIDCProviderBase +from apps.schemas.config import OIDCConfig + +logger = logging.getLogger(__name__) + + +class AutheliaOIDCProvider(OIDCProviderBase): + """Authelia OIDC Provider""" + + @classmethod + def _get_login_config(cls) -> OIDCConfig: + """获取并验证登录配置""" + login_config = config.login.settings + if not isinstance(login_config, OIDCConfig): + err = "OpenEuler OIDC配置错误" + raise TypeError(err) + return login_config + + @classmethod + async def get_oidc_token(cls, code: str) -> dict[str, Any]: + """获取Authelia OIDC Token""" + login_config = cls._get_login_config() + + data = { + "client_id": login_config.app_id, + "client_secret": login_config.app_secret, + "redirect_uri": login_config.login_api, + "grant_type": "authorization_code", + "code": code, + } + headers = { + "Content-Type": "application/x-www-form-urlencoded", + } + url = await cls.get_access_token_url() + result = None + async with httpx.AsyncClient(verify=False) as client: # noqa: S501 + resp = await client.post( + url, + headers=headers, + data=data, + timeout=10, + ) + if resp.status_code != status.HTTP_200_OK: + err = f"[Authelia] 获取OIDC Token失败: {resp.status_code},完整输出: {resp.text}" + raise RuntimeError(err) + logger.info("[Authelia] 获取OIDC Token成功: %s", resp.text) + result = resp.json() + return { + "access_token": result["access_token"], + "refresh_token": result.get("refresh_token", ""), + } + + @classmethod + async def get_oidc_user(cls, access_token: str) -> dict[str, Any]: + """获取Authelia OIDC用户""" + login_config = cls._get_login_config() + + if not access_token: + err = "Access token is empty." + raise RuntimeError(err) + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + url = login_config.host.rstrip("/") + "/api/oidc/userinfo" + result = None + async with httpx.AsyncClient(verify=False) as client: # noqa: S501 + resp = await client.get( + url, + headers=headers, + timeout=10, + ) + if resp.status_code != status.HTTP_200_OK: + err = f"[Authelia] 获取用户信息失败: {resp.status_code},完整输出: {resp.text}" + raise RuntimeError(err) + logger.info("[Authelia] 获取用户信息成功: %s", resp.text) + result = resp.json() + + return { + "user_sub": result.get("sub", result.get("preferred_username", "")), + "user_name": result.get("name", result.get("preferred_username", result.get("nickname", ""))), + } + + @classmethod + async def get_login_status(cls, cookie: dict[str, str]) -> dict[str, Any]: + """检查登录状态;Authelia通过session cookie检查""" + login_config = cls._get_login_config() + + headers = { + "Content-Type": "application/json", + } + url = login_config.host.rstrip("/") + "/api/user/info" + async with httpx.AsyncClient(verify=False) as client: # noqa: S501 + resp = await client.get( + url, + headers=headers, + cookies=cookie, + timeout=10, + ) + if resp.status_code != status.HTTP_200_OK: + err = f"[Authelia] 获取登录状态失败: {resp.status_code},完整输出: {resp.text}" + raise RuntimeError(err) + + # Authelia 返回用户信息表示已登录 + return { + "access_token": "", + "refresh_token": "", + } + + @classmethod + async def oidc_logout(cls, cookie: dict[str, str]) -> None: + """触发OIDC的登出""" + login_config = cls._get_login_config() + + headers = { + "Content-Type": "application/json", + } + url = login_config.host.rstrip("/") + "/api/logout" + async with httpx.AsyncClient(verify=False) as client: # noqa: S501 + resp = await client.post( + url, + headers=headers, + cookies=cookie, + timeout=10, + ) + # Authelia登出成功通常返回200或302重定向 + if resp.status_code not in [status.HTTP_200_OK, status.HTTP_302_FOUND]: + err = f"[Authelia] 登出失败: {resp.status_code},完整输出: {resp.text}" + raise RuntimeError(err) + + @classmethod + async def get_redirect_url(cls) -> str: + """获取Authelia OIDC 重定向URL""" + login_config = cls._get_login_config() + + # 生成随机的 state 参数以确保安全性和唯一性 + state = secrets.token_urlsafe(32) + return (f"{login_config.host.rstrip('/')}/api/oidc/authorization?" + f"client_id={login_config.app_id}&" + f"response_type=code&" + f"scope=openid profile email&" + f"redirect_uri={login_config.login_api}&" + f"state={state}") + + @classmethod + async def get_access_token_url(cls) -> str: + """获取Authelia OIDC 访问Token URL""" + login_config = cls._get_login_config() + return login_config.host.rstrip("/") + "/api/oidc/token" diff --git a/apps/common/queue.py b/apps/common/queue.py index c0981527276e7e5613294e38ed0cdf1b4a2c6490..685511477eea0bda1fd8eb6df576aa5e125975ff 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator from datetime import UTC, datetime from typing import Any +from apps.llm import LLMConfig from apps.schemas.enum_var import EventType from apps.schemas.message import ( HeartbeatData, @@ -15,7 +16,6 @@ from apps.schemas.message import ( MessageFlow, MessageMetadata, ) -from apps.schemas.scheduler import LLMConfig from apps.schemas.task import TaskData logger = logging.getLogger(__name__) diff --git a/apps/llm/__init__.py b/apps/llm/__init__.py index 4938632956ed55b0c295cc487712f49a674af25d..fa0ef8e4fc88ab295492772422bfee42e51ba906 100644 --- a/apps/llm/__init__.py +++ b/apps/llm/__init__.py @@ -4,11 +4,13 @@ from .embedding import Embedding from .generator import JsonGenerator from .llm import LLM +from .schema import LLMConfig from .token import token_calculator __all__ = [ "LLM", "Embedding", "JsonGenerator", + "LLMConfig", "token_calculator", ] diff --git a/apps/llm/generator.py b/apps/llm/generator.py index 045076fcfcc3246053e62a253ffb1c007bdaf615..3ce5a300924912e0cdd9758967cc7569ca64b094 100644 --- a/apps/llm/generator.py +++ b/apps/llm/generator.py @@ -12,10 +12,10 @@ from jsonschema import Draft7Validator from apps.models import LLMType from apps.schemas.llm import LLMFunctions -from apps.schemas.scheduler import LLMConfig from .llm import LLM from .prompt import JSON_GEN_BASIC, JSON_NO_FUNCTION_CALL +from .schema import LLMConfig _logger = logging.getLogger(__name__) @@ -72,7 +72,7 @@ class JsonGenerator: # 如果支持FunctionCall,使用provider的function调用逻辑 result = await self._call_with_function() else: - # 如果不支持FunctionCall,使用JSON_GEN_BASIC和provider的call进行调用 + # 如果不支持FunctionCall,使用JSON_GEN_BASIC result = await self._call_without_function(max_tokens, temperature) # 校验结果 @@ -157,9 +157,16 @@ class JsonGenerator: while self._count < JSON_GEN_MAX_TRIAL: self._count += 1 try: + # 如果_single_trial没有抛出异常,直接返回结果,不进行重试 return await self._single_trial() - except Exception: # noqa: BLE001 - # 校验失败,_single_trial已经将错误信息记录到self._err_info中并记录日志 - continue - + except Exception: + # 每次捕获异常时都记录错误 + _logger.exception( + "[JSONGenerator] 第 %d/%d 次尝试失败", + self._count, + JSON_GEN_MAX_TRIAL, + ) + # 如果还有重试机会,继续下一次尝试 + if self._count < JSON_GEN_MAX_TRIAL: + continue return {} diff --git a/apps/llm/schema.py b/apps/llm/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..53296a1bd51d379510fdc14adf9104e2f2976f13 --- /dev/null +++ b/apps/llm/schema.py @@ -0,0 +1,20 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""LLM配置""" + +from pydantic import BaseModel, Field + +from .embedding import Embedding +from .llm import LLM + + +class LLMConfig(BaseModel): + """LLM配置""" + + reasoning: LLM = Field(description="推理LLM") + function: LLM | None = Field(description="函数LLM") + embedding: Embedding | None = Field(description="Embedding") + + class Config: + """配置""" + + arbitrary_types_allowed = True diff --git a/apps/scheduler/call/slot/prompt.py b/apps/scheduler/call/slot/prompt.py index 89c3829791f07d33207d5988af674ac8d2830f67..0767cda56a3fed186caccea2f969cf7a34f42e84 100644 --- a/apps/scheduler/call/slot/prompt.py +++ b/apps/scheduler/call/slot/prompt.py @@ -1,181 +1,200 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """自动参数填充工具的提示词""" +from textwrap import dedent + from apps.models import LanguageType +# 主查询提示词模板 SLOT_GEN_PROMPT: dict[LanguageType, str] = { - LanguageType.CHINESE: r""" + LanguageType.CHINESE: dedent(r""" - 你是一个可以使用工具的AI助手,正尝试使用工具来完成任务。 - 目前,你正在生成一个JSON参数对象,以作为调用工具的输入。 - 请根据用户输入、背景信息、工具信息和JSON Schema内容,生成符合要求的JSON对象。 - - 背景信息将在中给出,工具信息将在中给出,JSON Schema将在中给出,\ - 用户的问题将在中给出。 - 请在中输出生成的JSON对象。 - - 要求: - 1. 严格按照JSON Schema描述的JSON格式输出,不要编造不存在的字段。 - 2. JSON字段的值优先使用用户输入中的内容。如果用户输入没有,则使用背景信息中的内容。 - 3. 只输出JSON对象,不要输出任何解释说明,不要输出任何其他内容。 - 4. 如果JSON Schema中描述的JSON字段是可选的,则可以不输出该字段。 - 5. example中仅为示例,不要照搬example中的内容,不要将example中的内容作为输出。 - + + 你需要使用 fill_parameters 工具来填充参数。请仔细结合对话上下文,根据用户的问题和历史对话内容,\ +为当前工具调用生成合适的参数。要求: + 1. 严格按照下面提供的 JSON Schema 格式输出,不要编造不存在的字段; + 2. 参数值优先从用户的当前问题中提取;如果当前问题中没有,则从对话历史中查找相关信息; + 3. 必须仔细理解对话上下文,确保生成的参数与上下文保持一致,符合用户的真实意图; + 4. 只输出 JSON 对象,不要包含任何解释说明或其他内容; + 5. 如果 JSON Schema 中的字段是可选的(optional),在无法确定其值时可以省略该字段; + 6. 不要编造或猜测参数值,所有参数必须基于对话上下文中的实际信息。 + - 用户询问杭州今天的天气情况。AI回复杭州今天晴,温度20℃。用户询问杭州明天的天气情况。 + 用户询问:"北京今天天气怎么样?" AI回复:"北京今天晴,温度20℃。" 用户继续问:"明天呢?" - - 杭州明天的天气情况如何? - 工具名称:check_weather 工具描述:查询指定城市的天气信息 - + { "type": "object", "properties": { - "city": { - "type": "string", - "description": "城市名称" - }, - "date": { - "type": "string", - "description": "查询日期" - }, - "required": ["city", "date"] - } + "city": {"type": "string", "description": "城市名称"}, + "date": {"type": "string", "description": "查询日期"} + }, + "required": ["city", "date"] } - - + + { - "city": "杭州", + "city": "北京", "date": "明天" } - + - - - 以下是对用户给你的任务的历史总结,在中给出: - - {{summary}} - - 附加的条目化信息: - {{ facts }} - - - 在本次任务中,你已经调用过一些工具,并获得了它们的输出,在中给出: - - {% for tool in history_data %} - - {{ tool.step_name }} - {{ tool.step_description }} - {{ tool.output_data }} - - {% endfor %} - - - - {{question}} - - - 工具名称:{{current_tool["name"]}} - 工具描述:{{current_tool["description"]}} - - - {{schema}} - - - """, - LanguageType.ENGLISH: r""" - - You are an AI assistant capable of using tools to complete tasks. - Currently, you are generating a JSON parameter object as input for calling a tool. - Please generate a compliant JSON object based on user input, background information, tool information, \ -and JSON Schema content. - - Background information will be provided in , tool information in , JSON Schema in \ -, and the user's question in . - Output the generated JSON object in . - - Requirements: - 1. Strictly follow the JSON format described in the JSON Schema. Do not fabricate non-existent fields. - 2. Prioritize using values from user input for JSON fields. If not available, use content from background \ -information. - 3. Only output the JSON object. Do not include any explanations or additional content. - 4. Optional fields in the JSON Schema may be omitted. - 5. Examples are for illustration only. Do not copy content from examples or use them as output. - 6. Respond in the same language as the user's question by default, unless explicitly requested otherwise. - - - User asked about today's weather in Hangzhou. AI replied it's sunny, 20℃. User then asks about \ -tomorrow's weather in Hangzhou. - - - What's the weather like in Hangzhou tomorrow? - - - Tool name: check_weather - Tool description: Query weather information for specified cities - - - { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name" - }, - "date": { - "type": "string", - "description": "Query date" + + 工具名称:{{current_tool["name"]}} + 工具描述:{{current_tool["description"]}} + + + + {{schema}} + + + 现在,请使用 fill_parameters 工具来填充参数: + """).strip("\n"), + LanguageType.ENGLISH: dedent(r""" + + + You need to use the fill_parameters tool to fill in the parameters. Please carefully combine the \ +conversation context, and generate appropriate parameters for the current tool call based on the user's question \ +and historical conversation. Requirements: + 1. Strictly follow the JSON Schema format provided below for output; do not fabricate \ +non-existent fields; + 2. Parameter values should be extracted from the user's current question first; if not available \ +in the current question, search for relevant information from the conversation history; + 3. You must carefully understand the conversation context to ensure that the generated parameters \ +are consistent with the context and meet the user's true intent; + 4. Only output the JSON object; do not include any explanations or other content; + 5. If a field in the JSON Schema is optional, you may omit it when its value cannot be determined; + 6. Do not fabricate or guess parameter values; all parameters must be based on actual information \ +from the conversation context. + + + + + User asks: "What's the weather like in Beijing today?" AI replies: "It's sunny in Beijing today, \ +20℃." User continues: "What about tomorrow?" + + + Tool name: check_weather + Tool description: Query weather information for specified cities + + + { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Query date"} }, "required": ["city", "date"] } - } - - - { - "city": "Hangzhou", - "date": "tomorrow" - } - - - - - Historical summary of tasks given by user, provided in : - - {{summary}} - - Additional itemized information: - {{ facts }} - - - During this task, you have called some tools and obtained their outputs, provided in : - - {% for tool in history_data %} - - {{ tool.step_name }} - {{ tool.step_description }} - {{ tool.output_data }} - - {% endfor %} - - - - {{question}} - + + + { + "city": "Beijing", + "date": "tomorrow" + } + + + + Tool name: {{current_tool["name"]}} Tool description: {{current_tool["description"]}} - + + {{schema}} - - - """, + + + Now, please use the fill_parameters tool to fill in the parameters: + """).strip("\n"), +} + +# 任务背景模板(用于conversation中的user消息) +SLOT_SUMMARY_TEMPLATE: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent(r""" + ## 任务背景 + + {{summary}} + {%- if facts %} + + ### 重要事实信息 + + 以下是从对话中提取的关键事实,在填充参数时请特别参考这些信息: + + {%- for fact in facts %} + - {{fact}} + {%- endfor %} + {%- endif %} + """).strip("\n"), + LanguageType.ENGLISH: dedent(r""" + ## Task Background + + {{summary}} + {%- if facts %} + + ### Important Facts + + The following are key facts extracted from the conversation. Please refer to this information when filling \ +in parameters: + + {%- for fact in facts %} + - {{fact}} + {%- endfor %} + {%- endif %} + """).strip("\n"), +} + +# 历史工具调用模板(用于conversation中的assistant消息) +SLOT_HISTORY_TEMPLATE: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent(r""" + ## 历史工具调用记录 + + 以下是我在处理当前任务时已经调用过的工具及其输出结果,这些信息可能包含了填充参数所需的数据: + + {%- for tool in history_data %} + + ### 工具 {{loop.index}}: {{tool.step_name}} + {%- if tool.step_description %} + + **功能描述**: {{tool.step_description}} + {%- endif %} + {%- if tool.output_data %} + + **输出结果**: + ```json + {{tool.output_data | tojson(ensure_ascii=False, indent=2)}} + ``` + {%- endif %} + {%- endfor %} + """).strip("\n"), + LanguageType.ENGLISH: dedent(r""" + ## Historical Tool Call Records + + The following are the tools I have already called while processing the current task and their output results. \ +This information may contain the data needed to fill in the parameters: + + {%- for tool in history_data %} + + ### Tool {{loop.index}}: {{tool.step_name}} + {%- if tool.step_description %} + + **Function Description**: {{tool.step_description}} + {%- endif %} + {%- if tool.output_data %} + + **Output Result**: + ```json + {{tool.output_data | tojson(ensure_ascii=False, indent=2)}} + ``` + {%- endif %} + {%- endfor %} + """).strip("\n"), } diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 53c4894fdbfcc6c79c97a5dc7a6d9e3a15b9721b..9f34f33d3042dad3d2d7fe2ffc1aafa88ff84050 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -16,7 +16,7 @@ from apps.scheduler.slot.slot import Slot as SlotProcessor from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars -from .prompt import SLOT_GEN_PROMPT +from .prompt import SLOT_GEN_PROMPT, SLOT_HISTORY_TEMPLATE, SLOT_SUMMARY_TEMPLATE from .schema import SlotInput, SlotOutput if TYPE_CHECKING: @@ -47,39 +47,80 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> tuple[str, dict[str, Any]]: - """使用大模型填充参数;若大模型解析度足够,则直接返回结果""" + """使用JsonGenerator填充参数;若大模型解析度足够,则直接返回结果""" env = SandboxedEnvironment( loader=BaseLoader(), autoescape=False, trim_blocks=True, lstrip_blocks=True, ) - template = env.from_string(SLOT_GEN_PROMPT[self._sys_vars.language]) - conversation = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": template.render( - current_tool={ - "name": self.name, - "description": self.description, - }, - schema=remaining_schema, - history_data=self._flow_history, + # 获取当前语言 + language = self._sys_vars.language + + # 渲染查询模板(不包含历史信息) + query_template = env.from_string(SLOT_GEN_PROMPT[language]) + query = query_template.render( + current_tool={ + "name": self.name, + "description": self.description, + }, + schema=remaining_schema, + ) + + # 组装conversation:将summary和历史工具调用组装为对话格式 + conversation = [] + + # 使用Jinja2模板渲染任务总结 + if self.summary or self.facts: + summary_template = env.from_string(SLOT_SUMMARY_TEMPLATE[language]) + summary_content = summary_template.render( summary=self.summary, - question=self._question, facts=self.facts, - )}, - ] + ) + conversation.append({ + "role": "user", + "content": summary_content, + }) + assistant_response = ( + "我理解了任务背景,请继续。" + if language == LanguageType.CHINESE + else "I understand the task context. Please continue." + ) + conversation.append({ + "role": "assistant", + "content": assistant_response, + }) + + # 使用Jinja2模板渲染历史工具调用 + if self._flow_history: + history_template = env.from_string(SLOT_HISTORY_TEMPLATE[language]) + history_content = history_template.render( + history_data=self._flow_history, + ) + if history_content.strip(): + conversation.append({ + "role": "assistant", + "content": history_content, + }) + + # 构建OpenAI标准FunctionCall格式 + function = { + "name": "fill_parameters", + "description": f"Fill the missing parameters for {self.name}. {self.description}", + "parameters": remaining_schema, + } + + # 使用JsonGenerator进行参数填充 + generator = JsonGenerator( + llm_config=self._llm_obj, + query=query, + conversation=conversation, + function=function, + ) - # 使用大模型进行尝试 - answer = "" - async for chunk in self._llm(messages=conversation, streaming=True): - answer += chunk - answer = await JsonGenerator.process_response(answer) - try: - data = json.loads(answer) - except Exception: # noqa: BLE001 - data = {} + data = await generator.generate() + answer = json.dumps(data, ensure_ascii=False) return answer, data async def _function_slot_fill(self, answer: str, remaining_schema: dict[str, Any]) -> dict[str, Any]: diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 28496445cc47aeb2f6f3e07a625788866e94fa90..1ff24a31db9fef40b0f852efa80f2c841726e0d1 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -15,7 +15,7 @@ from apps.services.record import RecordManager if TYPE_CHECKING: from apps.common.queue import MessageQueue - from apps.schemas.scheduler import LLMConfig + from apps.llm import LLMConfig from apps.schemas.task import TaskData diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index bb3aa7d5c1d78d0860fc6197a9eeaaaae3ae7420..610fd1b2fa9c9bb54d68c8903af1b51ff20c289e 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -10,13 +10,12 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from mcp.types import TextContent -from apps.llm import JsonGenerator +from apps.llm import JsonGenerator, LLMConfig from apps.models import LanguageType, MCPTools from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.pool.mcp.client import MCPClient from apps.scheduler.pool.mcp.pool import mcp_pool from apps.schemas.mcp import MCPContext, MCPPlanItem -from apps.schemas.scheduler import LLMConfig from apps.services.mcp_service import MCPServiceManager logger = logging.getLogger(__name__) diff --git a/apps/scheduler/mcp/plan.py b/apps/scheduler/mcp/plan.py index e00648a14ecd289d8c9751a7e5ff1bf2640fca1e..f24ae8c1fb64951722a11a2323f616066d0683b3 100644 --- a/apps/scheduler/mcp/plan.py +++ b/apps/scheduler/mcp/plan.py @@ -6,10 +6,9 @@ import logging from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -from apps.llm import JsonGenerator +from apps.llm import JsonGenerator, LLMConfig from apps.models import LanguageType, MCPTools from apps.schemas.mcp import MCPPlan -from apps.schemas.scheduler import LLMConfig from .prompt import CREATE_PLAN, FINAL_ANSWER diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 17dd699f3e8132fd5ac0a1fe94dbca03113dac34..0f4362de8a605d05219f21ba1ccdfb0770e5b047 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -6,10 +6,9 @@ import logging from sqlalchemy import select from apps.common.postgres import postgres -from apps.llm import JsonGenerator +from apps.llm import JsonGenerator, LLMConfig from apps.models import LanguageType, MCPTools from apps.schemas.mcp import MCPSelectResult -from apps.schemas.scheduler import LLMConfig from apps.services.mcp_service import MCPServiceManager from .prompt import MCP_FUNCTION_SELECT diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 1a4d168db2f959277cc14f0199abb246a9560e70..934e65a6eb3aecdf2306ae4de4ce9a3c5f8d4aa9 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -4,9 +4,8 @@ import logging from typing import Any -from apps.llm import JsonGenerator +from apps.llm import JsonGenerator, LLMConfig from apps.models import LanguageType -from apps.schemas.scheduler import LLMConfig from apps.schemas.task import TaskData _logger = logging.getLogger(__name__) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index a08cb4547f0ec795f3bdde6eceb5ffd67158be80..4c0ce897e7980b68941f1e199793bdf6aad78f95 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -11,7 +11,7 @@ from jinja2.sandbox import SandboxedEnvironment from apps.common.queue import MessageQueue from apps.common.security import Security -from apps.llm import LLM, Embedding, JsonGenerator +from apps.llm import LLM, Embedding, JsonGenerator, LLMConfig from apps.models import ( AppType, Conversation, @@ -34,7 +34,7 @@ from apps.schemas.message import ( ) from apps.schemas.record import FlowHistory, RecordContent from apps.schemas.request_data import RequestData -from apps.schemas.scheduler import LLMConfig, TopFlow +from apps.schemas.scheduler import TopFlow from apps.schemas.task import TaskData from apps.services.activity import Activity from apps.services.appcenter import AppCenterManager diff --git a/apps/schemas/config.py b/apps/schemas/config.py index dba8fab9c8a43b6a0b551d039c28e1a1c6c5ab29..34e48b57d02417f9c3567be951cca7dbfbaf01fb 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -33,7 +33,9 @@ class FixedUserConfig(BaseModel): class LoginConfig(BaseModel): """OIDC配置""" - provider: Literal["authhub", "openeuler", "disable"] = Field(description="OIDC Provider", default="authhub") + provider: Literal[ + "authhub", "openeuler", "authelia", "disable", + ] = Field(description="OIDC Provider", default="authhub") admin_user: list[str] = Field(description="管理员用户ID列表", default=[]) settings: OIDCConfig | FixedUserConfig = Field(description="OIDC 配置") diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py index 9c7595ea3599dfbdce55c4a5be483a8794d9ef39..98c34ef6881079ca4d5525be4faa9f8f882b8221 100644 --- a/apps/schemas/scheduler.py +++ b/apps/schemas/scheduler.py @@ -6,25 +6,11 @@ from typing import Any from pydantic import BaseModel, Field -from apps.llm import LLM, Embedding from apps.models import ExecutorHistory, LanguageType from .enum_var import CallOutputType -class LLMConfig(BaseModel): - """LLM配置""" - - reasoning: LLM = Field(description="推理LLM") - function: LLM | None = Field(description="函数LLM") - embedding: Embedding | None = Field(description="Embedding") - - class Config: - """配置""" - - arbitrary_types_allowed = True - - class CallInfo(BaseModel): """Call的名称和描述"""