diff --git a/apps/common/encoder.py b/apps/common/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..94f535dde56d23d65d26389e73867e48040de88c --- /dev/null +++ b/apps/common/encoder.py @@ -0,0 +1,21 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""JSON编码器模块""" + +import json +from typing import Any +from uuid import UUID + + +class UUIDEncoder(json.JSONEncoder): + """支持UUID类型的JSON编码器""" + + def default(self, obj: Any) -> Any: + """ + 重写default方法以支持UUID类型序列化 + + :param obj: 待序列化的对象 + :return: 序列化后的对象 + """ + if isinstance(obj, UUID): + return str(obj) + return super().default(obj) diff --git a/apps/common/oidc.py b/apps/common/oidc.py index a54f1a60d25f0d5ca68da0c64a249e5e246aca35..37317014144656778ed1d7950ba089582725e468 100644 --- a/apps/common/oidc.py +++ b/apps/common/oidc.py @@ -32,17 +32,17 @@ class OIDCProvider: raise NotImplementedError(err) @staticmethod - async def set_token(user_sub: str, access_token: str, refresh_token: str) -> None: + async def set_token(user_id: str, access_token: str, refresh_token: str) -> None: """设置MongoDB中的OIDC Token到sessions集合""" async with postgres.session() as session: await session.merge(Session( - userSub=user_sub, + userId=user_id, sessionType=SessionType.ACCESS_TOKEN, token=access_token, validUntil=datetime.now(UTC) + timedelta(minutes=OIDC_ACCESS_TOKEN_EXPIRE_TIME), )) await session.merge(Session( - userSub=user_sub, + userId=user_id, sessionType=SessionType.REFRESH_TOKEN, token=refresh_token, validUntil=datetime.now(UTC) + timedelta(minutes=OIDC_REFRESH_TOKEN_EXPIRE_TIME), diff --git a/apps/common/queue.py b/apps/common/queue.py index 5cd66e617956d16884f07c92ece5fd4f701e6141..befeab3c86c99c161540665785b04812667557d9 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -18,6 +18,8 @@ from apps.schemas.message import ( ) from apps.schemas.task import TaskData +from .encoder import UUIDEncoder + logger = logging.getLogger(__name__) @@ -76,7 +78,9 @@ class MessageQueue: content=data, ) - await self._queue.put(json.dumps(message.model_dump(by_alias=True, exclude_none=True), ensure_ascii=False)) + await self._queue.put( + json.dumps(message.model_dump(by_alias=True, exclude_none=True), ensure_ascii=False, cls=UUIDEncoder), + ) async def get(self) -> AsyncGenerator[str, None]: """从Queue中获取消息;变为async generator""" @@ -98,7 +102,7 @@ class MessageQueue: async def _heartbeat(self) -> None: """组装用于向用户(前端/Shell端)输出的心跳""" heartbeat_template = HeartbeatData() - heartbeat_msg = json.dumps(heartbeat_template.model_dump(by_alias=True), ensure_ascii=False) + heartbeat_msg = json.dumps(heartbeat_template.model_dump(by_alias=True), ensure_ascii=False, cls=UUIDEncoder) while True: # 如果关闭,则停止心跳 diff --git a/apps/llm/generator.py b/apps/llm/generator.py index 656ef2800580ab795ffd49cc9b7ec656f27e1411..da5933d5f2344972103235fc5bc7fdc78b1b32ba 100644 --- a/apps/llm/generator.py +++ b/apps/llm/generator.py @@ -21,6 +21,16 @@ _logger = logging.getLogger(__name__) JSON_GEN_MAX_TRIAL = 3 + +class JsonValidationError(Exception): + """JSON验证失败异常,携带生成的结果""" + + def __init__(self, message: str, result: dict[str, Any]) -> None: + """初始化异常""" + super().__init__(message) + self.result = result + + class JsonGenerator: """综合Json生成器(全局单例)""" @@ -56,7 +66,7 @@ class JsonGenerator: def _build_messages( self, - function: dict[str, Any], + functions: list[dict[str, Any]], conversation: list[dict[str, str]], language: LanguageType = LanguageType.CHINESE, ) -> list[dict[str, str]]: @@ -70,9 +80,8 @@ class JsonGenerator: template = self._env.from_string(JSON_GEN[language]) prompt = template.render( query=query, - conversation=conversation[:-1], - schema=function["parameters"], use_xml_format=False, + functions=functions, ) messages = [*conversation[:-1], {"role": "user", "content": prompt}] @@ -117,7 +126,7 @@ class JsonGenerator: async def _single_trial( self, function: dict[str, Any], - context: list[dict[str, str]], + conversation: list[dict[str, str]], language: LanguageType = LanguageType.CHINESE, ) -> dict[str, Any]: """单次尝试,包含校验逻辑;function使用OpenAI标准Function格式""" @@ -131,10 +140,10 @@ class JsonGenerator: # 执行生成 if self._support_function_call: # 如果支持FunctionCall - result = await self._call_with_function(function, context, language) + result = await self._call_with_function(function, conversation, language) else: # 如果不支持FunctionCall - result = await self._call_without_function(function, context, language) + result = await self._call_without_function(function, conversation, language) # 校验结果 try: @@ -144,11 +153,9 @@ class JsonGenerator: err_info = err_info.split("\n\n")[0] _logger.info("[JSONGenerator] 验证失败:%s", err_info) - context.append({ - "role": "assistant", - "content": f"Attempted to use tool but validation failed: {err_info}", - }) - raise + # 创建带结果的验证异常 + validation_error = JsonValidationError(err_info, result) + raise validation_error from err else: return result @@ -163,7 +170,7 @@ class JsonGenerator: err = "[JSONGenerator] 未初始化,请先调用init()方法" raise RuntimeError(err) - messages = self._build_messages(function, conversation, language) + messages = self._build_messages([function], conversation, language) tool = LLMFunctions( name=function["name"], @@ -172,7 +179,7 @@ class JsonGenerator: ) tool_call_result = {} - async for chunk in self._llm.call(messages, include_thinking=False, streaming=True, tools=[tool]): + async for chunk in self._llm.call(messages, include_thinking=False, streaming=False, tools=[tool]): if chunk.tool_call: tool_call_result.update(chunk.tool_call) @@ -193,11 +200,11 @@ class JsonGenerator: err = "[JSONGenerator] 未初始化,请先调用init()方法" raise RuntimeError(err) - messages = self._build_messages(function, conversation, language) + messages = self._build_messages([function], conversation, language) # 使用LLM的call方法获取响应 full_response = "" - async for chunk in self._llm.call(messages, include_thinking=False, streaming=True): + async for chunk in self._llm.call(messages, include_thinking=False, streaming=False): if chunk.content: full_response += chunk.content @@ -225,32 +232,50 @@ class JsonGenerator: schema = function["parameters"] Draft7Validator.check_schema(schema) - # 构建上下文 - context = [ - { - "role": "system", - "content": "You are a helpful assistant that can use tools to help answer user queries.", - }, - ] - if conversation: - context.extend(conversation) - count = 0 - original_context = context.copy() + + retry_conversation = conversation.copy() if conversation else [] while count < JSON_GEN_MAX_TRIAL: count += 1 try: # 如果_single_trial没有抛出异常,直接返回结果,不进行重试 - return await self._single_trial(function, context, language) - except Exception: + return await self._single_trial(function, retry_conversation, language) + except JsonValidationError as e: _logger.exception( "[JSONGenerator] 第 %d/%d 次尝试失败", count, JSON_GEN_MAX_TRIAL, ) + if count < JSON_GEN_MAX_TRIAL: + # 将错误信息添加到retry_conversation中,模拟完整的对话流程 + err_info = str(e) + err_info = err_info.split("\n\n")[0] + + # 模拟assistant调用了函数但失败,包含实际获得的参数 + function_name = function["name"] + result_json = json.dumps(e.result, ensure_ascii=False) + retry_conversation.append({ + "role": "assistant", + "content": f"I called function {function_name} with parameters: {result_json}", + }) + + # 模拟user要求重新调用函数并给出错误原因 + retry_conversation.append({ + "role": "user", + "content": ( + f"The previous function call failed with validation error: {err_info}. " + "Please try again and ensure the output strictly follows the required schema." + ), + }) + continue + except Exception: + _logger.exception( + "[JSONGenerator] 第 %d/%d 次尝试失败(非验证错误)", + count, + JSON_GEN_MAX_TRIAL, + ) if count < JSON_GEN_MAX_TRIAL: continue - context = original_context return {} diff --git a/apps/llm/llm.py b/apps/llm/llm.py index fc005cd7fe73cdc5c81e3e757fac794a542f7da9..7cba3b6b538cca7261b2430404108ac806ef65dd 100644 --- a/apps/llm/llm.py +++ b/apps/llm/llm.py @@ -45,27 +45,14 @@ class LLM: streaming: bool = True, tools: list[LLMFunctions] | None = None, ) -> AsyncGenerator[LLMChunk, None]: - """调用大模型,分为流式和非流式两种""" - if streaming: - async for chunk in await self._provider.chat( - messages, - include_thinking=include_thinking, - tools=tools, - ): - yield chunk - else: - # 非流式模式下,需要收集所有内容 - async for _ in await self._provider.chat( - messages, - include_thinking=include_thinking, - tools=tools, - ): - continue - # 直接yield最终结果 - yield LLMChunk( - content=self._provider.full_answer, - reasoning_content=self._provider.full_thinking, - ) + """调用大模型,统一处理流式和非流式""" + async for chunk in self._provider.chat( + messages, + include_thinking=include_thinking, + streaming=streaming, + tools=tools, + ): + yield chunk @property def input_tokens(self) -> int: diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py index 6ece04539e6d50896980a8d6dc2bf956e420c900..fa424e731924cb2712470e881fc32ea96d047460 100644 --- a/apps/llm/prompt.py +++ b/apps/llm/prompt.py @@ -13,10 +13,9 @@ JSON_GEN: dict[LanguageType, str] = { - 你可以访问能够帮助收集信息的工具 - - 逐步使用工具,每次使用都基于之前的结果 + - 逐步使用工具,每次使用都基于之前的结果{% if use_xml_format %} - 用户的查询在 标签中提供 - {% if previous_trial %}- 查看 信息以避免重复错误{% endif %} - {% if use_xml_format %}- 使用 XML 样式的标签格式化工具调用,其中工具名称是根标签,每个参数是嵌套标签 + - 使用 XML 样式的标签格式化工具调用,其中工具名称是根标签,每个参数是嵌套标签 - 使用架构中指定的确切工具名称和参数名称 - 基本格式结构: <工具名称> @@ -27,7 +26,9 @@ JSON_GEN: dict[LanguageType, str] = { * 数字:10 * 布尔值:true * 数组(重复标签):项目1项目2 - * 对象(嵌套标签):{% endif %} + * 对象(嵌套标签):{% else %} + - 必须使用提供的工具来回答查询 + - 工具调用必须遵循工具架构中定义的JSON格式{% endif %} {% if use_xml_format %} @@ -36,72 +37,57 @@ JSON_GEN: dict[LanguageType, str] = { 杭州的天气怎么样? - - - get_weather: 获取指定城市的当前天气信息 - - - { - "name": "get_weather", - "description": "获取指定城市的当前天气信息", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "要查询天气的城市名称" + + + get_weather + 获取指定城市的当前天气信息 + + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "要查询天气的城市名称" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "温度单位" + } }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "温度单位" - }, - "include_forecast": { - "type": "boolean", - "description": "是否包含预报数据" - } - }, - "required": ["city"] - } - } - - + "required": ["city"] + } + + + 助手响应: 杭州 celsius - false - {% endif %} {{ query }} - {% if previous_trial %} - - - - 你之前的工具调用有不正确的参数。 - - - {{ previous_trial }} - - - {{ err_info }} - - - {% endif %} - - - - {{ tool_descriptions }} - - - {{ tool_schemas }} - - + + {% for func in functions %} + + {{ func.name }} + {{ func.description }} + {{ func.parameters | tojson(indent=2) }} + {% endfor %} + {% else %} + + ## 可用工具 + {% for func in functions %} + - **{{ func.name }}**: {{ func.description }} + {% endfor %} + + 请根据用户查询选择合适的工具来回答问题。你必须使用上述工具之一来处理查询。 + + **用户查询**: {{ query }}{% endif %} """, ), LanguageType.ENGLISH: dedent( @@ -111,11 +97,9 @@ JSON_GEN: dict[LanguageType, str] = { - You have access to tools that can help gather information - - Use tools step-by-step, with each use informed by previous results + - Use tools step-by-step, with each use informed by previous results{% if use_xml_format %} - The user's query is provided in the tags - {% if previous_trial %}- Review the information to avoid \ -repeating mistakes{% endif %} - {% if use_xml_format %}- Format tool calls using XML-style tags where the tool name is the root tag \ + - Format tool calls using XML-style tags where the tool name is the root tag \ and each parameter is a nested tag - Use the exact tool name and parameter names as specified in the schema - Basic format structure: @@ -127,7 +111,9 @@ and each parameter is a nested tag * Number: 10 * Boolean: true * Array (repeat tags): item1item2 - * Object (nest tags): value{% endif %} + * Object (nest tags): value{% else %} + - You must use the provided tools to answer the query + - Tool calls must follow the JSON format defined in the tool schemas{% endif %} {% if use_xml_format %} @@ -136,72 +122,58 @@ and each parameter is a nested tag What is the weather like in Hangzhou? - - - get_weather: Get current weather information for a specified city - - - { - "name": "get_weather", - "description": "Get current weather information for a specified city", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name to query weather for" + + + get_weather + Get current weather information for a specified city + + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name to query weather for" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit" + } }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "Temperature unit" - }, - "include_forecast": { - "type": "boolean", - "description": "Whether to include forecast data" - } - }, - "required": ["city"] - } - } - - + "required": ["city"] + } + + + Assistant response: Hangzhou celsius - false - {% endif %} {{ query }} - {% if previous_trial %} - - - - Your previous tool call had incorrect arguments. - - - {{ previous_trial }} - - - {{ err_info }} - - - {% endif %} - - - - {{ tool_descriptions }} - - - {{ tool_schemas }} - - + + {% for func in functions %} + + {{ func.name }} + {{ func.description }} + {{ func.parameters | tojson(indent=2) }} + {% endfor %} + {% else %} + + ## Available Tools + {% for func in functions %} + - **{{ func.name }}**: {{ func.description }} + {% endfor %} + + Please select the appropriate tool(s) from above to answer the user's query. + You must use one of the tools listed above to process the query. + + **User Query**: {{ query }}{% endif %} """, ), } diff --git a/apps/llm/providers/base.py b/apps/llm/providers/base.py index 207fa0391c102b41a0c38d70b40a418454359f90..5c5757ae8622f38397e32601bc8bc4d00f7e75f1 100644 --- a/apps/llm/providers/base.py +++ b/apps/llm/providers/base.py @@ -51,10 +51,11 @@ class BaseProvider: async def chat( self, messages: list[dict[str, str]], *, include_thinking: bool = False, + streaming: bool = True, tools: list[LLMFunctions] | None = None, ) -> AsyncGenerator[LLMChunk, None]: """聊天""" - raise NotImplementedError + yield LLMChunk(content="") async def embedding(self, text: list[str]) -> list[list[float]]: diff --git a/apps/llm/providers/ollama.py b/apps/llm/providers/ollama.py index e721510694c2f7d264e6dea72b06e88af1c1047d..4b668536a681454110d1e6cfbd547ffac8905891 100644 --- a/apps/llm/providers/ollama.py +++ b/apps/llm/providers/ollama.py @@ -77,53 +77,25 @@ class OllamaProvider(BaseProvider): "content": "" + self.full_thinking + "" + self.full_answer, }]) - @override - async def chat( - self, messages: list[dict[str, str]], - *, include_thinking: bool = False, - tools: list[LLMFunctions] | None = None, - ) -> AsyncGenerator[LLMChunk, None]: - # 检查能力 - if not self._allow_chat: - err = "[OllamaProvider] 当前模型不支持Chat" - _logger.error(err) - raise RuntimeError(err) - - # 初始化Token计数 - self.input_tokens = 0 - self.output_tokens = 0 - - # 检查消息 - messages = self._validate_messages(messages) - - # 流式返回响应 - last_chunk = None - chat_kwargs = { - "model": self.config.modelName, - "messages": messages, - "options": { - "temperature": self.config.temperature, - "num_predict": self.config.maxToken, + def _build_tools_param(self, tools: list[LLMFunctions]) -> list[dict[str, Any]]: + """构建工具参数""" + return [{ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.param_schema, }, - "stream": True, - **self.config.extraConfig, - } + } for tool in tools] - # 如果提供了tools,则传入以启用function-calling模式 - if tools: - functions = [] - for tool in tools: - functions += [{ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.param_schema, - }, - }] - chat_kwargs["tools"] = functions - - async for chunk in await self._client.chat(**chat_kwargs): + async def _handle_streaming_response( + self, response: AsyncGenerator[ChatResponse, None], + messages: list[dict[str, str]], + tools: list[LLMFunctions] | None, + ) -> AsyncGenerator[LLMChunk, None]: + """处理流式响应""" + last_chunk = None + async for chunk in response: last_chunk = chunk if chunk.message.thinking: self.full_thinking += chunk.message.thinking @@ -145,6 +117,85 @@ class OllamaProvider(BaseProvider): # 使用最后一个chunk的usage数据 self._process_usage_data(last_chunk, messages) + async def _handle_non_streaming_response( + self, response: ChatResponse, + messages: list[dict[str, str]], + ) -> AsyncGenerator[LLMChunk, None]: + """处理非流式响应""" + if response.message.thinking: + self.full_thinking = response.message.thinking + if response.message.content: + self.full_answer = response.message.content + + # 处理usage数据 + self._process_usage_data(response if response.done else None, messages) + + # 一次性返回完整结果 + yield LLMChunk( + content=self.full_answer, + reasoning_content=self.full_thinking, + ) + + def _build_chat_kwargs( + self, messages: list[dict[str, str]], + *, + streaming: bool, + tools: list[LLMFunctions] | None, + ) -> dict[str, Any]: + """构建chat请求参数""" + chat_kwargs = { + "model": self.config.modelName, + "messages": messages, + "options": { + "temperature": self.config.temperature, + "num_predict": self.config.maxToken, + }, + "stream": streaming, + **self.config.extraConfig, + } + + # 如果提供了tools,则传入以启用function-calling模式 + if tools: + chat_kwargs["tools"] = self._build_tools_param(tools) + + return chat_kwargs + + @override + async def chat( + self, messages: list[dict[str, str]], + tools: list[LLMFunctions] | None = None, + *, include_thinking: bool = False, + streaming: bool = True, + ) -> AsyncGenerator[LLMChunk, None]: + # 检查能力 + if not self._allow_chat: + err = "[OllamaProvider] 当前模型不支持Chat" + _logger.error(err) + raise RuntimeError(err) + + # 初始化Token计数和累积内容 + self.input_tokens = 0 + self.output_tokens = 0 + self.full_thinking = "" + self.full_answer = "" + + # 检查消息 + messages = self._validate_messages(messages) + + # 构建chat请求参数 + chat_kwargs = self._build_chat_kwargs(messages, streaming=streaming, tools=tools) + + # 发送请求 + response = await self._client.chat(**chat_kwargs) + + # 根据streaming模式处理响应 + if streaming: + async for chunk in self._handle_streaming_response(response, messages, tools): + yield chunk + else: + async for chunk in self._handle_non_streaming_response(response, messages): + yield chunk + def _seq_to_list(self, seq: Sequence[Any]) -> list[Any]: """将Sequence转换为list""" result = [] @@ -168,5 +219,3 @@ class OllamaProvider(BaseProvider): input=text, ) return self._seq_to_list(result.embeddings) - - diff --git a/apps/llm/providers/openai.py b/apps/llm/providers/openai.py index b14e9d4b9bea6f35227f819252959811dd6fb0e5..9de76a29e2d6ad3d48efc0085010df813e7be431 100644 --- a/apps/llm/providers/openai.py +++ b/apps/llm/providers/openai.py @@ -8,6 +8,7 @@ from typing import cast import httpx from openai import AsyncOpenAI, AsyncStream from openai.types.chat import ( + ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, ) @@ -92,11 +93,159 @@ class OpenAIProvider(BaseProvider): "content": "" + self.full_thinking + "" + self.full_answer, }]) + def _handle_usage_response(self, response: ChatCompletion, messages: list[dict[str, str]]) -> None: + """处理非流式响应的usage信息""" + if hasattr(response, "usage") and response.usage: + self.input_tokens = response.usage.prompt_tokens + self.output_tokens = response.usage.completion_tokens + else: + # 回退到本地估算 + self.input_tokens = token_calculator.calculate_token_length(messages) + self.output_tokens = token_calculator.calculate_token_length([{ + "role": "assistant", + "content": "" + self.full_thinking + "" + self.full_answer, + }]) + + def _build_request_kwargs( + self, + messages: list[dict[str, str]], + *, + streaming: bool, + ) -> dict: + """构建请求参数""" + request_kwargs = { + "model": self.config.modelName, + "messages": self._convert_messages(messages), + "max_tokens": self.config.maxToken, + "temperature": self.config.temperature, + "stream": streaming, + **self.config.extraConfig, + } + + # 流式模式下添加usage统计选项 + if streaming: + request_kwargs["stream_options"] = {"include_usage": True} + + return request_kwargs + + def _add_tools_to_request( + self, + request_kwargs: dict, + tools: list[LLMFunctions], + ) -> None: + """将工具添加到请求参数中""" + functions = [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.param_schema, + }, + } + for tool in tools + ] + request_kwargs["tools"] = functions + + async def _process_streaming_chunk( + self, + chunk: ChatCompletionChunk, + *, + include_thinking: bool, + ) -> AsyncGenerator[LLMChunk, None]: + """处理单个流式响应chunk""" + if not hasattr(chunk, "choices") or not chunk.choices: + return + + delta = chunk.choices[0].delta + + # 处理reasoning_content + if ( + hasattr(delta, "reasoning_content") and + getattr(delta, "reasoning_content", None) and + include_thinking + ): + reasoning_content = getattr(delta, "reasoning_content", "") + self.full_thinking += reasoning_content + yield LLMChunk(reasoning_content=reasoning_content) + + # 处理content + if hasattr(delta, "content") and delta.content: + self.full_answer += delta.content + yield LLMChunk(content=delta.content) + + async def _handle_streaming_response( + self, + request_kwargs: dict, + messages: list[dict[str, str]], + *, + include_thinking: bool, + ) -> AsyncGenerator[LLMChunk, None]: + """处理流式响应""" + stream: AsyncStream[ChatCompletionChunk] = await self._client.chat.completions.create(**request_kwargs) + last_chunk = None + + async for chunk in stream: + last_chunk = chunk + async for llm_chunk in self._process_streaming_chunk( + chunk, + include_thinking=include_thinking, + ): + yield llm_chunk + + # 处理最后一个Chunk的usage + self._handle_usage_chunk(last_chunk, messages) + + async def _handle_non_streaming_response( + self, + request_kwargs: dict, + messages: list[dict[str, str]], + tools: list[LLMFunctions] | None, + *, + include_thinking: bool, + ) -> AsyncGenerator[LLMChunk, None]: + """处理非流式响应""" + response: ChatCompletion = await self._client.chat.completions.create(**request_kwargs) + + tool_call_dict = {} + if response.choices: + message = response.choices[0].message + + # 处理reasoning_content + if ( + hasattr(message, "reasoning_content") and + getattr(message, "reasoning_content", None) and + include_thinking + ): + self.full_thinking = getattr(message, "reasoning_content", "") + + # 处理content + if hasattr(message, "content") and message.content: + self.full_answer = message.content + + # 处理工具调用 + if tools and hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + if tool_call.type == "function" and hasattr(tool_call, "function"): + func = tool_call.function + tool_call_dict[func.name] = func.arguments + + # 处理usage数据 + self._handle_usage_response(response, messages) + + # 一次性返回完整结果 + yield LLMChunk( + content=self.full_answer, + reasoning_content=self.full_thinking, + tool_call=tool_call_dict if tool_call_dict else None, + ) + @override - async def chat( # noqa: C901 + async def chat( self, messages: list[dict[str, str]], - *, include_thinking: bool = False, tools: list[LLMFunctions] | None = None, + *, include_thinking: bool = False, + streaming: bool = True, ) -> AsyncGenerator[LLMChunk, None]: """聊天""" # 检查能力 @@ -105,73 +254,38 @@ class OpenAIProvider(BaseProvider): _logger.error(err) raise RuntimeError(err) - # 初始化Token计数 + # 初始化Token计数和累积内容 self.input_tokens = 0 self.output_tokens = 0 + self.full_thinking = "" + self.full_answer = "" # 检查消息 messages = self._validate_messages(messages) - request_kwargs = { - "model": self.config.modelName, - "messages": self._convert_messages(messages), - "max_tokens": self.config.maxToken, - "temperature": self.config.temperature, - "stream": True, - "stream_options": {"include_usage": True}, - **self.config.extraConfig, - } + # 构建请求参数 + request_kwargs = self._build_request_kwargs(messages, streaming=streaming) # 如果提供了tools,则启用function-calling模式 if tools: - functions = [] - for tool in tools: - functions += [{ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.param_schema, - }, - }] - request_kwargs["tools"] = functions + self._add_tools_to_request(request_kwargs, tools) - stream: AsyncStream[ChatCompletionChunk] = await self._client.chat.completions.create(**request_kwargs) - # 流式返回响应 - last_chunk = None - async for chunk in stream: - last_chunk = chunk - if hasattr(chunk, "choices") and chunk.choices: - delta = chunk.choices[0].delta - if ( - hasattr(delta, "reasoning_content") and - getattr(delta, "reasoning_content", None) and - include_thinking - ): - reasoning_content = getattr(delta, "reasoning_content", "") - self.full_thinking += reasoning_content - yield LLMChunk(reasoning_content=reasoning_content) - - if ( - hasattr(chunk.choices[0].delta, "content") and - chunk.choices[0].delta.content - ): - self.full_answer += chunk.choices[0].delta.content - yield LLMChunk(content=chunk.choices[0].delta.content) - - # 在chat中统一处理工具调用分块(当提供了tools时) - if tools and hasattr(delta, "tool_calls") and delta.tool_calls: - tool_call_dict = {} - for tool_call in delta.tool_calls: - if hasattr(tool_call, "function") and tool_call.function: - tool_call_dict.update({ - tool_call.function.name: tool_call.function.arguments, - }) - if tool_call_dict: - yield LLMChunk(tool_call=tool_call_dict) - - # 处理最后一个Chunk的usage(仅在最后一个chunk会出现) - self._handle_usage_chunk(last_chunk, messages) + # 根据模式选择处理方式 + if streaming: + async for chunk in self._handle_streaming_response( + request_kwargs, + messages, + include_thinking=include_thinking, + ): + yield chunk + else: + async for chunk in self._handle_non_streaming_response( + request_kwargs, + messages, + tools, + include_thinking=include_thinking, + ): + yield chunk @override async def embedding(self, text: list[str]) -> list[list[float]]: diff --git a/apps/main.py b/apps/main.py index 9f26b4a47463ec5671c8a21a2bd5e1759b2b81df..0241c0019503ff259bd4d086a34142d8b10198e5 100644 --- a/apps/main.py +++ b/apps/main.py @@ -33,6 +33,7 @@ from .routers import ( user, ) from .scheduler.pool.pool import pool +from .services.settings import SettingsManager @asynccontextmanager @@ -41,6 +42,8 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: # 启动时初始化资源 await postgres.init() await pool.init() + # 初始化全局LLM设置 + await SettingsManager.init_global_llm_settings() yield # 关闭时释放资源 await postgres.close() diff --git a/apps/models/__init__.py b/apps/models/__init__.py index 6a0fcc3c016a92b70ab7cd134e249eef819ad5c2..4c0301575ba0bd2f7d23a5b79e3f447e6292551c 100644 --- a/apps/models/__init__.py +++ b/apps/models/__init__.py @@ -11,10 +11,10 @@ from .flow import Flow from .llm import LLMData, LLMProvider, LLMType from .mcp import MCPActivated, MCPInfo, MCPInstallStatus, MCPTools, MCPType from .node import NodeInfo -from .settings import GlobalSettings from .record import Record, RecordMetadata from .service import Service, ServiceACL, ServiceHashes from .session import Session, SessionActivity, SessionType +from .settings import GlobalSettings from .tag import Tag from .task import ( ExecutorCheckpoint, diff --git a/apps/models/task.py b/apps/models/task.py index 6d2f33a3e2b2b85235f7f6eed0edc530325a075e..51df70961b3236aa1734f78efb17e9d22d127545 100644 --- a/apps/models/task.py +++ b/apps/models/task.py @@ -98,10 +98,10 @@ class TaskRuntime(Base): """时间""" fullTime: Mapped[float] = mapped_column(Float, default=0.0, nullable=False) # noqa: N815 """完整时间成本""" - sessionId: Mapped[str | None] = mapped_column( # noqa: N815 - String(255), ForeignKey("framework_session.id"), nullable=True, default=None, + authHeader: Mapped[str | None] = mapped_column( # noqa: N815 + String(255), nullable=True, default=None, ) - """会话ID""" + """认证头""" userInput: Mapped[str] = mapped_column(Text, nullable=False, default="") # noqa: N815 """用户输入""" fullAnswer: Mapped[str] = mapped_column(Text, nullable=False, default="") # noqa: N815 diff --git a/apps/routers/auth.py b/apps/routers/auth.py index b74bf83e5c44ae7d8d6db3e10aea99395feaf9b1..5216aae960f193cb9dd7e5cf71494f4c15bcebd7 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -48,9 +48,9 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: token = await oidc_provider.get_oidc_token(code) user_info = await oidc_provider.get_oidc_user(token["access_token"]) - user_sub: str | None = user_info.get("user_sub", None) - if user_sub: - await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"]) + user_id: str | None = user_info.get("user_sub", None) + if user_id: + await oidc_provider.set_token(user_id, token["access_token"], token["refresh_token"]) except Exception as e: _logger.exception("User login failed") status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN @@ -68,7 +68,7 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: ) user_host = request.client.host # 获取用户sub - if not user_sub: + if not user_id: _logger.error("OIDC no user_sub associated.") return _templates.TemplateResponse( "login_failed.html.j2", @@ -76,9 +76,10 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: status_code=status.HTTP_403_FORBIDDEN, ) # 创建或更新用户登录信息 - await UserManager.create_or_update_on_login(user_sub) + user_name: str | None = user_info.get("user_name", None) + await UserManager.create_or_update_on_login(user_id, user_name) # 创建会话 - current_session = await SessionManager.create_session(user_sub, user_host) + current_session = await SessionManager.create_session(user_id, user_host) return _templates.TemplateResponse( "login_success.html.j2", {"request": request, "current_session": current_session}, @@ -97,7 +98,7 @@ async def logout(request: Request) -> JSONResponse: result={}, ).model_dump(exclude_none=True, by_alias=True), ) - await TokenManager.delete_plugin_token(request.state.user_sub) + await TokenManager.delete_plugin_token(request.state.user_id) if hasattr(request.state, "session_id"): await SessionManager.delete_session(request.state.session_id) @@ -145,7 +146,7 @@ async def oidc_logout(token: Annotated[str, Body()]) -> JSONResponse: }, response_model=PostPersonalTokenRsp) async def change_personal_token(request: Request) -> JSONResponse: """POST /auth/key: 重置用户的API密钥""" - new_api_key: str | None = await PersonalTokenManager.update_personal_token(request.state.user_sub) + new_api_key: str | None = await PersonalTokenManager.update_personal_token(request.state.user_id) if not new_api_key: return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 530c9c137af374c6d42dd9c3f786489b10a21681..48f5ad3c84820cc570ebb464dd3ce6ed91ef9bf4 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -29,26 +29,28 @@ router = APIRouter( ], ) -async def chat_generator(post_body: RequestData, user_id: str, session_id: str) -> AsyncGenerator[str, None]: +async def chat_generator(post_body: RequestData, user_id: str, auth_header: str) -> AsyncGenerator[str, None]: """进行实际问答,并从MQ中获取消息""" try: - await Activity.set_active(user_id) - # 敏感词检查 if await words_check.check(post_body.question) != 1: yield "data: [SENSITIVE]\n\n" _logger.info("[Chat] 问题包含敏感词!") - await Activity.remove_active(user_id) return # 创建queue;由Scheduler进行关闭 queue = MessageQueue() await queue.init() - # 在单独Task中运行Scheduler,拉齐queue.get的时机 + # 1. 先初始化 Scheduler(确保初始化成功) scheduler = Scheduler() - await scheduler.init(queue, post_body, user_id) - _logger.info(f"[Chat] 用户是否活跃: {await Activity.is_active(user_id)}") + await scheduler.init(queue, post_body, user_id, auth_header) + + # 2. Scheduler初始化成功后,再设置用户处于活动状态 + await Activity.set_active(user_id) + _logger.info(f"[Chat] 用户是否活跃: {await Activity.can_active(user_id)}") + + # 3. 最后判断并运行 Executor scheduler_task = asyncio.create_task(scheduler.run()) # 处理每一条消息 @@ -95,7 +97,7 @@ async def chat_generator(post_body: RequestData, user_id: str, session_id: str) @router.post("/chat") async def chat(request: Request, post_body: RequestData) -> StreamingResponse: """LLM流式对话接口""" - session_id = request.state.session_id + auth_header = request.headers.get("Authorization", "").replace("Bearer ", "") user_id = request.state.user_id # 问题黑名单检测 if (post_body.question is not None) and ( @@ -105,7 +107,7 @@ async def chat(request: Request, post_body: RequestData) -> StreamingResponse: await UserBlacklistManager.change_blacklisted_users(user_id, -10) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted") - res = chat_generator(post_body, user_id, session_id) + res = chat_generator(post_body, user_id, auth_header) return StreamingResponse( content=res, media_type="text/event-stream", diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index fca0d4ab344b89dced98c7aab740e49279705590..c677d70573d57a8e066f2a9b4a15a10d9665542e 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -6,6 +6,7 @@ Core Call类是定义了所有Call都应具有的方法和参数的PyDantic类 """ import logging +import uuid from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any, ClassVar, Self @@ -89,8 +90,8 @@ class CoreCall(BaseModel): logger.error(err) raise ValueError(err) - history = {} - history_order = [] + history: dict[uuid.UUID, ExecutorHistory] = {} + history_order: list[uuid.UUID] = [] for item in executor.task.context: history[item.stepId] = item history_order.append(item.stepId) @@ -100,7 +101,7 @@ class CoreCall(BaseModel): ids=CallIds( task_id=executor.task.metadata.id, executor_id=executor.task.state.executorId, - auth_header=executor.task.runtime.sessionId, + auth_header=executor.task.runtime.authHeader, user_id=executor.task.metadata.userId, app_id=executor.task.state.appId, conversation_id=executor.task.metadata.conversationId, @@ -114,12 +115,12 @@ class CoreCall(BaseModel): @staticmethod - def _extract_history_variables(path: str, history: dict[str, ExecutorHistory]) -> Any: + def _extract_history_variables(path: str, history: dict[uuid.UUID, ExecutorHistory]) -> Any: """ 提取History中的变量 :param path: 路径,格式为:step_id/key/to/variable - :param history: Step历史,即call_vars.history + :param history: Step历史,即call_vars.step_data :return: 变量 """ split_path = path.split("/") @@ -127,12 +128,21 @@ class CoreCall(BaseModel): err = f"[CoreCall] 路径格式错误: {path}" logger.error(err) return None - if split_path[0] not in history: - err = f"[CoreCall] 步骤{split_path[0]}不存在" + + # 将字符串形式的步骤ID转换为UUID + try: + step_id = uuid.UUID(split_path[0]) + except ValueError: + err = f"[CoreCall] 步骤ID格式错误: {split_path[0]}" + logger.exception(err) + return None + + if step_id not in history: + err = f"[CoreCall] 步骤{step_id}不存在" logger.error(err) return None - data = history[split_path[0]].outputData + data = history[step_id].outputData for key in split_path[1:]: if key not in data: err = f"[CoreCall] 输出Key {key} 不存在" diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index 8140fb04ae973b2ddb174bb279f94f613e40d3e7..a3dc92b890bdb04520b565eb40b090be55e559f6 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -4,7 +4,7 @@ import logging from collections.abc import AsyncGenerator from datetime import datetime -from typing import Any +from typing import TYPE_CHECKING, Any import pytz from jinja2 import BaseLoader @@ -24,6 +24,9 @@ from apps.schemas.scheduler import ( from .prompt import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT from .schema import LLMInput, LLMOutput +if TYPE_CHECKING: + from apps.models.task import ExecutorHistory + logger = logging.getLogger(__name__) @@ -68,7 +71,7 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): ) # 上下文信息 - step_history = [] + step_history: list[ExecutorHistory] = [] for ids in call_vars.step_order[-self.step_history_size:]: step_history += [call_vars.step_data[ids]] diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py index 73e88cb936fbddf54d4f65cc98e2136458af05d6..aaf6c9d9c8424aae82ecb72d78af544ed64d4f1e 100644 --- a/apps/scheduler/call/rag/rag.py +++ b/apps/scheduler/call/rag/rag.py @@ -136,19 +136,25 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): """获取文档信息,支持临时文档和知识库文档""" doc_chunk_list: list[DocItem] = [] - # 处理临时文档 + # 处理临时文档:若docIds为空,则不请求 if doc_ids: tmp_data = deepcopy(data) tmp_data.kbIds = [uuid.UUID("00000000-0000-0000-0000-000000000000")] tmp_data.docIds = doc_ids + _logger.info("[RAG] 获取临时文档: %s", tmp_data.docIds) doc_chunk_list.extend(await self._fetch_doc_chunks(tmp_data)) + else: + _logger.info("[RAG] docIds为空,跳过临时文档请求") - # 处理知识库ID的情况 + # 处理知识库:若kbIds为空,则不请求 if data.kbIds: kb_data = deepcopy(data) # 知识库查询时不使用docIds,只使用kbIds kb_data.docIds = None + _logger.info("[RAG] 获取知识库文档: %s", kb_data.kbIds) doc_chunk_list.extend(await self._fetch_doc_chunks(kb_data)) + else: + _logger.info("[RAG] kbIds为空,跳过知识库请求") return doc_chunk_list diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index a3930924b2a0c5e5fbed4739876591e4426832f8..af9afbc14bfc6025ee367826914e64762009b910 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -20,6 +20,7 @@ from .prompt import SLOT_GEN_PROMPT, SLOT_HISTORY_TEMPLATE, SLOT_SUMMARY_TEMPLAT from .schema import SlotInput, SlotOutput if TYPE_CHECKING: + from apps.models.task import ExecutorHistory from apps.scheduler.executor.step import StepExecutor @@ -148,7 +149,7 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): async def _init(self, call_vars: CallVars) -> SlotInput: """初始化""" - self._flow_history = [] + self._flow_history: list[ExecutorHistory] = [] self._question = call_vars.question for key in call_vars.step_order[:-self.step_num]: self._flow_history += [call_vars.step_data[key]] diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index eb5c594d911dbc60602121fb019d985b0cd4c356..59c82a8e02d83d3989263dd08fdbb17a401d03af 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -2,29 +2,27 @@ """Executor基类""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import Any from pydantic import BaseModel, ConfigDict +from apps.common.queue import MessageQueue from apps.common.security import Security +from apps.llm import LLM from apps.schemas.enum_var import EventType from apps.schemas.message import TextAddContent from apps.schemas.record import RecordContent from apps.schemas.scheduler import ExecutorBackground +from apps.schemas.task import TaskData from apps.services.record import RecordManager -if TYPE_CHECKING: - from apps.common.queue import MessageQueue - from apps.llm import LLM - from apps.schemas.task import TaskData - class BaseExecutor(BaseModel, ABC): """Executor基类""" - task: "TaskData" - msg_queue: "MessageQueue" - llm: "LLM" + task: TaskData + msg_queue: MessageQueue + llm: LLM question: str diff --git a/apps/scheduler/executor/qa.py b/apps/scheduler/executor/qa.py index bb6228188e8fee780ff471b4a3a15e8c596a6d1e..88e2442c4cd26ea201cb9f5354de58599a2950e0 100644 --- a/apps/scheduler/executor/qa.py +++ b/apps/scheduler/executor/qa.py @@ -5,7 +5,7 @@ import logging import uuid from datetime import UTC, datetime -from apps.models import ExecutorCheckpoint, ExecutorStatus, StepStatus +from apps.models import ExecutorCheckpoint, ExecutorStatus, StepStatus, StepType from apps.models.task import LanguageType from apps.scheduler.call.rag.schema import DocItem, RAGOutput from apps.schemas.document import DocumentInfo @@ -23,13 +23,13 @@ _RAG_STEP_LIST = [ name="RAG检索", description="从知识库中检索相关文档", node="RAG", - type="RAG", + type=StepType.RAG.value, ), LanguageType.ENGLISH: Step( name="RAG retrieval", description="Retrieve relevant documents from the knowledge base", node="RAG", - type="RAG", + type=StepType.RAG.value, ), }, { @@ -37,13 +37,13 @@ _RAG_STEP_LIST = [ name="LLM问答", description="基于检索到的文档生成答案", node="LLM", - type="LLM", + type=StepType.LLM.value, ), LanguageType.ENGLISH: Step( name="LLM answer", description="Generate answer based on the retrieved documents", node="LLM", - type="LLM", + type=StepType.LLM.value, ), }, { @@ -51,13 +51,13 @@ _RAG_STEP_LIST = [ name="问题推荐", description="根据对话答案,推荐相关问题", node="Suggestion", - type="Suggestion", + type=StepType.SUGGEST.value, ), LanguageType.ENGLISH: Step( name="Question Suggestion", description="Display the suggested next question under the answer", node="Suggestion", - type="Suggestion", + type=StepType.SUGGEST.value, ), }, { @@ -151,6 +151,7 @@ class QAExecutor(BaseExecutor): if self.task.state and self.task.state.stepStatus == StepStatus.ERROR: _logger.error("[QAExecutor] RAG检索步骤失败") + self.task.state.executorStatus = ExecutorStatus.ERROR return False rag_output_data = None @@ -199,6 +200,7 @@ class QAExecutor(BaseExecutor): if self.task.state and self.task.state.stepStatus == StepStatus.ERROR: _logger.error("[QAExecutor] LLM问答步骤失败") + self.task.state.executorStatus = ExecutorStatus.ERROR return False _logger.info("[QAExecutor] LLM问答步骤完成") @@ -225,6 +227,7 @@ class QAExecutor(BaseExecutor): if self.task.state and self.task.state.stepStatus == StepStatus.ERROR: _logger.error("[QAExecutor] 问题推荐步骤失败") + self.task.state.executorStatus = ExecutorStatus.ERROR return False _logger.info("[QAExecutor] 开始执行记忆存储步骤") @@ -246,6 +249,7 @@ class QAExecutor(BaseExecutor): if self.task.state and self.task.state.stepStatus == StepStatus.ERROR: _logger.error("[QAExecutor] 记忆存储步骤失败") + self.task.state.executorStatus = ExecutorStatus.ERROR return False return True diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index dbec9b4505209d607ea9cceab55fcbb9bfdb7126..2d9d72d87077a455cb04d806d8bbb295960b01b1 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -6,11 +6,11 @@ import logging import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any +from typing import Any from pydantic import ConfigDict -from apps.models import ExecutorHistory, StepStatus, StepType +from apps.models import ExecutorHistory, ExecutorStatus, StepStatus from apps.scheduler.call.core import CoreCall from apps.scheduler.call.empty import Empty from apps.scheduler.call.facts.facts import FactsCall @@ -24,21 +24,19 @@ from apps.schemas.enum_var import ( ) from apps.schemas.message import TextAddContent from apps.schemas.scheduler import CallError, CallOutputChunk, ExecutorBackground +from apps.schemas.task import StepQueueItem from apps.services.node import NodeManager from .base import BaseExecutor -if TYPE_CHECKING: - from apps.schemas.task import StepQueueItem - logger = logging.getLogger(__name__) class StepExecutor(BaseExecutor): """工作流中步骤相关函数""" - step: "StepQueueItem" - background: "ExecutorBackground" + step: StepQueueItem + background: ExecutorBackground model_config = ConfigDict( arbitrary_types_allowed=True, @@ -91,7 +89,7 @@ class StepExecutor(BaseExecutor): # State写入ID和运行状态 self.task.state.stepId = self.step.step_id - self.task.state.stepType = str(StepType(self.step.step.type)) + self.task.state.stepType = self.step.step.type self.task.state.stepName = self.step.step.name # 获取并验证Call类 @@ -235,6 +233,7 @@ class StepExecutor(BaseExecutor): except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") self.task.state.stepStatus = StepStatus.ERROR + self.task.state.executorStatus = ExecutorStatus.ERROR await self._push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): self.task.state.errorMessage = { diff --git a/apps/scheduler/scheduler/init.py b/apps/scheduler/scheduler/init.py index 05aba7cf1fb69d80be44bded928ae7f7dd0464b7..c7020f3648b754562686a0a1a568db3c7ce2b169 100644 --- a/apps/scheduler/scheduler/init.py +++ b/apps/scheduler/scheduler/init.py @@ -40,8 +40,9 @@ class InitMixin: raise ValueError(err) return LLM(reasoning_llm) - def _create_new_task(self, task_id: uuid.UUID, user_id: str, conversation_id: uuid.UUID | None) -> None: + def _create_new_task(self, user_id: str, conversation_id: uuid.UUID | None, auth_header: str) -> None: """创建新的TaskData""" + task_id = uuid.uuid4() self.task = TaskData( metadata=Task( id=task_id, @@ -50,18 +51,19 @@ class InitMixin: ), runtime=TaskRuntime( taskId=task_id, + authHeader=auth_header, ), state=None, context=[], ) - async def _init_task(self, task_id: uuid.UUID, user_id: str) -> None: + async def _init_task(self, user_id: str, auth_header: str) -> None: """初始化Task""" conversation_id = self.post_body.conversation_id if not conversation_id: _logger.info("[Scheduler] 无Conversation ID,直接创建新任务") - self._create_new_task(task_id, user_id, None) + self._create_new_task(user_id, None, auth_header) return _logger.info("[Scheduler] 尝试从Conversation ID %s 恢复任务", conversation_id) @@ -80,10 +82,8 @@ class InitMixin: task_data = await TaskManager.get_task_data_by_task_id(last_task.id) if task_data: self.task = task_data - self.task.metadata.id = task_id - self.task.runtime.taskId = task_id - if self.task.state: - self.task.state.taskId = task_id + # 恢复任务时保留原有的 task_id,但更新 authHeader + self.task.runtime.authHeader = auth_header return else: _logger.warning( @@ -94,7 +94,7 @@ class InitMixin: _logger.exception("[Scheduler] 从Conversation恢复任务失败,创建新任务") _logger.info("[Scheduler] 无法恢复任务,创建新任务") - self._create_new_task(task_id, user_id, conversation_id) + self._create_new_task(user_id, conversation_id, auth_header) async def _get_user(self, user_id: str) -> None: """初始化用户""" diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 28b513dfaf718593cafc2d630be6b2c699f22009..41d4558ba394748a89b92bce5a0d27ade9ade1cd 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -42,17 +42,17 @@ class Scheduler( async def init( self, - task_id: uuid.UUID, queue: MessageQueue, post_body: RequestData, user_id: str, + auth_header: str, ) -> None: """初始化""" self.queue = queue self.post_body = post_body await self._get_user(user_id) - await self._init_task(task_id, user_id) + await self._init_task(user_id, auth_header) self._init_jinja2_env() self.llm = await self._get_scheduler_llm(post_body.llm_id) diff --git a/apps/scheduler/scheduler/util.py b/apps/scheduler/scheduler/util.py index 0f710155da3bf6053eef3e7128da20bcde04386f..524a576ef580ba8d64d150b8077a26c17a30146a 100644 --- a/apps/scheduler/scheduler/util.py +++ b/apps/scheduler/scheduler/util.py @@ -55,9 +55,9 @@ class UtilMixin: check_interval = 0.5 while not kill_event.is_set(): - is_active = await Activity.is_active(user_id) + can_active = await Activity.can_active(user_id) - if not is_active: + if not can_active: _logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", user_id) kill_event.set() break diff --git a/apps/schemas/message.py b/apps/schemas/message.py index e149adb1895a44efd27133daf70ef8e0ee0bb51e..1b742e2b24269c1ce040e407d4af043de4492cf3 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -57,8 +57,8 @@ class TextAddContent(BaseModel): class MessageBase(HeartbeatData): """基础消息事件结构""" - id: uuid.UUID = Field(min_length=36, max_length=36) - conversation_id: uuid.UUID | None = Field(min_length=36, max_length=36, alias="conversationId", default=None) + id: uuid.UUID + conversation_id: uuid.UUID | None = Field(alias="conversationId", default=None) flow: MessageExecutor | None = None content: Any | None = Field(default=None, description="消息内容") metadata: MessageMetadata diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py index 0d8e4748dc473aa2fcdd82e3db05e58b6a36bece..cf3e62dce770e65bc349fc4e04fcc4db205a2d35 100644 --- a/apps/schemas/scheduler.py +++ b/apps/schemas/scheduler.py @@ -43,8 +43,8 @@ class CallVars(BaseModel): thinking: str = Field(description="上下文信息") question: str = Field(description="改写后的用户输入") - step_data: dict[str, ExecutorHistory] = Field(description="Executor中历史工具的结构化数据", default={}) - step_order: list[str] = Field(description="Executor中历史工具的顺序", default=[]) + step_data: dict[uuid.UUID, ExecutorHistory] = Field(description="Executor中历史工具的结构化数据", default={}) + step_order: list[uuid.UUID] = Field(description="Executor中历史工具的顺序", default=[]) background: ExecutorBackground = Field(description="Executor的背景信息") ids: CallIds = Field(description="Call的ID") language: LanguageType = Field(description="语言", default=LanguageType.CHINESE) diff --git a/apps/services/activity.py b/apps/services/activity.py index c83db3266523eb11d7acfe3343c70a8aa72c2029..687e9af75bf606bb780d6713ed2996bfb13dea37 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -15,12 +15,12 @@ class Activity: """活动控制:全局并发限制,同时最多有 n 个任务在执行(与用户无关)""" @staticmethod - async def is_active(user_id: str) -> bool: + async def can_active(user_id: str) -> bool: """ 判断系统是否达到全局并发上限 :param user_id: 用户实体ID(兼容现有接口签名) - :return: 达到并发上限返回 True,否则 False + :return: 达到并发上限返回 False,否则 True """ time = datetime.now(tz=UTC) @@ -32,11 +32,11 @@ class Activity: SessionActivity.timestamp <= time, ))).one() if count >= SLIDE_WINDOW_QUESTION_COUNT: - return True + return False # 全局并发检查:当前活跃任务数量是否达到上限 current_active = (await session.scalars(select(func.count(SessionActivity.id)))).one() - return current_active >= MAX_CONCURRENT_TASKS + return current_active < MAX_CONCURRENT_TASKS @staticmethod async def set_active(user_id: str) -> None: diff --git a/apps/services/settings.py b/apps/services/settings.py index ae856b0d54c50ae33c04552c608df67d2b440ed6..c21987e697b8052f530784637a33ccdeb4345854 100644 --- a/apps/services/settings.py +++ b/apps/services/settings.py @@ -7,14 +7,14 @@ import logging from sqlalchemy import select from apps.common.postgres import postgres -from apps.llm import embedding +from apps.llm import LLM, embedding, json_generator from apps.models import GlobalSettings, LLMType from apps.scheduler.pool.pool import pool from apps.schemas.request_data import UpdateSpecialLlmReq from apps.schemas.response_data import SelectedSpecialLlmID from apps.services.llm import LLMManager -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) class SettingsManager: @@ -34,7 +34,7 @@ class SettingsManager: )).first() if not settings: - logger.warning("[SettingsManager] 全局设置不存在,返回默认值") + _logger.warning("[SettingsManager] 全局设置不存在,返回默认值") return SelectedSpecialLlmID(functionLLM=None, embeddingLLM=None) return SelectedSpecialLlmID( @@ -60,11 +60,89 @@ class SettingsManager: # 触发向量化 await pool.set_vector() - logger.info("[SettingsManager] Embedding模型已更新,向量化过程已完成") + _logger.info("[SettingsManager] Embedding模型已更新,向量化过程已完成") else: - logger.error("[SettingsManager] 选择的embedding模型 %s 不存在", embedding_llm_id) + _logger.error("[SettingsManager] 选择的embedding模型 %s 不存在", embedding_llm_id) except Exception: - logger.exception("[SettingsManager] Embedding模型向量化过程失败") + _logger.exception("[SettingsManager] Embedding模型向量化过程失败") + + + @staticmethod + async def _initialize_function_llm(function_llm_id: str) -> None: + """ + 初始化Function LLM + + :param function_llm_id: Function LLM ID + :raises ValueError: 如果LLM不存在或不支持Function Call + """ + function_llm = await LLMManager.get_llm(function_llm_id) + if not function_llm: + err = f"[SettingsManager] Function LLM {function_llm_id} 不存在" + raise ValueError(err) + if LLMType.FUNCTION not in function_llm.llmType: + err = f"[SettingsManager] LLM {function_llm_id} 不支持Function Call" + raise ValueError(err) + + # 更新json_generator单例 + llm_instance = LLM(function_llm) + json_generator.init(llm_instance) + _logger.info("[SettingsManager] Function LLM已初始化,json_generator已配置") + + + @staticmethod + async def _validate_embedding_llm(embedding_llm_id: str) -> None: + """ + 验证Embedding LLM是否有效 + + :param embedding_llm_id: Embedding LLM ID + :raises ValueError: 如果LLM不存在或不支持Embedding + """ + embedding_llm = await LLMManager.get_llm(embedding_llm_id) + if not embedding_llm: + err = f"[SettingsManager] Embedding LLM {embedding_llm_id} 不存在" + raise ValueError(err) + if LLMType.EMBEDDING not in embedding_llm.llmType: + err = f"[SettingsManager] LLM {embedding_llm_id} 不支持Embedding" + raise ValueError(err) + + + @staticmethod + async def init_global_llm_settings() -> None: + """ + 初始化全局LLM设置(用于项目启动时调用) + + 从数据库读取全局设置并初始化对应的LLM实例 + """ + async with postgres.session() as session: + settings = (await session.scalars(select(GlobalSettings))).first() + + if not settings: + _logger.warning("[SettingsManager] 全局设置不存在,跳过LLM初始化") + return + + # 初始化Function LLM + if settings.functionLlmId: + try: + await SettingsManager._initialize_function_llm(settings.functionLlmId) + _logger.info("[SettingsManager] Function LLM初始化完成: %s", settings.functionLlmId) + except ValueError: + _logger.exception("[SettingsManager] Function LLM初始化失败") + else: + _logger.info("[SettingsManager] Function LLM未配置,跳过初始化") + + # 初始化Embedding LLM + if settings.embeddingLlmId: + try: + await SettingsManager._validate_embedding_llm(settings.embeddingLlmId) + # 初始化embedding配置 + embedding_llm_config = await LLMManager.get_llm(settings.embeddingLlmId) + if embedding_llm_config: + await embedding.init(embedding_llm_config) + _logger.info("[SettingsManager] Embedding LLM初始化完成: %s", settings.embeddingLlmId) + except ValueError: + _logger.exception("[SettingsManager] Embedding LLM初始化失败") + else: + _logger.info("[SettingsManager] Embedding LLM未配置,跳过初始化") @staticmethod @@ -78,25 +156,13 @@ class SettingsManager: :param user_id: 操作的管理员user_id :param req: 更新请求体 """ - # 验证functionLLM是否支持Function Call + # 验证并初始化functionLLM if req.functionLLM: - function_llm = await LLMManager.get_llm(req.functionLLM) - if not function_llm: - err = f"[SettingsManager] Function LLM {req.functionLLM} 不存在" - raise ValueError(err) - if LLMType.FUNCTION not in function_llm.llmType: - err = f"[SettingsManager] LLM {req.functionLLM} 不支持Function Call" - raise ValueError(err) - - # 验证embeddingLLM是否支持Embedding + await SettingsManager._initialize_function_llm(req.functionLLM) + + # 验证embeddingLLM if req.embeddingLLM: - embedding_llm = await LLMManager.get_llm(req.embeddingLLM) - if not embedding_llm: - err = f"[SettingsManager] Embedding LLM {req.embeddingLLM} 不存在" - raise ValueError(err) - if LLMType.EMBEDDING not in embedding_llm.llmType: - err = f"[SettingsManager] LLM {req.embeddingLLM} 不支持Embedding" - raise ValueError(err) + await SettingsManager._validate_embedding_llm(req.embeddingLLM) # 读取旧的embedding配置 old_embedding_llm = None @@ -124,4 +190,4 @@ class SettingsManager: task = asyncio.create_task(SettingsManager._trigger_vectorization(req.embeddingLLM)) # 添加任务完成回调,避免未处理的异常 task.add_done_callback(lambda _: None) - logger.info("[SettingsManager] 已启动后台向量化任务") + _logger.info("[SettingsManager] 已启动后台向量化任务") diff --git a/design/call/core.md b/design/call/core.md index d165382b60fbe473859596c6d04dbac25921aa07..0cf71e5dec4f9e2b65e23bacc7cdbba2dd3126e5 100644 --- a/design/call/core.md +++ b/design/call/core.md @@ -312,8 +312,8 @@ classDiagram +language: LanguageType +ids: CallIds +question: str - +step_data: dict[str, ExecutorHistory] - +step_order: list[str] + +step_data: dict[UUID, ExecutorHistory] + +step_order: list[UUID] +background: ExecutorBackground +thinking: str } @@ -565,8 +565,8 @@ CallVars是CoreCall中最重要的数据结构,包含了执行Call所需的所 | `language` | LanguageType | ✅ | 当前使用的语言类型(中文/英文) | | `ids` | CallIds | ✅ | 包含任务ID、执行器ID、会话ID等标识信息 | | `question` | str | ✅ | 改写或原始的用户问题 | -| `step_data` | dict[str, ExecutorHistory] | ✅ | 步骤执行历史的字典,键为步骤ID | -| `step_order` | list[str] | ✅ | 步骤执行的顺序列表 | +| `step_data` | dict[UUID, ExecutorHistory] | ✅ | 步骤执行历史的字典,键为步骤ID | +| `step_order` | list[UUID] | ✅ | 步骤执行的顺序列表 | | `background` | ExecutorBackground | ✅ | 执行器的背景信息,包含对话历史和事实 | | `thinking` | str | ✅ | AI的推理思考过程文本 |