diff --git a/apps/common/encoder.py b/apps/common/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f535dde56d23d65d26389e73867e48040de88c
--- /dev/null
+++ b/apps/common/encoder.py
@@ -0,0 +1,21 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""JSON编码器模块"""
+
+import json
+from typing import Any
+from uuid import UUID
+
+
+class UUIDEncoder(json.JSONEncoder):
+ """支持UUID类型的JSON编码器"""
+
+ def default(self, obj: Any) -> Any:
+ """
+ 重写default方法以支持UUID类型序列化
+
+ :param obj: 待序列化的对象
+ :return: 序列化后的对象
+ """
+ if isinstance(obj, UUID):
+ return str(obj)
+ return super().default(obj)
diff --git a/apps/common/oidc.py b/apps/common/oidc.py
index a54f1a60d25f0d5ca68da0c64a249e5e246aca35..37317014144656778ed1d7950ba089582725e468 100644
--- a/apps/common/oidc.py
+++ b/apps/common/oidc.py
@@ -32,17 +32,17 @@ class OIDCProvider:
raise NotImplementedError(err)
@staticmethod
- async def set_token(user_sub: str, access_token: str, refresh_token: str) -> None:
+ async def set_token(user_id: str, access_token: str, refresh_token: str) -> None:
"""设置MongoDB中的OIDC Token到sessions集合"""
async with postgres.session() as session:
await session.merge(Session(
- userSub=user_sub,
+ userId=user_id,
sessionType=SessionType.ACCESS_TOKEN,
token=access_token,
validUntil=datetime.now(UTC) + timedelta(minutes=OIDC_ACCESS_TOKEN_EXPIRE_TIME),
))
await session.merge(Session(
- userSub=user_sub,
+ userId=user_id,
sessionType=SessionType.REFRESH_TOKEN,
token=refresh_token,
validUntil=datetime.now(UTC) + timedelta(minutes=OIDC_REFRESH_TOKEN_EXPIRE_TIME),
diff --git a/apps/common/queue.py b/apps/common/queue.py
index 5cd66e617956d16884f07c92ece5fd4f701e6141..befeab3c86c99c161540665785b04812667557d9 100644
--- a/apps/common/queue.py
+++ b/apps/common/queue.py
@@ -18,6 +18,8 @@ from apps.schemas.message import (
)
from apps.schemas.task import TaskData
+from .encoder import UUIDEncoder
+
logger = logging.getLogger(__name__)
@@ -76,7 +78,9 @@ class MessageQueue:
content=data,
)
- await self._queue.put(json.dumps(message.model_dump(by_alias=True, exclude_none=True), ensure_ascii=False))
+ await self._queue.put(
+ json.dumps(message.model_dump(by_alias=True, exclude_none=True), ensure_ascii=False, cls=UUIDEncoder),
+ )
async def get(self) -> AsyncGenerator[str, None]:
"""从Queue中获取消息;变为async generator"""
@@ -98,7 +102,7 @@ class MessageQueue:
async def _heartbeat(self) -> None:
"""组装用于向用户(前端/Shell端)输出的心跳"""
heartbeat_template = HeartbeatData()
- heartbeat_msg = json.dumps(heartbeat_template.model_dump(by_alias=True), ensure_ascii=False)
+ heartbeat_msg = json.dumps(heartbeat_template.model_dump(by_alias=True), ensure_ascii=False, cls=UUIDEncoder)
while True:
# 如果关闭,则停止心跳
diff --git a/apps/llm/generator.py b/apps/llm/generator.py
index 656ef2800580ab795ffd49cc9b7ec656f27e1411..da5933d5f2344972103235fc5bc7fdc78b1b32ba 100644
--- a/apps/llm/generator.py
+++ b/apps/llm/generator.py
@@ -21,6 +21,16 @@ _logger = logging.getLogger(__name__)
JSON_GEN_MAX_TRIAL = 3
+
+class JsonValidationError(Exception):
+ """JSON验证失败异常,携带生成的结果"""
+
+ def __init__(self, message: str, result: dict[str, Any]) -> None:
+ """初始化异常"""
+ super().__init__(message)
+ self.result = result
+
+
class JsonGenerator:
"""综合Json生成器(全局单例)"""
@@ -56,7 +66,7 @@ class JsonGenerator:
def _build_messages(
self,
- function: dict[str, Any],
+ functions: list[dict[str, Any]],
conversation: list[dict[str, str]],
language: LanguageType = LanguageType.CHINESE,
) -> list[dict[str, str]]:
@@ -70,9 +80,8 @@ class JsonGenerator:
template = self._env.from_string(JSON_GEN[language])
prompt = template.render(
query=query,
- conversation=conversation[:-1],
- schema=function["parameters"],
use_xml_format=False,
+ functions=functions,
)
messages = [*conversation[:-1], {"role": "user", "content": prompt}]
@@ -117,7 +126,7 @@ class JsonGenerator:
async def _single_trial(
self,
function: dict[str, Any],
- context: list[dict[str, str]],
+ conversation: list[dict[str, str]],
language: LanguageType = LanguageType.CHINESE,
) -> dict[str, Any]:
"""单次尝试,包含校验逻辑;function使用OpenAI标准Function格式"""
@@ -131,10 +140,10 @@ class JsonGenerator:
# 执行生成
if self._support_function_call:
# 如果支持FunctionCall
- result = await self._call_with_function(function, context, language)
+ result = await self._call_with_function(function, conversation, language)
else:
# 如果不支持FunctionCall
- result = await self._call_without_function(function, context, language)
+ result = await self._call_without_function(function, conversation, language)
# 校验结果
try:
@@ -144,11 +153,9 @@ class JsonGenerator:
err_info = err_info.split("\n\n")[0]
_logger.info("[JSONGenerator] 验证失败:%s", err_info)
- context.append({
- "role": "assistant",
- "content": f"Attempted to use tool but validation failed: {err_info}",
- })
- raise
+ # 创建带结果的验证异常
+ validation_error = JsonValidationError(err_info, result)
+ raise validation_error from err
else:
return result
@@ -163,7 +170,7 @@ class JsonGenerator:
err = "[JSONGenerator] 未初始化,请先调用init()方法"
raise RuntimeError(err)
- messages = self._build_messages(function, conversation, language)
+ messages = self._build_messages([function], conversation, language)
tool = LLMFunctions(
name=function["name"],
@@ -172,7 +179,7 @@ class JsonGenerator:
)
tool_call_result = {}
- async for chunk in self._llm.call(messages, include_thinking=False, streaming=True, tools=[tool]):
+ async for chunk in self._llm.call(messages, include_thinking=False, streaming=False, tools=[tool]):
if chunk.tool_call:
tool_call_result.update(chunk.tool_call)
@@ -193,11 +200,11 @@ class JsonGenerator:
err = "[JSONGenerator] 未初始化,请先调用init()方法"
raise RuntimeError(err)
- messages = self._build_messages(function, conversation, language)
+ messages = self._build_messages([function], conversation, language)
# 使用LLM的call方法获取响应
full_response = ""
- async for chunk in self._llm.call(messages, include_thinking=False, streaming=True):
+ async for chunk in self._llm.call(messages, include_thinking=False, streaming=False):
if chunk.content:
full_response += chunk.content
@@ -225,32 +232,50 @@ class JsonGenerator:
schema = function["parameters"]
Draft7Validator.check_schema(schema)
- # 构建上下文
- context = [
- {
- "role": "system",
- "content": "You are a helpful assistant that can use tools to help answer user queries.",
- },
- ]
- if conversation:
- context.extend(conversation)
-
count = 0
- original_context = context.copy()
+
+ retry_conversation = conversation.copy() if conversation else []
while count < JSON_GEN_MAX_TRIAL:
count += 1
try:
# 如果_single_trial没有抛出异常,直接返回结果,不进行重试
- return await self._single_trial(function, context, language)
- except Exception:
+ return await self._single_trial(function, retry_conversation, language)
+ except JsonValidationError as e:
_logger.exception(
"[JSONGenerator] 第 %d/%d 次尝试失败",
count,
JSON_GEN_MAX_TRIAL,
)
+ if count < JSON_GEN_MAX_TRIAL:
+ # 将错误信息添加到retry_conversation中,模拟完整的对话流程
+ err_info = str(e)
+ err_info = err_info.split("\n\n")[0]
+
+ # 模拟assistant调用了函数但失败,包含实际获得的参数
+ function_name = function["name"]
+ result_json = json.dumps(e.result, ensure_ascii=False)
+ retry_conversation.append({
+ "role": "assistant",
+ "content": f"I called function {function_name} with parameters: {result_json}",
+ })
+
+ # 模拟user要求重新调用函数并给出错误原因
+ retry_conversation.append({
+ "role": "user",
+ "content": (
+ f"The previous function call failed with validation error: {err_info}. "
+ "Please try again and ensure the output strictly follows the required schema."
+ ),
+ })
+ continue
+ except Exception:
+ _logger.exception(
+ "[JSONGenerator] 第 %d/%d 次尝试失败(非验证错误)",
+ count,
+ JSON_GEN_MAX_TRIAL,
+ )
if count < JSON_GEN_MAX_TRIAL:
continue
- context = original_context
return {}
diff --git a/apps/llm/llm.py b/apps/llm/llm.py
index fc005cd7fe73cdc5c81e3e757fac794a542f7da9..7cba3b6b538cca7261b2430404108ac806ef65dd 100644
--- a/apps/llm/llm.py
+++ b/apps/llm/llm.py
@@ -45,27 +45,14 @@ class LLM:
streaming: bool = True,
tools: list[LLMFunctions] | None = None,
) -> AsyncGenerator[LLMChunk, None]:
- """调用大模型,分为流式和非流式两种"""
- if streaming:
- async for chunk in await self._provider.chat(
- messages,
- include_thinking=include_thinking,
- tools=tools,
- ):
- yield chunk
- else:
- # 非流式模式下,需要收集所有内容
- async for _ in await self._provider.chat(
- messages,
- include_thinking=include_thinking,
- tools=tools,
- ):
- continue
- # 直接yield最终结果
- yield LLMChunk(
- content=self._provider.full_answer,
- reasoning_content=self._provider.full_thinking,
- )
+ """调用大模型,统一处理流式和非流式"""
+ async for chunk in self._provider.chat(
+ messages,
+ include_thinking=include_thinking,
+ streaming=streaming,
+ tools=tools,
+ ):
+ yield chunk
@property
def input_tokens(self) -> int:
diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py
index 6ece04539e6d50896980a8d6dc2bf956e420c900..fa424e731924cb2712470e881fc32ea96d047460 100644
--- a/apps/llm/prompt.py
+++ b/apps/llm/prompt.py
@@ -13,10 +13,9 @@ JSON_GEN: dict[LanguageType, str] = {
- 你可以访问能够帮助收集信息的工具
- - 逐步使用工具,每次使用都基于之前的结果
+ - 逐步使用工具,每次使用都基于之前的结果{% if use_xml_format %}
- 用户的查询在 标签中提供
- {% if previous_trial %}- 查看 信息以避免重复错误{% endif %}
- {% if use_xml_format %}- 使用 XML 样式的标签格式化工具调用,其中工具名称是根标签,每个参数是嵌套标签
+ - 使用 XML 样式的标签格式化工具调用,其中工具名称是根标签,每个参数是嵌套标签
- 使用架构中指定的确切工具名称和参数名称
- 基本格式结构:
<工具名称>
@@ -27,7 +26,9 @@ JSON_GEN: dict[LanguageType, str] = {
* 数字:10
* 布尔值:true
* 数组(重复标签):项目1项目2
- * 对象(嵌套标签):值{% endif %}
+ * 对象(嵌套标签):值{% else %}
+ - 必须使用提供的工具来回答查询
+ - 工具调用必须遵循工具架构中定义的JSON格式{% endif %}
{% if use_xml_format %}
@@ -36,72 +37,57 @@ JSON_GEN: dict[LanguageType, str] = {
杭州的天气怎么样?
-
-
- get_weather: 获取指定城市的当前天气信息
-
-
- {
- "name": "get_weather",
- "description": "获取指定城市的当前天气信息",
- "parameters": {
- "type": "object",
- "properties": {
- "city": {
- "type": "string",
- "description": "要查询天气的城市名称"
+
+ -
+ get_weather
+ 获取指定城市的当前天气信息
+
+ {
+ "type": "object",
+ "properties": {
+ "city": {
+ "type": "string",
+ "description": "要查询天气的城市名称"
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "温度单位"
+ }
},
- "unit": {
- "type": "string",
- "enum": ["celsius", "fahrenheit"],
- "description": "温度单位"
- },
- "include_forecast": {
- "type": "boolean",
- "description": "是否包含预报数据"
- }
- },
- "required": ["city"]
- }
- }
-
-
+ "required": ["city"]
+ }
+
+
+
助手响应:
杭州
celsius
- false
- {% endif %}
{{ query }}
- {% if previous_trial %}
-
-
-
- 你之前的工具调用有不正确的参数。
-
-
- {{ previous_trial }}
-
-
- {{ err_info }}
-
-
- {% endif %}
-
-
-
- {{ tool_descriptions }}
-
-
- {{ tool_schemas }}
-
-
+
+ {% for func in functions %}
+ -
+ {{ func.name }}
+ {{ func.description }}
+ {{ func.parameters | tojson(indent=2) }}
+
{% endfor %}
+ {% else %}
+
+ ## 可用工具
+ {% for func in functions %}
+ - **{{ func.name }}**: {{ func.description }}
+ {% endfor %}
+
+ 请根据用户查询选择合适的工具来回答问题。你必须使用上述工具之一来处理查询。
+
+ **用户查询**: {{ query }}{% endif %}
""",
),
LanguageType.ENGLISH: dedent(
@@ -111,11 +97,9 @@ JSON_GEN: dict[LanguageType, str] = {
- You have access to tools that can help gather information
- - Use tools step-by-step, with each use informed by previous results
+ - Use tools step-by-step, with each use informed by previous results{% if use_xml_format %}
- The user's query is provided in the tags
- {% if previous_trial %}- Review the information to avoid \
-repeating mistakes{% endif %}
- {% if use_xml_format %}- Format tool calls using XML-style tags where the tool name is the root tag \
+ - Format tool calls using XML-style tags where the tool name is the root tag \
and each parameter is a nested tag
- Use the exact tool name and parameter names as specified in the schema
- Basic format structure:
@@ -127,7 +111,9 @@ and each parameter is a nested tag
* Number: 10
* Boolean: true
* Array (repeat tags): item1item2
- * Object (nest tags): value{% endif %}
+ * Object (nest tags): value{% else %}
+ - You must use the provided tools to answer the query
+ - Tool calls must follow the JSON format defined in the tool schemas{% endif %}
{% if use_xml_format %}
@@ -136,72 +122,58 @@ and each parameter is a nested tag
What is the weather like in Hangzhou?
-
-
- get_weather: Get current weather information for a specified city
-
-
- {
- "name": "get_weather",
- "description": "Get current weather information for a specified city",
- "parameters": {
- "type": "object",
- "properties": {
- "city": {
- "type": "string",
- "description": "The city name to query weather for"
+
+ -
+ get_weather
+ Get current weather information for a specified city
+
+ {
+ "type": "object",
+ "properties": {
+ "city": {
+ "type": "string",
+ "description": "The city name to query weather for"
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "Temperature unit"
+ }
},
- "unit": {
- "type": "string",
- "enum": ["celsius", "fahrenheit"],
- "description": "Temperature unit"
- },
- "include_forecast": {
- "type": "boolean",
- "description": "Whether to include forecast data"
- }
- },
- "required": ["city"]
- }
- }
-
-
+ "required": ["city"]
+ }
+
+
+
Assistant response:
Hangzhou
celsius
- false
- {% endif %}
{{ query }}
- {% if previous_trial %}
-
-
-
- Your previous tool call had incorrect arguments.
-
-
- {{ previous_trial }}
-
-
- {{ err_info }}
-
-
- {% endif %}
-
-
-
- {{ tool_descriptions }}
-
-
- {{ tool_schemas }}
-
-
+
+ {% for func in functions %}
+ -
+ {{ func.name }}
+ {{ func.description }}
+ {{ func.parameters | tojson(indent=2) }}
+
{% endfor %}
+ {% else %}
+
+ ## Available Tools
+ {% for func in functions %}
+ - **{{ func.name }}**: {{ func.description }}
+ {% endfor %}
+
+ Please select the appropriate tool(s) from above to answer the user's query.
+ You must use one of the tools listed above to process the query.
+
+ **User Query**: {{ query }}{% endif %}
""",
),
}
diff --git a/apps/llm/providers/base.py b/apps/llm/providers/base.py
index 207fa0391c102b41a0c38d70b40a418454359f90..5c5757ae8622f38397e32601bc8bc4d00f7e75f1 100644
--- a/apps/llm/providers/base.py
+++ b/apps/llm/providers/base.py
@@ -51,10 +51,11 @@ class BaseProvider:
async def chat(
self, messages: list[dict[str, str]],
*, include_thinking: bool = False,
+ streaming: bool = True,
tools: list[LLMFunctions] | None = None,
) -> AsyncGenerator[LLMChunk, None]:
"""聊天"""
- raise NotImplementedError
+ yield LLMChunk(content="")
async def embedding(self, text: list[str]) -> list[list[float]]:
diff --git a/apps/llm/providers/ollama.py b/apps/llm/providers/ollama.py
index e721510694c2f7d264e6dea72b06e88af1c1047d..4b668536a681454110d1e6cfbd547ffac8905891 100644
--- a/apps/llm/providers/ollama.py
+++ b/apps/llm/providers/ollama.py
@@ -77,53 +77,25 @@ class OllamaProvider(BaseProvider):
"content": "" + self.full_thinking + "" + self.full_answer,
}])
- @override
- async def chat(
- self, messages: list[dict[str, str]],
- *, include_thinking: bool = False,
- tools: list[LLMFunctions] | None = None,
- ) -> AsyncGenerator[LLMChunk, None]:
- # 检查能力
- if not self._allow_chat:
- err = "[OllamaProvider] 当前模型不支持Chat"
- _logger.error(err)
- raise RuntimeError(err)
-
- # 初始化Token计数
- self.input_tokens = 0
- self.output_tokens = 0
-
- # 检查消息
- messages = self._validate_messages(messages)
-
- # 流式返回响应
- last_chunk = None
- chat_kwargs = {
- "model": self.config.modelName,
- "messages": messages,
- "options": {
- "temperature": self.config.temperature,
- "num_predict": self.config.maxToken,
+ def _build_tools_param(self, tools: list[LLMFunctions]) -> list[dict[str, Any]]:
+ """构建工具参数"""
+ return [{
+ "type": "function",
+ "function": {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": tool.param_schema,
},
- "stream": True,
- **self.config.extraConfig,
- }
+ } for tool in tools]
- # 如果提供了tools,则传入以启用function-calling模式
- if tools:
- functions = []
- for tool in tools:
- functions += [{
- "type": "function",
- "function": {
- "name": tool.name,
- "description": tool.description,
- "parameters": tool.param_schema,
- },
- }]
- chat_kwargs["tools"] = functions
-
- async for chunk in await self._client.chat(**chat_kwargs):
+ async def _handle_streaming_response(
+ self, response: AsyncGenerator[ChatResponse, None],
+ messages: list[dict[str, str]],
+ tools: list[LLMFunctions] | None,
+ ) -> AsyncGenerator[LLMChunk, None]:
+ """处理流式响应"""
+ last_chunk = None
+ async for chunk in response:
last_chunk = chunk
if chunk.message.thinking:
self.full_thinking += chunk.message.thinking
@@ -145,6 +117,85 @@ class OllamaProvider(BaseProvider):
# 使用最后一个chunk的usage数据
self._process_usage_data(last_chunk, messages)
+ async def _handle_non_streaming_response(
+ self, response: ChatResponse,
+ messages: list[dict[str, str]],
+ ) -> AsyncGenerator[LLMChunk, None]:
+ """处理非流式响应"""
+ if response.message.thinking:
+ self.full_thinking = response.message.thinking
+ if response.message.content:
+ self.full_answer = response.message.content
+
+ # 处理usage数据
+ self._process_usage_data(response if response.done else None, messages)
+
+ # 一次性返回完整结果
+ yield LLMChunk(
+ content=self.full_answer,
+ reasoning_content=self.full_thinking,
+ )
+
+ def _build_chat_kwargs(
+ self, messages: list[dict[str, str]],
+ *,
+ streaming: bool,
+ tools: list[LLMFunctions] | None,
+ ) -> dict[str, Any]:
+ """构建chat请求参数"""
+ chat_kwargs = {
+ "model": self.config.modelName,
+ "messages": messages,
+ "options": {
+ "temperature": self.config.temperature,
+ "num_predict": self.config.maxToken,
+ },
+ "stream": streaming,
+ **self.config.extraConfig,
+ }
+
+ # 如果提供了tools,则传入以启用function-calling模式
+ if tools:
+ chat_kwargs["tools"] = self._build_tools_param(tools)
+
+ return chat_kwargs
+
+ @override
+ async def chat(
+ self, messages: list[dict[str, str]],
+ tools: list[LLMFunctions] | None = None,
+ *, include_thinking: bool = False,
+ streaming: bool = True,
+ ) -> AsyncGenerator[LLMChunk, None]:
+ # 检查能力
+ if not self._allow_chat:
+ err = "[OllamaProvider] 当前模型不支持Chat"
+ _logger.error(err)
+ raise RuntimeError(err)
+
+ # 初始化Token计数和累积内容
+ self.input_tokens = 0
+ self.output_tokens = 0
+ self.full_thinking = ""
+ self.full_answer = ""
+
+ # 检查消息
+ messages = self._validate_messages(messages)
+
+ # 构建chat请求参数
+ chat_kwargs = self._build_chat_kwargs(messages, streaming=streaming, tools=tools)
+
+ # 发送请求
+ response = await self._client.chat(**chat_kwargs)
+
+ # 根据streaming模式处理响应
+ if streaming:
+ async for chunk in self._handle_streaming_response(response, messages, tools):
+ yield chunk
+ else:
+ async for chunk in self._handle_non_streaming_response(response, messages):
+ yield chunk
+
def _seq_to_list(self, seq: Sequence[Any]) -> list[Any]:
"""将Sequence转换为list"""
result = []
@@ -168,5 +219,3 @@ class OllamaProvider(BaseProvider):
input=text,
)
return self._seq_to_list(result.embeddings)
-
-
diff --git a/apps/llm/providers/openai.py b/apps/llm/providers/openai.py
index b14e9d4b9bea6f35227f819252959811dd6fb0e5..9de76a29e2d6ad3d48efc0085010df813e7be431 100644
--- a/apps/llm/providers/openai.py
+++ b/apps/llm/providers/openai.py
@@ -8,6 +8,7 @@ from typing import cast
import httpx
from openai import AsyncOpenAI, AsyncStream
from openai.types.chat import (
+ ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
)
@@ -92,11 +93,159 @@ class OpenAIProvider(BaseProvider):
"content": "" + self.full_thinking + "" + self.full_answer,
}])
+ def _handle_usage_response(self, response: ChatCompletion, messages: list[dict[str, str]]) -> None:
+ """处理非流式响应的usage信息"""
+ if hasattr(response, "usage") and response.usage:
+ self.input_tokens = response.usage.prompt_tokens
+ self.output_tokens = response.usage.completion_tokens
+ else:
+ # 回退到本地估算
+ self.input_tokens = token_calculator.calculate_token_length(messages)
+ self.output_tokens = token_calculator.calculate_token_length([{
+ "role": "assistant",
+ "content": "" + self.full_thinking + "" + self.full_answer,
+ }])
+
+ def _build_request_kwargs(
+ self,
+ messages: list[dict[str, str]],
+ *,
+ streaming: bool,
+ ) -> dict:
+ """构建请求参数"""
+ request_kwargs = {
+ "model": self.config.modelName,
+ "messages": self._convert_messages(messages),
+ "max_tokens": self.config.maxToken,
+ "temperature": self.config.temperature,
+ "stream": streaming,
+ **self.config.extraConfig,
+ }
+
+ # 流式模式下添加usage统计选项
+ if streaming:
+ request_kwargs["stream_options"] = {"include_usage": True}
+
+ return request_kwargs
+
+ def _add_tools_to_request(
+ self,
+ request_kwargs: dict,
+ tools: list[LLMFunctions],
+ ) -> None:
+ """将工具添加到请求参数中"""
+ functions = [
+ {
+ "type": "function",
+ "function": {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": tool.param_schema,
+ },
+ }
+ for tool in tools
+ ]
+ request_kwargs["tools"] = functions
+
+ async def _process_streaming_chunk(
+ self,
+ chunk: ChatCompletionChunk,
+ *,
+ include_thinking: bool,
+ ) -> AsyncGenerator[LLMChunk, None]:
+ """处理单个流式响应chunk"""
+ if not hasattr(chunk, "choices") or not chunk.choices:
+ return
+
+ delta = chunk.choices[0].delta
+
+ # 处理reasoning_content
+ if (
+ hasattr(delta, "reasoning_content") and
+ getattr(delta, "reasoning_content", None) and
+ include_thinking
+ ):
+ reasoning_content = getattr(delta, "reasoning_content", "")
+ self.full_thinking += reasoning_content
+ yield LLMChunk(reasoning_content=reasoning_content)
+
+ # 处理content
+ if hasattr(delta, "content") and delta.content:
+ self.full_answer += delta.content
+ yield LLMChunk(content=delta.content)
+
+ async def _handle_streaming_response(
+ self,
+ request_kwargs: dict,
+ messages: list[dict[str, str]],
+ *,
+ include_thinking: bool,
+ ) -> AsyncGenerator[LLMChunk, None]:
+ """处理流式响应"""
+ stream: AsyncStream[ChatCompletionChunk] = await self._client.chat.completions.create(**request_kwargs)
+ last_chunk = None
+
+ async for chunk in stream:
+ last_chunk = chunk
+ async for llm_chunk in self._process_streaming_chunk(
+ chunk,
+ include_thinking=include_thinking,
+ ):
+ yield llm_chunk
+
+ # 处理最后一个Chunk的usage
+ self._handle_usage_chunk(last_chunk, messages)
+
+ async def _handle_non_streaming_response(
+ self,
+ request_kwargs: dict,
+ messages: list[dict[str, str]],
+ tools: list[LLMFunctions] | None,
+ *,
+ include_thinking: bool,
+ ) -> AsyncGenerator[LLMChunk, None]:
+ """处理非流式响应"""
+ response: ChatCompletion = await self._client.chat.completions.create(**request_kwargs)
+
+ tool_call_dict = {}
+ if response.choices:
+ message = response.choices[0].message
+
+ # 处理reasoning_content
+ if (
+ hasattr(message, "reasoning_content") and
+ getattr(message, "reasoning_content", None) and
+ include_thinking
+ ):
+ self.full_thinking = getattr(message, "reasoning_content", "")
+
+ # 处理content
+ if hasattr(message, "content") and message.content:
+ self.full_answer = message.content
+
+ # 处理工具调用
+ if tools and hasattr(message, "tool_calls") and message.tool_calls:
+ for tool_call in message.tool_calls:
+ if tool_call.type == "function" and hasattr(tool_call, "function"):
+ func = tool_call.function
+ tool_call_dict[func.name] = func.arguments
+
+ # 处理usage数据
+ self._handle_usage_response(response, messages)
+
+ # 一次性返回完整结果
+ yield LLMChunk(
+ content=self.full_answer,
+ reasoning_content=self.full_thinking,
+ tool_call=tool_call_dict if tool_call_dict else None,
+ )
+
@override
- async def chat( # noqa: C901
+ async def chat(
self, messages: list[dict[str, str]],
- *, include_thinking: bool = False,
tools: list[LLMFunctions] | None = None,
+ *, include_thinking: bool = False,
+ streaming: bool = True,
) -> AsyncGenerator[LLMChunk, None]:
"""聊天"""
# 检查能力
@@ -105,73 +254,38 @@ class OpenAIProvider(BaseProvider):
_logger.error(err)
raise RuntimeError(err)
- # 初始化Token计数
+ # 初始化Token计数和累积内容
self.input_tokens = 0
self.output_tokens = 0
+ self.full_thinking = ""
+ self.full_answer = ""
# 检查消息
messages = self._validate_messages(messages)
- request_kwargs = {
- "model": self.config.modelName,
- "messages": self._convert_messages(messages),
- "max_tokens": self.config.maxToken,
- "temperature": self.config.temperature,
- "stream": True,
- "stream_options": {"include_usage": True},
- **self.config.extraConfig,
- }
+ # 构建请求参数
+ request_kwargs = self._build_request_kwargs(messages, streaming=streaming)
# 如果提供了tools,则启用function-calling模式
if tools:
- functions = []
- for tool in tools:
- functions += [{
- "type": "function",
- "function": {
- "name": tool.name,
- "description": tool.description,
- "parameters": tool.param_schema,
- },
- }]
- request_kwargs["tools"] = functions
+ self._add_tools_to_request(request_kwargs, tools)
- stream: AsyncStream[ChatCompletionChunk] = await self._client.chat.completions.create(**request_kwargs)
- # 流式返回响应
- last_chunk = None
- async for chunk in stream:
- last_chunk = chunk
- if hasattr(chunk, "choices") and chunk.choices:
- delta = chunk.choices[0].delta
- if (
- hasattr(delta, "reasoning_content") and
- getattr(delta, "reasoning_content", None) and
- include_thinking
- ):
- reasoning_content = getattr(delta, "reasoning_content", "")
- self.full_thinking += reasoning_content
- yield LLMChunk(reasoning_content=reasoning_content)
-
- if (
- hasattr(chunk.choices[0].delta, "content") and
- chunk.choices[0].delta.content
- ):
- self.full_answer += chunk.choices[0].delta.content
- yield LLMChunk(content=chunk.choices[0].delta.content)
-
- # 在chat中统一处理工具调用分块(当提供了tools时)
- if tools and hasattr(delta, "tool_calls") and delta.tool_calls:
- tool_call_dict = {}
- for tool_call in delta.tool_calls:
- if hasattr(tool_call, "function") and tool_call.function:
- tool_call_dict.update({
- tool_call.function.name: tool_call.function.arguments,
- })
- if tool_call_dict:
- yield LLMChunk(tool_call=tool_call_dict)
-
- # 处理最后一个Chunk的usage(仅在最后一个chunk会出现)
- self._handle_usage_chunk(last_chunk, messages)
+ # 根据模式选择处理方式
+ if streaming:
+ async for chunk in self._handle_streaming_response(
+ request_kwargs,
+ messages,
+ include_thinking=include_thinking,
+ ):
+ yield chunk
+ else:
+ async for chunk in self._handle_non_streaming_response(
+ request_kwargs,
+ messages,
+ tools,
+ include_thinking=include_thinking,
+ ):
+ yield chunk
@override
async def embedding(self, text: list[str]) -> list[list[float]]:
diff --git a/apps/main.py b/apps/main.py
index 9f26b4a47463ec5671c8a21a2bd5e1759b2b81df..0241c0019503ff259bd4d086a34142d8b10198e5 100644
--- a/apps/main.py
+++ b/apps/main.py
@@ -33,6 +33,7 @@ from .routers import (
user,
)
from .scheduler.pool.pool import pool
+from .services.settings import SettingsManager
@asynccontextmanager
@@ -41,6 +42,8 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
# 启动时初始化资源
await postgres.init()
await pool.init()
+ # 初始化全局LLM设置
+ await SettingsManager.init_global_llm_settings()
yield
# 关闭时释放资源
await postgres.close()
diff --git a/apps/models/__init__.py b/apps/models/__init__.py
index 6a0fcc3c016a92b70ab7cd134e249eef819ad5c2..4c0301575ba0bd2f7d23a5b79e3f447e6292551c 100644
--- a/apps/models/__init__.py
+++ b/apps/models/__init__.py
@@ -11,10 +11,10 @@ from .flow import Flow
from .llm import LLMData, LLMProvider, LLMType
from .mcp import MCPActivated, MCPInfo, MCPInstallStatus, MCPTools, MCPType
from .node import NodeInfo
-from .settings import GlobalSettings
from .record import Record, RecordMetadata
from .service import Service, ServiceACL, ServiceHashes
from .session import Session, SessionActivity, SessionType
+from .settings import GlobalSettings
from .tag import Tag
from .task import (
ExecutorCheckpoint,
diff --git a/apps/models/task.py b/apps/models/task.py
index 6d2f33a3e2b2b85235f7f6eed0edc530325a075e..51df70961b3236aa1734f78efb17e9d22d127545 100644
--- a/apps/models/task.py
+++ b/apps/models/task.py
@@ -98,10 +98,10 @@ class TaskRuntime(Base):
"""时间"""
fullTime: Mapped[float] = mapped_column(Float, default=0.0, nullable=False) # noqa: N815
"""完整时间成本"""
- sessionId: Mapped[str | None] = mapped_column( # noqa: N815
- String(255), ForeignKey("framework_session.id"), nullable=True, default=None,
+ authHeader: Mapped[str | None] = mapped_column( # noqa: N815
+ String(255), nullable=True, default=None,
)
- """会话ID"""
+ """认证头"""
userInput: Mapped[str] = mapped_column(Text, nullable=False, default="") # noqa: N815
"""用户输入"""
fullAnswer: Mapped[str] = mapped_column(Text, nullable=False, default="") # noqa: N815
diff --git a/apps/routers/auth.py b/apps/routers/auth.py
index b74bf83e5c44ae7d8d6db3e10aea99395feaf9b1..5216aae960f193cb9dd7e5cf71494f4c15bcebd7 100644
--- a/apps/routers/auth.py
+++ b/apps/routers/auth.py
@@ -48,9 +48,9 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse:
token = await oidc_provider.get_oidc_token(code)
user_info = await oidc_provider.get_oidc_user(token["access_token"])
- user_sub: str | None = user_info.get("user_sub", None)
- if user_sub:
- await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"])
+ user_id: str | None = user_info.get("user_sub", None)
+ if user_id:
+ await oidc_provider.set_token(user_id, token["access_token"], token["refresh_token"])
except Exception as e:
_logger.exception("User login failed")
status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN
@@ -68,7 +68,7 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse:
)
user_host = request.client.host
# 获取用户sub
- if not user_sub:
+ if not user_id:
_logger.error("OIDC no user_sub associated.")
return _templates.TemplateResponse(
"login_failed.html.j2",
@@ -76,9 +76,10 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse:
status_code=status.HTTP_403_FORBIDDEN,
)
# 创建或更新用户登录信息
- await UserManager.create_or_update_on_login(user_sub)
+ user_name: str | None = user_info.get("user_name", None)
+ await UserManager.create_or_update_on_login(user_id, user_name)
# 创建会话
- current_session = await SessionManager.create_session(user_sub, user_host)
+ current_session = await SessionManager.create_session(user_id, user_host)
return _templates.TemplateResponse(
"login_success.html.j2",
{"request": request, "current_session": current_session},
@@ -97,7 +98,7 @@ async def logout(request: Request) -> JSONResponse:
result={},
).model_dump(exclude_none=True, by_alias=True),
)
- await TokenManager.delete_plugin_token(request.state.user_sub)
+ await TokenManager.delete_plugin_token(request.state.user_id)
if hasattr(request.state, "session_id"):
await SessionManager.delete_session(request.state.session_id)
@@ -145,7 +146,7 @@ async def oidc_logout(token: Annotated[str, Body()]) -> JSONResponse:
}, response_model=PostPersonalTokenRsp)
async def change_personal_token(request: Request) -> JSONResponse:
"""POST /auth/key: 重置用户的API密钥"""
- new_api_key: str | None = await PersonalTokenManager.update_personal_token(request.state.user_sub)
+ new_api_key: str | None = await PersonalTokenManager.update_personal_token(request.state.user_id)
if not new_api_key:
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData(
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
diff --git a/apps/routers/chat.py b/apps/routers/chat.py
index 530c9c137af374c6d42dd9c3f786489b10a21681..48f5ad3c84820cc570ebb464dd3ce6ed91ef9bf4 100644
--- a/apps/routers/chat.py
+++ b/apps/routers/chat.py
@@ -29,26 +29,28 @@ router = APIRouter(
],
)
-async def chat_generator(post_body: RequestData, user_id: str, session_id: str) -> AsyncGenerator[str, None]:
+async def chat_generator(post_body: RequestData, user_id: str, auth_header: str) -> AsyncGenerator[str, None]:
"""进行实际问答,并从MQ中获取消息"""
try:
- await Activity.set_active(user_id)
-
# 敏感词检查
if await words_check.check(post_body.question) != 1:
yield "data: [SENSITIVE]\n\n"
_logger.info("[Chat] 问题包含敏感词!")
- await Activity.remove_active(user_id)
return
# 创建queue;由Scheduler进行关闭
queue = MessageQueue()
await queue.init()
- # 在单独Task中运行Scheduler,拉齐queue.get的时机
+ # 1. 先初始化 Scheduler(确保初始化成功)
scheduler = Scheduler()
- await scheduler.init(queue, post_body, user_id)
- _logger.info(f"[Chat] 用户是否活跃: {await Activity.is_active(user_id)}")
+ await scheduler.init(queue, post_body, user_id, auth_header)
+
+ # 2. Scheduler初始化成功后,再设置用户处于活动状态
+ await Activity.set_active(user_id)
+ _logger.info(f"[Chat] 用户是否活跃: {await Activity.can_active(user_id)}")
+
+ # 3. 最后判断并运行 Executor
scheduler_task = asyncio.create_task(scheduler.run())
# 处理每一条消息
@@ -95,7 +97,7 @@ async def chat_generator(post_body: RequestData, user_id: str, session_id: str)
@router.post("/chat")
async def chat(request: Request, post_body: RequestData) -> StreamingResponse:
"""LLM流式对话接口"""
- session_id = request.state.session_id
+ auth_header = request.headers.get("Authorization", "").replace("Bearer ", "")
user_id = request.state.user_id
# 问题黑名单检测
if (post_body.question is not None) and (
@@ -105,7 +107,7 @@ async def chat(request: Request, post_body: RequestData) -> StreamingResponse:
await UserBlacklistManager.change_blacklisted_users(user_id, -10)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted")
- res = chat_generator(post_body, user_id, session_id)
+ res = chat_generator(post_body, user_id, auth_header)
return StreamingResponse(
content=res,
media_type="text/event-stream",
diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py
index fca0d4ab344b89dced98c7aab740e49279705590..c677d70573d57a8e066f2a9b4a15a10d9665542e 100644
--- a/apps/scheduler/call/core.py
+++ b/apps/scheduler/call/core.py
@@ -6,6 +6,7 @@ Core Call类是定义了所有Call都应具有的方法和参数的PyDantic类
"""
import logging
+import uuid
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, ClassVar, Self
@@ -89,8 +90,8 @@ class CoreCall(BaseModel):
logger.error(err)
raise ValueError(err)
- history = {}
- history_order = []
+ history: dict[uuid.UUID, ExecutorHistory] = {}
+ history_order: list[uuid.UUID] = []
for item in executor.task.context:
history[item.stepId] = item
history_order.append(item.stepId)
@@ -100,7 +101,7 @@ class CoreCall(BaseModel):
ids=CallIds(
task_id=executor.task.metadata.id,
executor_id=executor.task.state.executorId,
- auth_header=executor.task.runtime.sessionId,
+ auth_header=executor.task.runtime.authHeader,
user_id=executor.task.metadata.userId,
app_id=executor.task.state.appId,
conversation_id=executor.task.metadata.conversationId,
@@ -114,12 +115,12 @@ class CoreCall(BaseModel):
@staticmethod
- def _extract_history_variables(path: str, history: dict[str, ExecutorHistory]) -> Any:
+ def _extract_history_variables(path: str, history: dict[uuid.UUID, ExecutorHistory]) -> Any:
"""
提取History中的变量
:param path: 路径,格式为:step_id/key/to/variable
- :param history: Step历史,即call_vars.history
+ :param history: Step历史,即call_vars.step_data
:return: 变量
"""
split_path = path.split("/")
@@ -127,12 +128,21 @@ class CoreCall(BaseModel):
err = f"[CoreCall] 路径格式错误: {path}"
logger.error(err)
return None
- if split_path[0] not in history:
- err = f"[CoreCall] 步骤{split_path[0]}不存在"
+
+ # 将字符串形式的步骤ID转换为UUID
+ try:
+ step_id = uuid.UUID(split_path[0])
+ except ValueError:
+ err = f"[CoreCall] 步骤ID格式错误: {split_path[0]}"
+ logger.exception(err)
+ return None
+
+ if step_id not in history:
+ err = f"[CoreCall] 步骤{step_id}不存在"
logger.error(err)
return None
- data = history[split_path[0]].outputData
+ data = history[step_id].outputData
for key in split_path[1:]:
if key not in data:
err = f"[CoreCall] 输出Key {key} 不存在"
diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py
index 8140fb04ae973b2ddb174bb279f94f613e40d3e7..a3dc92b890bdb04520b565eb40b090be55e559f6 100644
--- a/apps/scheduler/call/llm/llm.py
+++ b/apps/scheduler/call/llm/llm.py
@@ -4,7 +4,7 @@
import logging
from collections.abc import AsyncGenerator
from datetime import datetime
-from typing import Any
+from typing import TYPE_CHECKING, Any
import pytz
from jinja2 import BaseLoader
@@ -24,6 +24,9 @@ from apps.schemas.scheduler import (
from .prompt import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT
from .schema import LLMInput, LLMOutput
+if TYPE_CHECKING:
+ from apps.models.task import ExecutorHistory
+
logger = logging.getLogger(__name__)
@@ -68,7 +71,7 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput):
)
# 上下文信息
- step_history = []
+ step_history: list[ExecutorHistory] = []
for ids in call_vars.step_order[-self.step_history_size:]:
step_history += [call_vars.step_data[ids]]
diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py
index 73e88cb936fbddf54d4f65cc98e2136458af05d6..aaf6c9d9c8424aae82ecb72d78af544ed64d4f1e 100644
--- a/apps/scheduler/call/rag/rag.py
+++ b/apps/scheduler/call/rag/rag.py
@@ -136,19 +136,25 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput):
"""获取文档信息,支持临时文档和知识库文档"""
doc_chunk_list: list[DocItem] = []
- # 处理临时文档
+ # 处理临时文档:若docIds为空,则不请求
if doc_ids:
tmp_data = deepcopy(data)
tmp_data.kbIds = [uuid.UUID("00000000-0000-0000-0000-000000000000")]
tmp_data.docIds = doc_ids
+ _logger.info("[RAG] 获取临时文档: %s", tmp_data.docIds)
doc_chunk_list.extend(await self._fetch_doc_chunks(tmp_data))
+ else:
+ _logger.info("[RAG] docIds为空,跳过临时文档请求")
- # 处理知识库ID的情况
+ # 处理知识库:若kbIds为空,则不请求
if data.kbIds:
kb_data = deepcopy(data)
# 知识库查询时不使用docIds,只使用kbIds
kb_data.docIds = None
+ _logger.info("[RAG] 获取知识库文档: %s", kb_data.kbIds)
doc_chunk_list.extend(await self._fetch_doc_chunks(kb_data))
+ else:
+ _logger.info("[RAG] kbIds为空,跳过知识库请求")
return doc_chunk_list
diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py
index a3930924b2a0c5e5fbed4739876591e4426832f8..af9afbc14bfc6025ee367826914e64762009b910 100644
--- a/apps/scheduler/call/slot/slot.py
+++ b/apps/scheduler/call/slot/slot.py
@@ -20,6 +20,7 @@ from .prompt import SLOT_GEN_PROMPT, SLOT_HISTORY_TEMPLATE, SLOT_SUMMARY_TEMPLAT
from .schema import SlotInput, SlotOutput
if TYPE_CHECKING:
+ from apps.models.task import ExecutorHistory
from apps.scheduler.executor.step import StepExecutor
@@ -148,7 +149,7 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput):
async def _init(self, call_vars: CallVars) -> SlotInput:
"""初始化"""
- self._flow_history = []
+ self._flow_history: list[ExecutorHistory] = []
self._question = call_vars.question
for key in call_vars.step_order[:-self.step_num]:
self._flow_history += [call_vars.step_data[key]]
diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py
index eb5c594d911dbc60602121fb019d985b0cd4c356..59c82a8e02d83d3989263dd08fdbb17a401d03af 100644
--- a/apps/scheduler/executor/base.py
+++ b/apps/scheduler/executor/base.py
@@ -2,29 +2,27 @@
"""Executor基类"""
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any
+from typing import Any
from pydantic import BaseModel, ConfigDict
+from apps.common.queue import MessageQueue
from apps.common.security import Security
+from apps.llm import LLM
from apps.schemas.enum_var import EventType
from apps.schemas.message import TextAddContent
from apps.schemas.record import RecordContent
from apps.schemas.scheduler import ExecutorBackground
+from apps.schemas.task import TaskData
from apps.services.record import RecordManager
-if TYPE_CHECKING:
- from apps.common.queue import MessageQueue
- from apps.llm import LLM
- from apps.schemas.task import TaskData
-
class BaseExecutor(BaseModel, ABC):
"""Executor基类"""
- task: "TaskData"
- msg_queue: "MessageQueue"
- llm: "LLM"
+ task: TaskData
+ msg_queue: MessageQueue
+ llm: LLM
question: str
diff --git a/apps/scheduler/executor/qa.py b/apps/scheduler/executor/qa.py
index bb6228188e8fee780ff471b4a3a15e8c596a6d1e..88e2442c4cd26ea201cb9f5354de58599a2950e0 100644
--- a/apps/scheduler/executor/qa.py
+++ b/apps/scheduler/executor/qa.py
@@ -5,7 +5,7 @@ import logging
import uuid
from datetime import UTC, datetime
-from apps.models import ExecutorCheckpoint, ExecutorStatus, StepStatus
+from apps.models import ExecutorCheckpoint, ExecutorStatus, StepStatus, StepType
from apps.models.task import LanguageType
from apps.scheduler.call.rag.schema import DocItem, RAGOutput
from apps.schemas.document import DocumentInfo
@@ -23,13 +23,13 @@ _RAG_STEP_LIST = [
name="RAG检索",
description="从知识库中检索相关文档",
node="RAG",
- type="RAG",
+ type=StepType.RAG.value,
),
LanguageType.ENGLISH: Step(
name="RAG retrieval",
description="Retrieve relevant documents from the knowledge base",
node="RAG",
- type="RAG",
+ type=StepType.RAG.value,
),
},
{
@@ -37,13 +37,13 @@ _RAG_STEP_LIST = [
name="LLM问答",
description="基于检索到的文档生成答案",
node="LLM",
- type="LLM",
+ type=StepType.LLM.value,
),
LanguageType.ENGLISH: Step(
name="LLM answer",
description="Generate answer based on the retrieved documents",
node="LLM",
- type="LLM",
+ type=StepType.LLM.value,
),
},
{
@@ -51,13 +51,13 @@ _RAG_STEP_LIST = [
name="问题推荐",
description="根据对话答案,推荐相关问题",
node="Suggestion",
- type="Suggestion",
+ type=StepType.SUGGEST.value,
),
LanguageType.ENGLISH: Step(
name="Question Suggestion",
description="Display the suggested next question under the answer",
node="Suggestion",
- type="Suggestion",
+ type=StepType.SUGGEST.value,
),
},
{
@@ -151,6 +151,7 @@ class QAExecutor(BaseExecutor):
if self.task.state and self.task.state.stepStatus == StepStatus.ERROR:
_logger.error("[QAExecutor] RAG检索步骤失败")
+ self.task.state.executorStatus = ExecutorStatus.ERROR
return False
rag_output_data = None
@@ -199,6 +200,7 @@ class QAExecutor(BaseExecutor):
if self.task.state and self.task.state.stepStatus == StepStatus.ERROR:
_logger.error("[QAExecutor] LLM问答步骤失败")
+ self.task.state.executorStatus = ExecutorStatus.ERROR
return False
_logger.info("[QAExecutor] LLM问答步骤完成")
@@ -225,6 +227,7 @@ class QAExecutor(BaseExecutor):
if self.task.state and self.task.state.stepStatus == StepStatus.ERROR:
_logger.error("[QAExecutor] 问题推荐步骤失败")
+ self.task.state.executorStatus = ExecutorStatus.ERROR
return False
_logger.info("[QAExecutor] 开始执行记忆存储步骤")
@@ -246,6 +249,7 @@ class QAExecutor(BaseExecutor):
if self.task.state and self.task.state.stepStatus == StepStatus.ERROR:
_logger.error("[QAExecutor] 记忆存储步骤失败")
+ self.task.state.executorStatus = ExecutorStatus.ERROR
return False
return True
diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py
index dbec9b4505209d607ea9cceab55fcbb9bfdb7126..2d9d72d87077a455cb04d806d8bbb295960b01b1 100644
--- a/apps/scheduler/executor/step.py
+++ b/apps/scheduler/executor/step.py
@@ -6,11 +6,11 @@ import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
-from typing import TYPE_CHECKING, Any
+from typing import Any
from pydantic import ConfigDict
-from apps.models import ExecutorHistory, StepStatus, StepType
+from apps.models import ExecutorHistory, ExecutorStatus, StepStatus
from apps.scheduler.call.core import CoreCall
from apps.scheduler.call.empty import Empty
from apps.scheduler.call.facts.facts import FactsCall
@@ -24,21 +24,19 @@ from apps.schemas.enum_var import (
)
from apps.schemas.message import TextAddContent
from apps.schemas.scheduler import CallError, CallOutputChunk, ExecutorBackground
+from apps.schemas.task import StepQueueItem
from apps.services.node import NodeManager
from .base import BaseExecutor
-if TYPE_CHECKING:
- from apps.schemas.task import StepQueueItem
-
logger = logging.getLogger(__name__)
class StepExecutor(BaseExecutor):
"""工作流中步骤相关函数"""
- step: "StepQueueItem"
- background: "ExecutorBackground"
+ step: StepQueueItem
+ background: ExecutorBackground
model_config = ConfigDict(
arbitrary_types_allowed=True,
@@ -91,7 +89,7 @@ class StepExecutor(BaseExecutor):
# State写入ID和运行状态
self.task.state.stepId = self.step.step_id
- self.task.state.stepType = str(StepType(self.step.step.type))
+ self.task.state.stepType = self.step.step.type
self.task.state.stepName = self.step.step.name
# 获取并验证Call类
@@ -235,6 +233,7 @@ class StepExecutor(BaseExecutor):
except Exception as e:
logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤")
self.task.state.stepStatus = StepStatus.ERROR
+ self.task.state.executorStatus = ExecutorStatus.ERROR
await self._push_message(EventType.STEP_OUTPUT.value, {})
if isinstance(e, CallError):
self.task.state.errorMessage = {
diff --git a/apps/scheduler/scheduler/init.py b/apps/scheduler/scheduler/init.py
index 05aba7cf1fb69d80be44bded928ae7f7dd0464b7..c7020f3648b754562686a0a1a568db3c7ce2b169 100644
--- a/apps/scheduler/scheduler/init.py
+++ b/apps/scheduler/scheduler/init.py
@@ -40,8 +40,9 @@ class InitMixin:
raise ValueError(err)
return LLM(reasoning_llm)
- def _create_new_task(self, task_id: uuid.UUID, user_id: str, conversation_id: uuid.UUID | None) -> None:
+ def _create_new_task(self, user_id: str, conversation_id: uuid.UUID | None, auth_header: str) -> None:
"""创建新的TaskData"""
+ task_id = uuid.uuid4()
self.task = TaskData(
metadata=Task(
id=task_id,
@@ -50,18 +51,19 @@ class InitMixin:
),
runtime=TaskRuntime(
taskId=task_id,
+ authHeader=auth_header,
),
state=None,
context=[],
)
- async def _init_task(self, task_id: uuid.UUID, user_id: str) -> None:
+ async def _init_task(self, user_id: str, auth_header: str) -> None:
"""初始化Task"""
conversation_id = self.post_body.conversation_id
if not conversation_id:
_logger.info("[Scheduler] 无Conversation ID,直接创建新任务")
- self._create_new_task(task_id, user_id, None)
+ self._create_new_task(user_id, None, auth_header)
return
_logger.info("[Scheduler] 尝试从Conversation ID %s 恢复任务", conversation_id)
@@ -80,10 +82,8 @@ class InitMixin:
task_data = await TaskManager.get_task_data_by_task_id(last_task.id)
if task_data:
self.task = task_data
- self.task.metadata.id = task_id
- self.task.runtime.taskId = task_id
- if self.task.state:
- self.task.state.taskId = task_id
+ # 恢复任务时保留原有的 task_id,但更新 authHeader
+ self.task.runtime.authHeader = auth_header
return
else:
_logger.warning(
@@ -94,7 +94,7 @@ class InitMixin:
_logger.exception("[Scheduler] 从Conversation恢复任务失败,创建新任务")
_logger.info("[Scheduler] 无法恢复任务,创建新任务")
- self._create_new_task(task_id, user_id, conversation_id)
+ self._create_new_task(user_id, conversation_id, auth_header)
async def _get_user(self, user_id: str) -> None:
"""初始化用户"""
diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py
index 28b513dfaf718593cafc2d630be6b2c699f22009..41d4558ba394748a89b92bce5a0d27ade9ade1cd 100644
--- a/apps/scheduler/scheduler/scheduler.py
+++ b/apps/scheduler/scheduler/scheduler.py
@@ -42,17 +42,17 @@ class Scheduler(
async def init(
self,
- task_id: uuid.UUID,
queue: MessageQueue,
post_body: RequestData,
user_id: str,
+ auth_header: str,
) -> None:
"""初始化"""
self.queue = queue
self.post_body = post_body
await self._get_user(user_id)
- await self._init_task(task_id, user_id)
+ await self._init_task(user_id, auth_header)
self._init_jinja2_env()
self.llm = await self._get_scheduler_llm(post_body.llm_id)
diff --git a/apps/scheduler/scheduler/util.py b/apps/scheduler/scheduler/util.py
index 0f710155da3bf6053eef3e7128da20bcde04386f..524a576ef580ba8d64d150b8077a26c17a30146a 100644
--- a/apps/scheduler/scheduler/util.py
+++ b/apps/scheduler/scheduler/util.py
@@ -55,9 +55,9 @@ class UtilMixin:
check_interval = 0.5
while not kill_event.is_set():
- is_active = await Activity.is_active(user_id)
+ can_active = await Activity.can_active(user_id)
- if not is_active:
+ if not can_active:
_logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", user_id)
kill_event.set()
break
diff --git a/apps/schemas/message.py b/apps/schemas/message.py
index e149adb1895a44efd27133daf70ef8e0ee0bb51e..1b742e2b24269c1ce040e407d4af043de4492cf3 100644
--- a/apps/schemas/message.py
+++ b/apps/schemas/message.py
@@ -57,8 +57,8 @@ class TextAddContent(BaseModel):
class MessageBase(HeartbeatData):
"""基础消息事件结构"""
- id: uuid.UUID = Field(min_length=36, max_length=36)
- conversation_id: uuid.UUID | None = Field(min_length=36, max_length=36, alias="conversationId", default=None)
+ id: uuid.UUID
+ conversation_id: uuid.UUID | None = Field(alias="conversationId", default=None)
flow: MessageExecutor | None = None
content: Any | None = Field(default=None, description="消息内容")
metadata: MessageMetadata
diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py
index 0d8e4748dc473aa2fcdd82e3db05e58b6a36bece..cf3e62dce770e65bc349fc4e04fcc4db205a2d35 100644
--- a/apps/schemas/scheduler.py
+++ b/apps/schemas/scheduler.py
@@ -43,8 +43,8 @@ class CallVars(BaseModel):
thinking: str = Field(description="上下文信息")
question: str = Field(description="改写后的用户输入")
- step_data: dict[str, ExecutorHistory] = Field(description="Executor中历史工具的结构化数据", default={})
- step_order: list[str] = Field(description="Executor中历史工具的顺序", default=[])
+ step_data: dict[uuid.UUID, ExecutorHistory] = Field(description="Executor中历史工具的结构化数据", default={})
+ step_order: list[uuid.UUID] = Field(description="Executor中历史工具的顺序", default=[])
background: ExecutorBackground = Field(description="Executor的背景信息")
ids: CallIds = Field(description="Call的ID")
language: LanguageType = Field(description="语言", default=LanguageType.CHINESE)
diff --git a/apps/services/activity.py b/apps/services/activity.py
index c83db3266523eb11d7acfe3343c70a8aa72c2029..687e9af75bf606bb780d6713ed2996bfb13dea37 100644
--- a/apps/services/activity.py
+++ b/apps/services/activity.py
@@ -15,12 +15,12 @@ class Activity:
"""活动控制:全局并发限制,同时最多有 n 个任务在执行(与用户无关)"""
@staticmethod
- async def is_active(user_id: str) -> bool:
+ async def can_active(user_id: str) -> bool:
"""
判断系统是否达到全局并发上限
:param user_id: 用户实体ID(兼容现有接口签名)
- :return: 达到并发上限返回 True,否则 False
+ :return: 达到并发上限返回 False,否则 True
"""
time = datetime.now(tz=UTC)
@@ -32,11 +32,11 @@ class Activity:
SessionActivity.timestamp <= time,
))).one()
if count >= SLIDE_WINDOW_QUESTION_COUNT:
- return True
+ return False
# 全局并发检查:当前活跃任务数量是否达到上限
current_active = (await session.scalars(select(func.count(SessionActivity.id)))).one()
- return current_active >= MAX_CONCURRENT_TASKS
+ return current_active < MAX_CONCURRENT_TASKS
@staticmethod
async def set_active(user_id: str) -> None:
diff --git a/apps/services/settings.py b/apps/services/settings.py
index ae856b0d54c50ae33c04552c608df67d2b440ed6..c21987e697b8052f530784637a33ccdeb4345854 100644
--- a/apps/services/settings.py
+++ b/apps/services/settings.py
@@ -7,14 +7,14 @@ import logging
from sqlalchemy import select
from apps.common.postgres import postgres
-from apps.llm import embedding
+from apps.llm import LLM, embedding, json_generator
from apps.models import GlobalSettings, LLMType
from apps.scheduler.pool.pool import pool
from apps.schemas.request_data import UpdateSpecialLlmReq
from apps.schemas.response_data import SelectedSpecialLlmID
from apps.services.llm import LLMManager
-logger = logging.getLogger(__name__)
+_logger = logging.getLogger(__name__)
class SettingsManager:
@@ -34,7 +34,7 @@ class SettingsManager:
)).first()
if not settings:
- logger.warning("[SettingsManager] 全局设置不存在,返回默认值")
+ _logger.warning("[SettingsManager] 全局设置不存在,返回默认值")
return SelectedSpecialLlmID(functionLLM=None, embeddingLLM=None)
return SelectedSpecialLlmID(
@@ -60,11 +60,89 @@ class SettingsManager:
# 触发向量化
await pool.set_vector()
- logger.info("[SettingsManager] Embedding模型已更新,向量化过程已完成")
+ _logger.info("[SettingsManager] Embedding模型已更新,向量化过程已完成")
else:
- logger.error("[SettingsManager] 选择的embedding模型 %s 不存在", embedding_llm_id)
+ _logger.error("[SettingsManager] 选择的embedding模型 %s 不存在", embedding_llm_id)
except Exception:
- logger.exception("[SettingsManager] Embedding模型向量化过程失败")
+ _logger.exception("[SettingsManager] Embedding模型向量化过程失败")
+
+
+ @staticmethod
+ async def _initialize_function_llm(function_llm_id: str) -> None:
+ """
+ 初始化Function LLM
+
+ :param function_llm_id: Function LLM ID
+ :raises ValueError: 如果LLM不存在或不支持Function Call
+ """
+ function_llm = await LLMManager.get_llm(function_llm_id)
+ if not function_llm:
+ err = f"[SettingsManager] Function LLM {function_llm_id} 不存在"
+ raise ValueError(err)
+ if LLMType.FUNCTION not in function_llm.llmType:
+ err = f"[SettingsManager] LLM {function_llm_id} 不支持Function Call"
+ raise ValueError(err)
+
+ # 更新json_generator单例
+ llm_instance = LLM(function_llm)
+ json_generator.init(llm_instance)
+ _logger.info("[SettingsManager] Function LLM已初始化,json_generator已配置")
+
+
+ @staticmethod
+ async def _validate_embedding_llm(embedding_llm_id: str) -> None:
+ """
+ 验证Embedding LLM是否有效
+
+ :param embedding_llm_id: Embedding LLM ID
+ :raises ValueError: 如果LLM不存在或不支持Embedding
+ """
+ embedding_llm = await LLMManager.get_llm(embedding_llm_id)
+ if not embedding_llm:
+ err = f"[SettingsManager] Embedding LLM {embedding_llm_id} 不存在"
+ raise ValueError(err)
+ if LLMType.EMBEDDING not in embedding_llm.llmType:
+ err = f"[SettingsManager] LLM {embedding_llm_id} 不支持Embedding"
+ raise ValueError(err)
+
+
+ @staticmethod
+ async def init_global_llm_settings() -> None:
+ """
+ 初始化全局LLM设置(用于项目启动时调用)
+
+ 从数据库读取全局设置并初始化对应的LLM实例
+ """
+ async with postgres.session() as session:
+ settings = (await session.scalars(select(GlobalSettings))).first()
+
+ if not settings:
+ _logger.warning("[SettingsManager] 全局设置不存在,跳过LLM初始化")
+ return
+
+ # 初始化Function LLM
+ if settings.functionLlmId:
+ try:
+ await SettingsManager._initialize_function_llm(settings.functionLlmId)
+ _logger.info("[SettingsManager] Function LLM初始化完成: %s", settings.functionLlmId)
+ except ValueError:
+ _logger.exception("[SettingsManager] Function LLM初始化失败")
+ else:
+ _logger.info("[SettingsManager] Function LLM未配置,跳过初始化")
+
+ # 初始化Embedding LLM
+ if settings.embeddingLlmId:
+ try:
+ await SettingsManager._validate_embedding_llm(settings.embeddingLlmId)
+ # 初始化embedding配置
+ embedding_llm_config = await LLMManager.get_llm(settings.embeddingLlmId)
+ if embedding_llm_config:
+ await embedding.init(embedding_llm_config)
+ _logger.info("[SettingsManager] Embedding LLM初始化完成: %s", settings.embeddingLlmId)
+ except ValueError:
+ _logger.exception("[SettingsManager] Embedding LLM初始化失败")
+ else:
+ _logger.info("[SettingsManager] Embedding LLM未配置,跳过初始化")
@staticmethod
@@ -78,25 +156,13 @@ class SettingsManager:
:param user_id: 操作的管理员user_id
:param req: 更新请求体
"""
- # 验证functionLLM是否支持Function Call
+ # 验证并初始化functionLLM
if req.functionLLM:
- function_llm = await LLMManager.get_llm(req.functionLLM)
- if not function_llm:
- err = f"[SettingsManager] Function LLM {req.functionLLM} 不存在"
- raise ValueError(err)
- if LLMType.FUNCTION not in function_llm.llmType:
- err = f"[SettingsManager] LLM {req.functionLLM} 不支持Function Call"
- raise ValueError(err)
-
- # 验证embeddingLLM是否支持Embedding
+ await SettingsManager._initialize_function_llm(req.functionLLM)
+
+ # 验证embeddingLLM
if req.embeddingLLM:
- embedding_llm = await LLMManager.get_llm(req.embeddingLLM)
- if not embedding_llm:
- err = f"[SettingsManager] Embedding LLM {req.embeddingLLM} 不存在"
- raise ValueError(err)
- if LLMType.EMBEDDING not in embedding_llm.llmType:
- err = f"[SettingsManager] LLM {req.embeddingLLM} 不支持Embedding"
- raise ValueError(err)
+ await SettingsManager._validate_embedding_llm(req.embeddingLLM)
# 读取旧的embedding配置
old_embedding_llm = None
@@ -124,4 +190,4 @@ class SettingsManager:
task = asyncio.create_task(SettingsManager._trigger_vectorization(req.embeddingLLM))
# 添加任务完成回调,避免未处理的异常
task.add_done_callback(lambda _: None)
- logger.info("[SettingsManager] 已启动后台向量化任务")
+ _logger.info("[SettingsManager] 已启动后台向量化任务")
diff --git a/design/call/core.md b/design/call/core.md
index d165382b60fbe473859596c6d04dbac25921aa07..0cf71e5dec4f9e2b65e23bacc7cdbba2dd3126e5 100644
--- a/design/call/core.md
+++ b/design/call/core.md
@@ -312,8 +312,8 @@ classDiagram
+language: LanguageType
+ids: CallIds
+question: str
- +step_data: dict[str, ExecutorHistory]
- +step_order: list[str]
+ +step_data: dict[UUID, ExecutorHistory]
+ +step_order: list[UUID]
+background: ExecutorBackground
+thinking: str
}
@@ -565,8 +565,8 @@ CallVars是CoreCall中最重要的数据结构,包含了执行Call所需的所
| `language` | LanguageType | ✅ | 当前使用的语言类型(中文/英文) |
| `ids` | CallIds | ✅ | 包含任务ID、执行器ID、会话ID等标识信息 |
| `question` | str | ✅ | 改写或原始的用户问题 |
-| `step_data` | dict[str, ExecutorHistory] | ✅ | 步骤执行历史的字典,键为步骤ID |
-| `step_order` | list[str] | ✅ | 步骤执行的顺序列表 |
+| `step_data` | dict[UUID, ExecutorHistory] | ✅ | 步骤执行历史的字典,键为步骤ID |
+| `step_order` | list[UUID] | ✅ | 步骤执行的顺序列表 |
| `background` | ExecutorBackground | ✅ | 执行器的背景信息,包含对话历史和事实 |
| `thinking` | str | ✅ | AI的推理思考过程文本 |