diff --git a/apps/llm/function.py b/apps/llm/function.py index b8a629c9bc7d79b9795e4b02c598c7b43a7f9188..b23ad1caf377559e1b75ccb954484246c6dc1bce 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -37,10 +37,10 @@ class FunctionLLM: if config["SCHEDULER_BACKEND"] == "vllm" or config["SCHEDULER_BACKEND"] == "openai": import openai if not config["SCHEDULER_API_KEY"]: - self._client = openai.AsyncOpenAI(base_url=config["SCHEDULER_URL"]) + self._client = openai.AsyncOpenAI(base_url=config["SCHEDULER_URL"] + "/v1") else: self._client = openai.AsyncOpenAI( - base_url=config["SCHEDULER_URL"], + base_url=config["SCHEDULER_URL"] + "/v1", api_key=config["SCHEDULER_API_KEY"], ) @@ -199,7 +199,7 @@ class FunctionLLM: sglang.set_default_backend(self._client) sglang_func = sglang.function(self._sglang_func) - state = await asyncify(sglang_func.run)(messages, schema, max_tokens, temperature) + state = await asyncify(sglang_func.run)(messages, schema, max_tokens, temperature) #type: ignore[arg-type] return state["output"] diff --git a/apps/manager/document.py b/apps/manager/document.py index 9193ed7532727f0d7bd8452c966a8b68c4e6c957..8dff7d40b6840d3545014f420c284faa42920ef6 100644 --- a/apps/manager/document.py +++ b/apps/manager/document.py @@ -119,7 +119,8 @@ class DocumentManager: logger.error("[DocumentManager] 记录组不存在: %s", record_group_id) return [] - doc_ids = RecordGroup.model_validate(record_group).docs + docs = RecordGroup.model_validate(record_group).docs + doc_ids = [doc.id for doc in docs] doc_infos = [Document.model_validate(doc) async for doc in docs_collection.find({"_id": {"$in": doc_ids}})] return [ RecordDocument( @@ -129,7 +130,7 @@ class DocumentManager: size=item[1].size, conversation_id=item[1].conversation_id, associated=item[0].associated, - ) for item in zip(doc_ids, doc_infos) + ) for item in zip(docs, doc_infos) ] except Exception: logger.exception("[DocumentManager] 获取使用文件失败") diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 6a7140d6bf8b69426c921519ae2c1d7792721d99..422ccbe845d40db8bb2493152960d597317721c4 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -24,17 +24,18 @@ logger = logging.getLogger("ray") async def get_docs(user_sub: str, post_body: RequestData) -> tuple[Union[list[RecordDocument], list[Document]], list[str]]: """获取当前问答可供关联的文档""" doc_ids = [] - if post_body.group_id: - # 是重新生成,直接从RecordGroup中获取 - docs = await DocumentManager.get_used_docs_by_record_group(user_sub, post_body.group_id) - doc_ids += [doc.id for doc in docs] - else: + + docs = await DocumentManager.get_used_docs_by_record_group(user_sub, post_body.group_id) + if not docs: # 是新提问 # 从Conversation中获取刚上传的文档 docs = await DocumentManager.get_unused_docs(user_sub, post_body.conversation_id) # 从最近10条Record中获取文档 docs += await DocumentManager.get_used_docs(user_sub, post_body.conversation_id, 10) doc_ids += [doc.id for doc in docs] + else: + # 是重新生成 + doc_ids += [doc.id for doc in docs] return docs, doc_ids diff --git a/apps/utils/flow.py b/apps/utils/flow.py index b8dfbe2870db9c5ee10bc11d8594462ee3300854..ef2914ba8334bb431f779b4ace6e32168fa94b27 100644 --- a/apps/utils/flow.py +++ b/apps/utils/flow.py @@ -6,7 +6,7 @@ import logging import queue from apps.entities.enum_var import NodeType -from apps.entities.flow_topology import FlowItem +from apps.entities.flow_topology import EdgeItem, FlowItem, NodeItem logger = logging.getLogger("ray") @@ -52,63 +52,97 @@ class FlowService: return flow_item @staticmethod - async def validate_flow_illegal(flow_item: FlowItem) -> None: - """验证流程图是否合法""" - step_id_set = set() - edge_id_set = set() - edge_to_branch = dict() - num_of_start_node = 0 - num_of_end_node = 0 - id_of_start_node = None - id_of_end_node = None - node_in_degrees = {} - node_out_degrees = {} - for node in flow_item.nodes: - if node.step_id in step_id_set: + async def _validate_node_ids(nodes: list[NodeItem]) -> tuple[str, str]: + """验证节点ID的唯一性并获取起始和终止节点ID,当节点ID重复或起始/终止节点数量不为1时抛出异常""" + ids = set() + start_cnt = 0 + end_cnt = 0 + start_id = None + end_id = None + + for node in nodes: + if node.step_id in ids: err = f"[FlowService] 节点{node.name}的id重复" logger.error(err) raise Exception(err) - step_id_set.add(node.step_id) + ids.add(node.step_id) if node.call_id == NodeType.START.value: - num_of_start_node += 1 - id_of_start_node = node.step_id + start_cnt += 1 + start_id = node.step_id if node.call_id == NodeType.END.value: - num_of_end_node += 1 - id_of_end_node = node.step_id - if num_of_start_node != 1 or num_of_end_node != 1: + end_cnt += 1 + end_id = node.step_id + + if start_cnt != 1 or end_cnt != 1: err = "[FlowService] 起始节点和终止节点数量不为1" logger.error(err) raise Exception(err) - for edge in flow_item.edges: - if edge.edge_id in edge_id_set: - err = f"[FlowService] 边{edge.edge_id}的id重复" + + if start_id is None or end_id is None: + err = "[FlowService] 起始节点或终止节点ID为空" + logger.error(err) + raise Exception(err) + + return start_id, end_id + + @staticmethod + async def _validate_edges(edges: list[EdgeItem]) -> tuple[dict[str, int], dict[str, int]]: + """验证边的合法性并计算节点的入度和出度;当边的ID重复、起始终止节点相同或分支重复时抛出异常""" + ids = set() + branches = {} + in_deg = {} + out_deg = {} + + for e in edges: + if e.edge_id in ids: + err = f"[FlowService] 边{e.edge_id}的id重复" logger.error(err) raise Exception(err) - edge_id_set.add(edge.edge_id) - if edge.source_node == edge.target_node: - err = f"[FlowService] 边{edge.edge_id}的起始节点和终止节点相同" + ids.add(e.edge_id) + + if e.source_node == e.target_node: + err = f"[FlowService] 边{e.edge_id}的起始节点和终止节点相同" logger.error(err) raise Exception(err) - if edge.source_node not in edge_to_branch: - edge_to_branch[edge.source_node] = set() - if edge.branch_id in edge_to_branch[edge.source_node]: - err = f"[FlowService] 边{edge.edge_id}的分支{edge.branch_id}重复" + + if e.source_node not in branches: + branches[e.source_node] = set() + if e.branch_id in branches[e.source_node]: + err = f"[FlowService] 边{e.edge_id}的分支{e.branch_id}重复" logger.error(err) raise Exception(err) - edge_to_branch[edge.source_node].add(edge.branch_id) - node_in_degrees[edge.target_node] = node_in_degrees.get( - edge.target_node, 0) + 1 - node_out_degrees[edge.source_node] = node_out_degrees.get( - edge.source_node, 0) + 1 - if id_of_start_node in node_in_degrees and node_in_degrees[id_of_start_node] != 0: - err = f"[FlowService] 起始节点{id_of_start_node}的入度不为0" + branches[e.source_node].add(e.branch_id) + + in_deg[e.target_node] = in_deg.get(e.target_node, 0) + 1 + out_deg[e.source_node] = out_deg.get(e.source_node, 0) + 1 + + return in_deg, out_deg + + @staticmethod + async def _validate_node_degrees(start_id: str, end_id: str, + in_deg: dict[str, int], out_deg: dict[str, int]) -> None: + """验证起始和终止节点的入度和出度;当起始节点入度不为0或终止节点出度不为0时抛出异常""" + if start_id in in_deg and in_deg[start_id] != 0: + err = f"[FlowService] 起始节点{start_id}的入度不为0" logger.error(err) raise Exception(err) - if id_of_end_node in node_out_degrees and node_out_degrees[id_of_end_node] != 0: - err = f"[FlowService] 终止节点{id_of_end_node}的出度不为0" + if end_id in out_deg and out_deg[end_id] != 0: + err = f"[FlowService] 终止节点{end_id}的出度不为0" logger.error(err) raise Exception(err) + @staticmethod + async def validate_flow_illegal(flow_item: FlowItem) -> None: + """验证流程图是否合法;当流程图不合法时抛出异常""" + # 验证节点ID并获取起始和终止节点 + start_id, end_id = await FlowService._validate_node_ids(flow_item.nodes) + + # 验证边的合法性并获取节点的入度和出度 + in_deg, out_deg = await FlowService._validate_edges(flow_item.edges) + + # 验证起始和终止节点的入度和出度 + await FlowService._validate_node_degrees(start_id, end_id, in_deg, out_deg) + @staticmethod async def validate_flow_connectivity(flow_item: FlowItem) -> None: id_of_start_node = None