diff --git a/src/app/deployment/models.py b/src/app/deployment/models.py index 440d49fb2ce6c0355ebb95fff35acfcd529b17db..1de0b5fda4a10ad2a5e7a80e5ec7f23f08b81d31 100644 --- a/src/app/deployment/models.py +++ b/src/app/deployment/models.py @@ -13,7 +13,7 @@ from enum import Enum from tool.validators import APIValidator # 常量定义 -MAX_TEMPERATURE = 2.0 +MAX_TEMPERATURE = 10.0 MIN_TEMPERATURE = 0.0 @@ -106,21 +106,21 @@ class DeploymentConfig: 验证 LLM API 连接性和功能 单独验证 LLM 配置的有效性,包括模型可用性和 function_call 支持。 - 当 LLM 的3个核心字段(endpoint、api_key、model)填完后调用。 + 当 LLM 的端点填写后调用,API Key 和模型名称允许为空。 Returns: tuple[bool, str, dict]: (是否验证成功, 消息, 验证详细信息) """ - # 检查必要字段是否完整 - if not (self.llm.endpoint.strip() and self.llm.api_key.strip() and self.llm.model.strip()): - return False, "LLM 基础配置不完整", {} + # 检查必要字段是否完整(只要求端点) + if not self.llm.endpoint.strip(): + return False, "LLM API 端点不能为空", {} validator = APIValidator() llm_valid, llm_msg, llm_info = await validator.validate_llm_config( self.llm.endpoint, - self.llm.api_key, - self.llm.model, + self.llm.api_key, # 允许为空 + self.llm.model, # 允许为空 self.llm.request_timeout, ) @@ -131,21 +131,21 @@ class DeploymentConfig: 验证 Embedding API 连接性和功能 单独验证 Embedding 配置的有效性。 - 当 Embedding 的3个核心字段(endpoint、api_key、model)填完后调用。 + 当 Embedding 的端点填写后调用,API Key 和模型名称允许为空。 Returns: tuple[bool, str, dict]: (是否验证成功, 消息, 验证详细信息) """ - # 检查必要字段是否完整 - if not (self.embedding.endpoint.strip() and self.embedding.api_key.strip() and self.embedding.model.strip()): - return False, "Embedding 基础配置不完整", {} + # 检查必要字段是否完整(只要求端点) + if not self.embedding.endpoint.strip(): + return False, "Embedding API 端点不能为空", {} validator = APIValidator() embed_valid, embed_msg, embed_info = await validator.validate_embedding_config( self.embedding.endpoint, - self.embedding.api_key, - self.embedding.model, + self.embedding.api_key, # 允许为空 + self.embedding.model, # 允许为空 self.llm.request_timeout, # 使用相同的超时设置 ) @@ -163,10 +163,6 @@ class DeploymentConfig: errors = [] if not self.llm.endpoint.strip(): errors.append("LLM API 端点不能为空") - if not self.llm.api_key.strip(): - errors.append("LLM API 密钥不能为空") - if not self.llm.model.strip(): - errors.append("LLM 模型名称不能为空") return errors def _validate_embedding_fields(self) -> list[str]: @@ -184,29 +180,12 @@ class DeploymentConfig: # 轻量部署模式下,Embedding 配置是可选的 if self.deployment_mode == "light": - # 如果用户填了任何 Embedding 字段,则所有字段都必须完整 - if has_embedding_config: - if not self.embedding.endpoint.strip(): - errors.append( - "Embedding API 端点不能为空(轻量部署模式下,如果填写 Embedding 配置,所有字段都必须完整)", - ) - if not self.embedding.api_key.strip(): - errors.append( - "Embedding API 密钥不能为空(轻量部署模式下,如果填写 Embedding 配置,所有字段都必须完整)", - ) - if not self.embedding.model.strip(): - errors.append( - "Embedding 模型名称不能为空(轻量部署模式下,如果填写 Embedding 配置,所有字段都必须完整)", - ) - # 如果没有填写,则跳过验证 - else: - # 全量部署模式下,Embedding 配置是必需的 - if not self.embedding.endpoint.strip(): + # 如果用户填了任何 Embedding 字段,则端点必须填写,API Key 和模型名称允许为空 + if has_embedding_config and not self.embedding.endpoint.strip(): errors.append("Embedding API 端点不能为空") - if not self.embedding.api_key.strip(): - errors.append("Embedding API 密钥不能为空") - if not self.embedding.model.strip(): - errors.append("Embedding 模型名称不能为空") + elif not self.embedding.endpoint.strip(): + # 全量部署模式下,Embedding 配置是必需的,但只要求端点必填 + errors.append("Embedding API 端点不能为空") return errors diff --git a/src/app/settings.py b/src/app/settings.py index 9a0ba52db5539192f248d7254c9f951167c2c368..d20048c7292e8d82ea13837c00c24c5469afcd66 100644 --- a/src/app/settings.py +++ b/src/app/settings.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from textual import on from textual.containers import Container, Horizontal +from textual.css.query import NoMatches from textual.screen import ModalScreen from textual.widgets import Button, Input, Label, Static @@ -34,7 +35,6 @@ class SettingsScreen(ModalScreen): self.config_manager = config_manager self.llm_client = llm_client self.backend = self.config_manager.get_backend() - self.models: list[str] = [] self.selected_model = self.config_manager.get_model() # 添加保存任务的集合 self.background_tasks: set[asyncio.Task] = set() @@ -88,6 +88,7 @@ class SettingsScreen(ModalScreen): else self.config_manager.get_eulerintelli_key(), classes="settings-input", id="api-key", + placeholder="API 访问密钥,可选", ), classes="settings-option", ), @@ -96,7 +97,12 @@ class SettingsScreen(ModalScreen): [ Horizontal( Label("模型:", classes="settings-label"), - Button(f"{self.selected_model}", id="model-btn", classes="settings-button"), + Input( + value=self.selected_model, + classes="settings-input", + id="model-input", + placeholder="模型名称,可选", + ), id="model-section", classes="settings-option", ), @@ -132,12 +138,7 @@ class SettingsScreen(ModalScreen): def on_mount(self) -> None: """组件挂载时加载可用模型""" - if self.backend == Backend.OPENAI: - task = asyncio.create_task(self.load_models()) - # 保存任务引用 - self.background_tasks.add(task) - task.add_done_callback(self.background_tasks.discard) - else: # EULERINTELLI + if self.backend == Backend.EULERINTELLI: task = asyncio.create_task(self.load_mcp_status()) # 保存任务引用 self.background_tasks.add(task) @@ -149,33 +150,6 @@ class SettingsScreen(ModalScreen): # 确保操作按钮始终可见 self._ensure_buttons_visible() - async def load_models(self) -> None: - """异步加载当前选中后端的可用模型列表""" - try: - # 如果是 EULERINTELLI 后端,直接返回(不需要模型选择) - if self.backend == Backend.EULERINTELLI: - return - - # 使用当前选中的客户端获取模型列表 - self.models = await self.llm_client.get_available_models() - - # 过滤掉嵌入模型,只保留语言模型 - self.models = [ - model - for model in self.models - if not any(keyword in model.lower() for keyword in ["text-embedding-", "embedding", "embed", "bge"]) - ] - - if self.models and self.selected_model not in self.models: - self.selected_model = self.models[0] - - # 更新模型按钮文本 - model_btn = self.query_one("#model-btn", Button) - model_btn.label = self.selected_model - except (OSError, ValueError, RuntimeError): - model_btn = self.query_one("#model-btn", Button) - model_btn.label = "暂无可用模型" - async def load_mcp_status(self) -> None: """异步加载 MCP 工具授权状态""" try: @@ -203,15 +177,19 @@ class SettingsScreen(ModalScreen): mcp_btn.label = "手动确认" mcp_btn.disabled = True - @on(Input.Changed, "#base-url, #api-key") + @on(Input.Changed, "#base-url, #api-key, #model-input") def on_config_changed(self) -> None: - """当 Base URL 或 API Key 改变时更新客户端并验证配置""" + """当 Base URL、API Key 或模型改变时更新客户端并验证配置""" if self.backend == Backend.OPENAI: + # 获取当前模型输入值 + try: + model_input = self.query_one("#model-input", Input) + self.selected_model = model_input.value + except NoMatches: + # 如果模型输入框不存在,跳过 + pass + self._update_llm_client() - # 重新加载模型列表 - task = asyncio.create_task(self.load_models()) - self.background_tasks.add(task) - task.add_done_callback(self.background_tasks.discard) else: # EULERINTELLI self._update_llm_client() # 重新加载 MCP 状态 @@ -255,22 +233,21 @@ class SettingsScreen(ModalScreen): spacer = self.query_one("#spacer") model_section = Horizontal( Label("模型:", classes="settings-label"), - Button(self.selected_model, id="model-btn", classes="settings-button"), + Input( + value=self.selected_model, + classes="settings-input", + id="model-input", + placeholder="模型名称,可选", + ), id="model-section", classes="settings-option", ) - # 在spacer前面添加model_section + # 在 spacer 前面添加 model_section if spacer: container.mount(model_section, before=spacer) else: container.mount(model_section) - - # 重新加载模型 - task = asyncio.create_task(self.load_models()) - # 保存任务引用 - self.background_tasks.add(task) - task.add_done_callback(self.background_tasks.discard) else: base_url.value = self.config_manager.get_eulerintelli_url() api_key.value = self.config_manager.get_eulerintelli_key() @@ -317,34 +294,6 @@ class SettingsScreen(ModalScreen): # 切换后端后重新验证配置 self._schedule_validation() - @on(Button.Pressed, "#model-btn") - def toggle_model(self) -> None: - """循环切换模型""" - if not self.models: - return - - try: - # 如果当前选择的模型在列表中,则找到它的索引 - if self.selected_model in self.models: - idx = self.models.index(self.selected_model) - idx = (idx + 1) % len(self.models) - else: - # 如果不在列表中,则从第一个模型开始 - idx = 0 - self.selected_model = self.models[idx] - - # 更新按钮文本 - model_btn = self.query_one("#model-btn", Button) - model_btn.label = self.selected_model - - # 模型改变时重新验证配置 - self._schedule_validation() - except (IndexError, ValueError): - # 处理任何可能的异常 - self.selected_model = self.models[0] if self.models else "默认模型" - model_btn = self.query_one("#model-btn", Button) - model_btn.label = self.selected_model - @on(Button.Pressed, "#mcp-btn") def toggle_mcp_authorization(self) -> None: """切换 MCP 工具授权模式""" @@ -375,6 +324,14 @@ class SettingsScreen(ModalScreen): api_key = self.query_one("#api-key", Input).value if self.backend == Backend.OPENAI: + # 获取模型输入值 + try: + model_input = self.query_one("#model-input", Input) + self.selected_model = model_input.value.strip() + except NoMatches: + # 如果模型输入框不存在,保持当前选择的模型 + pass + self.config_manager.set_base_url(base_url) self.config_manager.set_api_key(api_key) self.config_manager.set_model(self.selected_model) @@ -466,16 +423,19 @@ class SettingsScreen(ModalScreen): try: if self.backend == Backend.OPENAI: - # 检查是否有有效的模型选择 - if not self.selected_model or not self.models: - # 如果没有模型,跳过验证,不更改验证状态 - return - - # 验证 OpenAI 配置 + # 获取模型输入值(可以为空) + try: + model_input = self.query_one("#model-input", Input) + model = model_input.value.strip() + except NoMatches: + # 如果模型输入框不存在,使用当前选择的模型 + model = self.selected_model + + # 验证 OpenAI 配置(模型和 API Key 都可以为空) valid, message, _ = await self.validator.validate_llm_config( endpoint=base_url, api_key=api_key, - model=self.selected_model, + model=model, timeout=10, ) self.is_validated = valid @@ -507,9 +467,16 @@ class SettingsScreen(ModalScreen): api_key_input = self.query_one("#api-key", Input) if self.backend == Backend.OPENAI: + # 获取模型输入值,如果输入框不存在则使用当前选择的模型 + try: + model_input = self.query_one("#model-input", Input) + model = model_input.value.strip() + except NoMatches: + model = self.selected_model + self.llm_client = OpenAIClient( base_url=base_url_input.value, - model=self.selected_model, + model=model, api_key=api_key_input.value, ) else: # EULERINTELLI diff --git a/src/app/tui.py b/src/app/tui.py index 54eb6ea7337c18da8401a1928a401e8a1ac9f2b6..419a8eb4bdfa552875c84ad500279245a6967463 100644 --- a/src/app/tui.py +++ b/src/app/tui.py @@ -224,6 +224,7 @@ class IntelligentTerminal(App): Binding(key="ctrl+s", action="settings", description="设置"), Binding(key="ctrl+r", action="reset_conversation", description="重置对话"), Binding(key="ctrl+t", action="choose_agent", description="选择智能体"), + Binding(key="ctrl+c", action="interrupt", description="中断", priority=True), Binding(key="tab", action="toggle_focus", description="切换焦点"), ] @@ -336,6 +337,40 @@ class IntelligentTerminal(App): # 否则聚焦到当前的输入组件 self._focus_current_input_widget() + def action_interrupt(self) -> None: + """中断当前正在进行的操作(命令执行或AI问答)""" + if not self.processing: + # 如果当前没有正在处理的操作,只显示提示 + self.logger.debug("当前没有正在进行的操作可以中断") + return + + self.logger.info("用户请求中断当前操作") + + # 中断当前所有的后台任务 + interrupted_count = 0 + for task in list(self.background_tasks): + if not task.done(): + task.cancel() + interrupted_count += 1 + self.logger.debug("已取消后台任务") + + # 中断 LLM 客户端 + if self._llm_client is not None: + # 异步调用中断方法 + interrupt_task = asyncio.create_task(self._interrupt_llm_client()) + self.background_tasks.add(interrupt_task) + interrupt_task.add_done_callback(self._task_done_callback) + + if interrupted_count > 0: + # 显示中断消息 + output_container = self.query_one("#output-container") + interrupt_line = OutputLine("[已中断]") + output_container.mount(interrupt_line) + # 异步滚动到底部 + scroll_task = asyncio.create_task(self._scroll_to_end()) + self.background_tasks.add(scroll_task) + scroll_task.add_done_callback(self._task_done_callback) + def on_mount(self) -> None: """初始化完成时设置焦点和绑定""" # 确保初始状态是正常模式 @@ -492,6 +527,15 @@ class IntelligentTerminal(App): # 确保处理标志被重置 self.processing = False + async def _interrupt_llm_client(self) -> None: + """异步中断 LLM 客户端""" + try: + if self._llm_client is not None: + await self._llm_client.interrupt() + self.logger.info("LLM 客户端中断完成") + except Exception: + self.logger.exception("中断 LLM 客户端时出错") + async def _process_command(self, user_input: str) -> None: """异步处理命令""" try: @@ -673,10 +717,7 @@ class IntelligentTerminal(App): def _handle_cancelled_error(self, output_container: Container, stream_state: dict) -> bool: """处理取消错误""" self.logger.info("Command stream was cancelled") - received_any_content = stream_state["received_any_content"] - if received_any_content and hasattr(self, "is_running") and self.is_running: - output_container.mount(OutputLine("[处理被中断]", command=False)) - return received_any_content + return stream_state["received_any_content"] async def _process_content_chunk( self, diff --git a/src/backend/base.py b/src/backend/base.py index 472b11efb500146c5a6e6df59c23739bdb101917..65dab5c1a489431bc32c94a22e2aac29001075bf 100644 --- a/src/backend/base.py +++ b/src/backend/base.py @@ -3,9 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -from typing_extensions import Self +from typing import TYPE_CHECKING, Self if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -28,6 +26,14 @@ class LLMClientBase(ABC): """ + @abstractmethod + async def interrupt(self) -> None: + """ + 中断当前正在进行的请求 + + 子类应该实现具体的中断逻辑,例如取消 HTTP 请求、停止流式响应等。 + """ + @abstractmethod async def get_available_models(self) -> list[str]: """ diff --git a/src/backend/hermes/client.py b/src/backend/hermes/client.py index 85b532ce05ab7aeb5d6fae0bcf5630cb4b7262e3..fb8bfae40da2d92c600151e62c94894d9e7024fc 100644 --- a/src/backend/hermes/client.py +++ b/src/backend/hermes/client.py @@ -286,6 +286,15 @@ class HermesChatClient(LLMClientBase): """ await self.user_manager.update_auto_execute(auto_execute=False) + async def interrupt(self) -> None: + """ + 中断当前正在进行的请求 + + 调用后端的 stop 能力来中断当前会话。 + """ + self.logger.info("中断 Hermes 客户端当前请求") + await self._stop() + async def close(self) -> None: """关闭 HTTP 客户端""" # 如果有未完成的会话,先停止它 diff --git a/src/backend/openai.py b/src/backend/openai.py index ba0ffb7991fd44154d216bba65a7825f32efe441..ac3b9dc1cf3dbada88e6e3edcd7d7aebb6cb1217 100644 --- a/src/backend/openai.py +++ b/src/backend/openai.py @@ -1,10 +1,11 @@ """OpenAI 大模型客户端""" +import asyncio import time from collections.abc import AsyncGenerator from typing import TYPE_CHECKING -from openai import AsyncOpenAI +from openai import AsyncOpenAI, OpenAIError from backend.base import LLMClientBase from log.manager import get_logger, log_api_request, log_exception @@ -30,6 +31,9 @@ class OpenAIClient(LLMClientBase): # 添加历史记录管理 self._conversation_history: list[ChatCompletionMessageParam] = [] + # 用于中断的任务跟踪 + self._current_task: asyncio.Task | None = None + self.logger.info("OpenAI 客户端初始化成功 - URL: %s, Model: %s", base_url, model) async def get_llm_response(self, prompt: str) -> AsyncGenerator[str, None]: @@ -69,11 +73,22 @@ class OpenAIClient(LLMClientBase): # 收集助手的完整回复 assistant_response = "" - async for chunk in response: - content = chunk.choices[0].delta.content - if content: - assistant_response += content - yield content + try: + async for chunk in response: + content = chunk.choices[0].delta.content + if content: + assistant_response += content + yield content + except asyncio.CancelledError: + self.logger.info("OpenAI 流式响应被中断") + # 如果被中断,移除刚添加的用户消息 + if ( + self._conversation_history + and len(self._conversation_history) > 0 + and self._conversation_history[-1].get("content") == prompt + ): + self._conversation_history.pop() + raise # 将助手回复添加到历史记录 if assistant_response: @@ -84,7 +99,10 @@ class OpenAIClient(LLMClientBase): self._conversation_history.append(assistant_message) self.logger.info("对话历史记录已更新,当前消息数: %d", len(self._conversation_history)) - except Exception as e: + except asyncio.CancelledError: + # 重新抛出取消异常 + raise + except OpenAIError as e: # 如果请求失败,移除刚添加的用户消息 if ( self._conversation_history @@ -107,6 +125,29 @@ class OpenAIClient(LLMClientBase): error=str(e), ) raise + finally: + # 清理当前任务引用 + self._current_task = None + + async def interrupt(self) -> None: + """ + 中断当前正在进行的请求 + + 取消当前正在进行的流式请求。 + """ + if self._current_task is not None and not self._current_task.done(): + self.logger.info("中断 OpenAI 客户端当前请求") + self._current_task.cancel() + try: + await self._current_task + except asyncio.CancelledError: + self.logger.info("OpenAI 客户端请求已成功中断") + except (OSError, TimeoutError) as e: + self.logger.warning("中断 OpenAI 客户端请求时出错: %s", e) + finally: + self._current_task = None + else: + self.logger.debug("OpenAI 客户端当前无正在进行的请求") def reset_conversation(self) -> None: """ @@ -122,6 +163,7 @@ class OpenAIClient(LLMClientBase): 获取当前 LLM 服务中可用的模型,返回名称列表 调用 LLM 服务的模型列表接口,并解析返回结果提取模型名称。 + 如果服务不支持模型列表接口,返回空列表。 """ start_time = time.time() self.logger.info("开始请求 OpenAI 模型列表 API") @@ -139,10 +181,9 @@ class OpenAIClient(LLMClientBase): duration, model_count=len(models), ) - except Exception as e: + except OpenAIError as e: duration = time.time() - start_time log_exception(self.logger, "OpenAI 模型列表 API 请求失败", e) - # 记录失败的API请求 log_api_request( self.logger, "GET", @@ -151,7 +192,7 @@ class OpenAIClient(LLMClientBase): duration, error=str(e), ) - raise + return [] else: self.logger.info("获取到 %d 个可用模型", len(models)) return models @@ -161,6 +202,6 @@ class OpenAIClient(LLMClientBase): try: await self.client.close() self.logger.info("OpenAI 客户端已关闭") - except Exception as e: + except OpenAIError as e: log_exception(self.logger, "关闭 OpenAI 客户端失败", e) raise diff --git a/src/tool/__init__.py b/src/tool/__init__.py index f54433633e9f605de28cf4c110013377f29c48b1..0f42f25eb572922b22304afbd21da138a900f737 100644 --- a/src/tool/__init__.py +++ b/src/tool/__init__.py @@ -1,6 +1,6 @@ """工具模块""" -from .command_processor import execute_command, is_command_safe, process_command +from .command_processor import is_command_safe, process_command from .oi_backend_init import oi_backend_init -__all__ = ["execute_command", "is_command_safe", "oi_backend_init", "process_command"] +__all__ = ["is_command_safe", "oi_backend_init", "process_command"] diff --git a/src/tool/command_processor.py b/src/tool/command_processor.py index d270eb842b2521422fcb6f202514ef5813fb2c31..17a2a44af8df1bcb552d10db635af43635bc9761 100644 --- a/src/tool/command_processor.py +++ b/src/tool/command_processor.py @@ -1,7 +1,15 @@ -"""命令处理器""" +""" +命令处理器 +功能说明: +1. 异步流式执行系统命令: 逐行输出 STDOUT。 +2. 结束后输出总结状态(退出码,成功/失败)。 +3. 失败时自动向 LLM 请求分析建议并继续流式输出建议。 +""" + +import asyncio +import logging import shutil -import subprocess from collections.abc import AsyncGenerator from backend.base import LLMClientBase @@ -21,30 +29,6 @@ def is_command_safe(command: str) -> bool: return all(dangerous not in command for dangerous in BLACKLIST) -def execute_command(command: str) -> tuple[bool, str]: - """ - 执行命令并返回结果 - - 尝试执行命令: - 返回 (True, 命令标准输出) 或 (False, 错误信息)。 - """ - try: - result = subprocess.run( # noqa: S602 - command, - shell=True, - capture_output=True, - text=True, - timeout=1800, - check=False, - ) - success = result.returncode == 0 - output = result.stdout if success else result.stderr - except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError) as e: - return False, str(e) - else: - return success, output - - async def process_command(command: str, llm_client: LLMClientBase) -> AsyncGenerator[tuple[str, bool], None]: """ 处理用户输入的命令 @@ -66,31 +50,172 @@ async def process_command(command: str, llm_client: LLMClientBase) -> AsyncGener return prog = tokens[0] - if shutil.which(prog) is not None: - logger.info("检测到系统命令: %s", prog) - # 检查命令安全性 - if not is_command_safe(command): - logger.warning("命令被安全检查阻止: %s", command) - yield ("检测到不安全命令,已阻止执行。", True) # 作为LLM输出处理 - return - - logger.info("执行系统命令: %s", command) - success, output = execute_command(command) - if success: - logger.debug("命令执行成功,输出长度: %d", len(output)) - yield (output, False) # 系统命令输出,使用纯文本 - else: - # 执行失败,将错误信息反馈给大模型 - logger.info("命令执行失败,向 LLM 请求建议") - query = f"命令 '{command}' 执行失败,错误信息如下:\n{output}\n请帮忙分析原因并提供解决建议。" - async for suggestion in llm_client.get_llm_response(query): - # 检查是否为 MCP 状态消息,如果是则使用特殊的处理方式 - is_mcp_message_flag = is_mcp_message(suggestion) - yield (suggestion, not is_mcp_message_flag) # MCP消息不作为LLM输出处理 - else: - # 不是已安装的命令,直接询问大模型 + if shutil.which(prog) is None: + # 非系统命令 -> 直接走 LLM logger.debug("向 LLM 发送问题: %s", command) - async for suggestion in llm_client.get_llm_response(command): - # 检查是否为 MCP 状态消息,如果是则使用特殊的处理方式 - is_mcp_message_flag = is_mcp_message(suggestion) - yield (suggestion, not is_mcp_message_flag) # MCP消息不作为LLM输出处理 + try: + async for suggestion in llm_client.get_llm_response(command): + is_mcp_message_flag = is_mcp_message(suggestion) + yield (suggestion, not is_mcp_message_flag) + except asyncio.CancelledError: + logger.info("LLM 响应被用户中断") + raise + return + + logger.info("检测到系统命令: %s", prog) + if not is_command_safe(command): + logger.warning("命令被安全检查阻止: %s", command) + yield ("检测到不安全命令,已阻止执行。", True) + return + + # 流式执行 + try: + async for item in _stream_system_command(command, llm_client, logger): + yield item + except asyncio.CancelledError: + logger.info("命令执行被用户中断") + raise + + +async def _stream_system_command( + command: str, + llm_client: LLMClientBase, + logger: logging.Logger, +) -> AsyncGenerator[tuple[str, bool], None]: + """ + 流式执行系统命令。 + + 逐行产出 STDOUT (is_llm_output=False)。结束后追加一条状态行: 成功 / 失败。 + 若失败随后继续产出 LLM 建议 (is_llm_output=True,除非是 MCP 消息)。 + 支持中断处理,会正确终止子进程。 + """ + logger.info("(流式) 执行系统命令: %s", command) + + # 创建子进程 + proc = await _create_subprocess(command, logger) + if proc is None: + async for item in _handle_subprocess_creation_error(command, llm_client): + yield item + return + + # 执行命令并处理输出 + try: + async for item in _execute_and_stream_output(proc, command, llm_client, logger): + yield item + except asyncio.CancelledError: + await _handle_process_interruption(proc, logger) + raise + + +async def _create_subprocess(command: str, logger: logging.Logger) -> asyncio.subprocess.Process | None: + """创建子进程,返回 None 表示创建失败""" + try: + return await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + except OSError: + logger.exception("创建子进程失败") + return None + + +async def _handle_subprocess_creation_error( + command: str, + llm_client: LLMClientBase, +) -> AsyncGenerator[tuple[str, bool], None]: + """处理子进程创建失败的情况""" + yield ("[命令启动失败] 无法创建子进程", False) + query = f"无法启动命令 '{command}',请分析可能原因并给出解决建议。" + async for suggestion in llm_client.get_llm_response(query): + is_mcp_message_flag = is_mcp_message(suggestion) + yield (suggestion, not is_mcp_message_flag) + + +async def _execute_and_stream_output( + proc: asyncio.subprocess.Process, + command: str, + llm_client: LLMClientBase, + logger: logging.Logger, +) -> AsyncGenerator[tuple[str, bool], None]: + """执行命令并流式输出结果""" + assert proc.stdout is not None # 类型提示 + + # 流式读取输出 + while True: + line = await proc.stdout.readline() + if not line: + break + # CR -> LF 规范化 + text = line.decode(errors="replace").replace("\r\n", "\n").replace("\r", "\n") + yield (text, False) + + # 等待进程结束 + returncode = await proc.wait() + success = returncode == 0 + + if success: + yield (f"\n[命令完成] 退出码: {returncode}", False) + return + + # 处理命令失败的情况 + async for item in _handle_command_failure(proc, command, returncode, llm_client, logger): + yield item + + +async def _handle_command_failure( + proc: asyncio.subprocess.Process, + command: str, + returncode: int, + llm_client: LLMClientBase, + logger: logging.Logger, +) -> AsyncGenerator[tuple[str, bool], None]: + """处理命令执行失败的情况""" + # 读取 stderr + stderr_text = await _read_stderr(proc) + yield (f"[命令失败] 退出码: {returncode}", False) + + # 获取 LLM 建议 + logger.info("命令执行失败(returncode=%s),向 LLM 请求建议", returncode) + query = ( + f"命令 '{command}' 以非零状态 {returncode} 退出。\n" + f"标准错误输出如下:\n{stderr_text}\n" + "请分析原因并提供解决建议。" + ) + async for suggestion in llm_client.get_llm_response(query): + is_mcp_message_flag = is_mcp_message(suggestion) + yield (suggestion, not is_mcp_message_flag) + + +async def _read_stderr(proc: asyncio.subprocess.Process) -> str: + """读取进程的标准错误输出""" + if proc.stderr is None: + return "" + + try: + stderr_bytes = await proc.stderr.read() + return stderr_bytes.decode(errors="replace") + except (OSError, asyncio.CancelledError): + return "读取 stderr 失败" + + +async def _handle_process_interruption(proc: asyncio.subprocess.Process, logger: logging.Logger) -> None: + """处理进程中断,确保正确终止子进程""" + logger.info("命令执行被中断,正在终止子进程") + + if proc.returncode is not None: + return # 进程已经结束 + + # 尝试正常终止进程 + proc.terminate() + try: + await asyncio.wait_for(proc.wait(), timeout=5.0) + logger.info("子进程已正常终止") + except TimeoutError: + # 强制杀死进程 + logger.warning("子进程未在5秒内终止,强制杀死") + proc.kill() + try: + await asyncio.wait_for(proc.wait(), timeout=2.0) + except TimeoutError: + logger.exception("无法强制终止子进程") diff --git a/src/tool/validators.py b/src/tool/validators.py index d13f38994410d559aaffa6841363d44d67f73cda..26f8dc57b036caa5fdbda292e3cac083324e04f7 100644 --- a/src/tool/validators.py +++ b/src/tool/validators.py @@ -51,47 +51,34 @@ class APIValidator: try: client = AsyncOpenAI(api_key=api_key, base_url=endpoint, timeout=timeout) - # 1. 验证模型是否存在 - models_valid, models_msg, available_models = await self._check_model_availability( - client, - model, - ) - if not models_valid: - await client.close() - return False, models_msg, {"available_models": available_models} - - # 2. 验证基本对话功能 + # 测试基本对话功能 chat_valid, chat_msg = await self._test_basic_chat(client, model) if not chat_valid: await client.close() - return False, chat_msg, {"available_models": available_models} + return False, chat_msg, {} - # 3. 验证 function_call 支持 + # 测试 function_call 支持 func_valid, func_msg, func_info = await self._test_function_call(client, model) await client.close() - if chat_valid: - success_msg = "LLM 配置验证成功 - 模型" - if func_valid: - success_msg += "支持 function_call" - else: - success_msg += f"不支持 function_call: {func_msg}" - - return True, success_msg, { - "available_models": available_models, - "supports_function_call": func_valid, - "function_call_info": func_info, - } - except TimeoutError: return False, f"连接超时 - 无法在 {timeout} 秒内连接到 {endpoint}", {} except (AuthenticationError, APIError, OpenAIError) as e: error_msg = f"LLM 配置验证失败: {e!s}" self.logger.exception(error_msg) return False, error_msg, {} - - return False, "未知错误", {} + else: + success_msg = "LLM 配置验证成功" + if func_valid: + success_msg += " - 支持 function_call" + else: + success_msg += f" - 不支持 function_call: {func_msg}" + + return True, success_msg, { + "supports_function_call": func_valid, + "function_call_info": func_info, + } async def validate_embedding_config( self, @@ -141,31 +128,6 @@ class APIValidator: return False, "Embedding 响应为空", {} - async def _check_model_availability( - self, - client: AsyncOpenAI, - target_model: str, - ) -> tuple[bool, str, list[str]]: - """检查模型是否可用""" - try: - models_response = await client.models.list() - available_models = [model.id for model in models_response.data] - - if target_model in available_models: - return True, f"模型 {target_model} 可用", available_models - - return ( - False, - ( - f"模型 {target_model} 不可用。可用模型: " - f"{', '.join(available_models[:MAX_MODEL_DISPLAY])}" - f"{'...' if len(available_models) > MAX_MODEL_DISPLAY else ''}" - ), - available_models, - ) - except (AuthenticationError, APIError, OpenAIError) as e: - return False, f"获取模型列表失败: {e!s}", [] - async def _test_basic_chat( self, client: AsyncOpenAI,