From ace7bb37e55bc71186b5876152a6546504c57763 Mon Sep 17 00:00:00 2001 From: houxu Date: Thu, 14 Aug 2025 10:19:58 +0800 Subject: [PATCH] update mcp parameter --- systrace_mcp/setup.py | 3 +- .../systrace_mcp/fail_slow_detection_api.py | 3 +- systrace_mcp/systrace_mcp/mcp_data.py | 77 +++++++++++++++---- systrace_mcp/systrace_mcp/mcp_server.py | 40 ++++++---- systrace_mcp/systrace_mcp/openapi_server.py | 54 ++++++------- systrace_mcp/systrace_mcp/report_api.py | 32 +++----- 6 files changed, 125 insertions(+), 84 deletions(-) diff --git a/systrace_mcp/setup.py b/systrace_mcp/setup.py index 20fe073..8d4954f 100644 --- a/systrace_mcp/setup.py +++ b/systrace_mcp/setup.py @@ -35,7 +35,8 @@ setup( install_requires=[ "systrace_failslow", "mcp", - "paramiko" + "paramiko", + "fastapi" ], entry_points={ "console_scripts": [ diff --git a/systrace_mcp/systrace_mcp/fail_slow_detection_api.py b/systrace_mcp/systrace_mcp/fail_slow_detection_api.py index 2d88d76..2556b7d 100644 --- a/systrace_mcp/systrace_mcp/fail_slow_detection_api.py +++ b/systrace_mcp/systrace_mcp/fail_slow_detection_api.py @@ -63,8 +63,7 @@ def detect_step_time_anomalies(data_df: pd.DataFrame, model_args: Dict): next_anomaly_degree = next_diff / anomaly_degree_thr if next_anomaly_degree > anomaly_degree_thr: anomalies.append( - {"training_step": i, - "anomaly_time": datetime.fromtimestamp(timestamps[i]/1000).strftime('%Y-%m-%d %H:%M:%S'), + {"training_step": i, "anomaly_time": datetime.fromtimestamp(timestamps[i]/1000).strftime('%Y-%m-%d %H:%M:%S'), "anomaly_degree": round(anomaly_degree, 3), "anomaly_training_time": f"{current_step_time}ms", "normal_training_time": f"{moving_average}ms"}) diff --git a/systrace_mcp/systrace_mcp/mcp_data.py b/systrace_mcp/systrace_mcp/mcp_data.py index 4ab38a4..04837c4 100644 --- a/systrace_mcp/systrace_mcp/mcp_data.py +++ b/systrace_mcp/systrace_mcp/mcp_data.py @@ -1,20 +1,65 @@ -from typing_extensions import TypedDict, List -from dataclasses import dataclass, field -from typing import List, Dict, Any +from typing import List +from pydantic import BaseModel, Field +from enum import Enum -class AnomalyInfo(TypedDict): + +class ReportType(str, Enum): + normal = "normal" + anomaly = "anomaly" + + +class AnomalyInfo(BaseModel): """劣化详细信息结构""" - metric_name: str #是否发生性能劣化 - threshold: float - actual_value: float - timestamp: int + training_step: int = Field(default=0, description="训练步骤(默认0)") + anomaly_time: str = Field(default="", description="劣化时间(默认空字符串)") + anomaly_degree: float = Field(default=0.0, description="劣化程度(默认0.0)") + anomaly_training_time: str = Field(default="", description="劣化训练step时间(默认空字符串)") + normal_training_time: str = Field(default="", description="正常训练step时间(默认空字符串)") -class PerceptionResult(TypedDict): + +class PerceptionResult(BaseModel): """慢节点感知结果结构""" - is_anomaly: bool #是否发生性能劣化 - anomaly_count_times: int #劣化次数 - anomaly_info: List[AnomalyInfo] #劣化详细信息 - start_time: int # Unix timestamp in milliseconds 劣化开始时间 - end_time: int # Unix timestamp in milliseconds 劣化结束时间 - anomaly_type: str # 劣化类型 - task_id: str #服务ip + is_anomaly: bool = Field(default=False, description="是否发生性能劣化(默认false)") + anomaly_count_times: int = Field(default=0, description="劣化次数(默认0)") + # 列表类型使用 default_factory 避免 mutable 默认值问题 + anomaly_info: List[AnomalyInfo] = Field( + default_factory=list, + description="劣化详细信息(默认空列表)" + ) + start_time: int = Field(default=0, description="劣化开始时间(默认0,单位毫秒)") + end_time: int = Field(default=0, description="劣化结束时间(默认0,单位毫秒)") + anomaly_type: str = Field(default="", description="劣化类型(默认空字符串)") + task_id: str = Field(default="", description="服务器ip(默认空字符串)") + + +class DetailItem(BaseModel): + objectId: str = Field(default="", alias="objectId", description="对象ID") + serverIp: str = Field(default="", alias="serverIp", description="服务器IP") + deviceInfo: str = Field(default="", alias="deviceInfo", description="设备信息") + kpiId: str = Field(default="", alias="kpiId", description="KPI指标ID") + methodType: str = Field(default="", alias="methodType", description="方法类型") + kpiData: list = Field(default_factory=list, alias="kpiData", description="KPI数据列表") + relaIds: List[int] = Field(default_factory=list, alias="relaIds", description="关联ID列表") + omittedDevices: list = Field(default_factory=list, alias="omittedDevices", description="忽略的设备列表") + + +# 主模型引用嵌套模型(为所有字段添加默认值) +class AIJobDetectResult(BaseModel): + timestamp: int = Field(default=0, description="时间戳(默认0)") + result_code: int = Field(default=0, alias="resultCode", description="结果编码(默认0)") + compute: bool = Field(default=False, description="计算状态(默认False)") + network: bool = Field(default=False, description="网络状态(默认False)") + storage: bool = Field(default=False, description="存储状态(默认False)") + # 列表类型推荐用 default_factory=list 而非 [],避免 mutable 默认值的潜在问题 + abnormal_detail: List[DetailItem] = Field( + default_factory=list, + alias="abnormalDetail", + description="异常详情列表(默认空列表)" + ) + normal_detail: List[DetailItem] = Field( + default_factory=list, + alias="normalDetail", + description="正常详情列表(默认空列表)" + ) + error_msg: str = Field(default="", alias="errorMsg", description="错误信息(默认空字符串)") + diff --git a/systrace_mcp/systrace_mcp/mcp_server.py b/systrace_mcp/systrace_mcp/mcp_server.py index bed7f4e..ada6926 100644 --- a/systrace_mcp/systrace_mcp/mcp_server.py +++ b/systrace_mcp/systrace_mcp/mcp_server.py @@ -4,13 +4,12 @@ import json from mcp.server import FastMCP -from failslow.response.response import AIJobDetectResult from failslow.util.logging_utils import get_default_logger from failslow.util.constant import MODEL_CONFIG_PATH from failslow.main import main as slow_node_detection_api -from systrace_mcp.report_api import generate_normal_report, generate_degraded_report, generate_default_report -from systrace_mcp.mcp_data import PerceptionResult +from systrace_mcp.report_api import generate_normal_report, generate_degraded_report +from systrace_mcp.mcp_data import PerceptionResult,ReportType,AIJobDetectResult from systrace_mcp.fail_slow_detection_api import run_slow_node_perception from systrace_mcp.remote_file_fetcher import sync_server_by_ip_and_type @@ -18,6 +17,7 @@ logger = get_default_logger(__name__) # 仅在 Linux 环境下强制使用 spawn 方式 import multiprocessing import os + if os.name == "posix": # posix 表示 Linux/macOS multiprocessing.set_start_method("spawn", force=True) # 创建MCP Server @@ -48,11 +48,13 @@ def slow_node_perception_tool(task_id: str) -> PerceptionResult: with open(MODEL_CONFIG_PATH, 'r', encoding='utf-8') as reader: model_args = json.load(reader) sync_server_by_ip_and_type(task_id, "perception") - res = run_slow_node_perception(model_args,task_id) - return res + _res = run_slow_node_perception(model_args, task_id) + _res = PerceptionResult.model_validate(_res) + return _res -@mcp.prompt(description="调用逻辑:1. 仅在感知工具返回is_anomaly=True时调用。2. 接收感知工具的全量性能数据作为输入。 3. 本方法得到的结果必须再调用generate_report 生成报告给到用户") +@mcp.prompt( + description="调用逻辑:1. 仅在感知工具返回is_anomaly=True时调用。2. 接收感知工具的全量性能数据作为输入。 3. 本方法得到的结果必须再调用generate_report 生成报告给到用户") @mcp.tool(name="slow_node_detection_tool") def slow_node_detection_tool(performance_data: PerceptionResult) -> AIJobDetectResult: """ @@ -63,15 +65,17 @@ def slow_node_detection_tool(performance_data: PerceptionResult) -> AIJobDetectR """ print("慢卡定界工具") print("performance_data = " + str(performance_data)) - sync_server_by_ip_and_type(performance_data["task_id"], "detection") + print("task_id = " + performance_data.task_id) + sync_server_by_ip_and_type(performance_data.task_id, "detection") _res = slow_node_detection_api() - print(json.dumps(_res)) + _res = AIJobDetectResult.model_validate(_res) + print("result = " + str(_res)) return _res @mcp.prompt(description="调用slow_node_perception_tool 或 slow_node_detection_tool 后把结果传入generate_report ") @mcp.tool() -def generate_report_tool(source_data: Union[dict, str], report_type: str) -> dict: +def generate_report_tool(source_data: Union[PerceptionResult, AIJobDetectResult], report_type: ReportType) -> dict: """ 使用 报告工具:生成最终Markdown格式报告 输入: @@ -87,18 +91,22 @@ def generate_report_tool(source_data: Union[dict, str], report_type: str) -> dic 2、细节:每条节点的具体卡号{objectId}、异常指标{kpiId}(其中:HcclAllGather表示集合通信库的AllGather时序序列指标;HcclReduceScatter表示集合通信库的ReduceScatter时序序列指标;HcclAllReduce表示集合通信库的AllReduce时序序列指标;),检测方法{methodType}(SPACE 多节点空间对比检测器,TIME 单节点时间检测器),以表格形式呈现; 3、针对这个节点给出检测建议,如果是计算类型,建议检测卡的状态,算子下发以及算子执行的代码,对慢节点进行隔离;如果是网络问题,建议检测组网的状态,使用压测节点之间的连通状态;如果是存储问题,建议检测存储的磁盘以及用户脚本中的dataloader和保存模型代码。 """ - print("调用了报告工具,report_type = " + report_type) + print("调用了报告工具,report_type = " + report_type.value) # 根据报告类型调用对应的生成方法 - if report_type == "normal": - return json.dumps(generate_normal_report(source_data)) - elif report_type == "anomaly": - return json.dumps(generate_degraded_report(source_data)) + if report_type == ReportType.normal: + result = generate_normal_report(source_data) + elif report_type == ReportType.anomaly: + result = generate_degraded_report(source_data) else: - # 默认报告类型 - return generate_default_report(source_data) + raise Exception("不支持的报告类型") + print("报告:", result) + return result + def main(): # 初始化并启动服务 mcp.run(transport='sse') + + if __name__ == "__main__": main() diff --git a/systrace_mcp/systrace_mcp/openapi_server.py b/systrace_mcp/systrace_mcp/openapi_server.py index a352d33..92d73cf 100644 --- a/systrace_mcp/systrace_mcp/openapi_server.py +++ b/systrace_mcp/systrace_mcp/openapi_server.py @@ -1,23 +1,22 @@ import json import os -from time import sleep from typing import Union, Dict, Any, Optional from pydantic import BaseModel import uvicorn from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware -from failslow.response.response import AIJobDetectResult from failslow.main import main as slow_node_detection_api from failslow.util.logging_utils import get_default_logger from failslow.util.constant import MODEL_CONFIG_PATH -from systrace_mcp.mcp_data import PerceptionResult +from systrace_mcp.mcp_data import PerceptionResult, AIJobDetectResult from systrace_mcp.fail_slow_detection_api import run_slow_node_perception from systrace_mcp.remote_file_fetcher import sync_server_by_ip_and_type -from systrace_mcp.report_api import generate_normal_report, generate_degraded_report, generate_default_report +from systrace_mcp.report_api import generate_normal_report, generate_degraded_report # 仅在 Linux 环境下强制使用 spawn 方式 import multiprocessing + if os.name == "posix": # posix 表示 Linux/macOS multiprocessing.set_start_method("spawn", force=True) # 初始化日志 @@ -58,10 +57,11 @@ def slow_node_perception_tool(task_id: str) -> PerceptionResult: try: sync_server_by_ip_and_type(task_id, "perception") - res = run_slow_node_perception(model_args,task_id) - res["task_id"] = task_id - logger.info(f"性能感知结果: {str(res)}") - return res + _res = run_slow_node_perception(model_args, task_id) + _res["task_id"] = task_id + logger.info(f"性能感知结果: {str(_res)}") + _res = PerceptionResult.model_validate(_res) + return _res except Exception as e: logger.error(f"性能劣化感知工具出错: {str(e)}") raise HTTPException(status_code=500, detail=f"性能劣化感知工具出错: {str(e)}") @@ -72,31 +72,27 @@ def slow_node_detection_tool(performance_data: PerceptionResult) -> AIJobDetectR logger.info(f"慢卡定界工具开启,performance_data = {str(performance_data)}") try: - sync_server_by_ip_and_type(performance_data["task_id"], "detection") - _res = slow_node_detection_api(performance_data) + sync_server_by_ip_and_type(performance_data.task_id, "detection") + _res = slow_node_detection_api() logger.info(f"慢卡定界结果: {json.dumps(_res)}") + _res = AIJobDetectResult.model_validate(_res) return _res except Exception as e: logger.error(f"慢卡定界工具出错: {str(e)}") raise HTTPException(status_code=500, detail=f"慢卡定界工具出错: {str(e)}") -def generate_report_tool(source_data: Union[dict, str], report_type: str) -> Union[str, Dict[str, Any]]: - """生成最终报告的工具""" - logger.info(f"调用报告工具,report_type = {report_type}") - - try: - if report_type == "normal": - report_content = generate_normal_report(source_data) - elif report_type == "anomaly": - report_content = generate_degraded_report(source_data) - else: - report_content = generate_default_report(source_data) - - return report_content - except Exception as e: - logger.error(f"报告生成工具出错: {str(e)}") - raise HTTPException(status_code=500, detail=f"报告生成工具出错: {str(e)}") +def generate_report_tool(source_data: Union[PerceptionResult, AIJobDetectResult], report_type: str) -> dict: + print("调用了报告工具,report_type = " + report_type) + # 根据报告类型调用对应的生成方法 + if report_type == "normal": + result = generate_normal_report(source_data) + elif report_type == "anomaly": + result = generate_degraded_report(source_data) + else: + raise Exception("不支持的报告类型") + print("报告:", result) + return result @app.get("/slow-node/systrace", response_model=ApiResponse) @@ -106,8 +102,8 @@ async def slow_node_perception(ip: str = Query("127.0.0.1", description="节点I """ result = slow_node_perception_tool(ip) # 判断是否劣化 - report_type = "anomaly" if result.get("is_anomaly", True) else "normal" - if True is result["is_anomaly"]: + report_type = "anomaly" if result.is_anomaly else "normal" + if True is result.is_anomaly: result = slow_node_detection_tool(result) # 3. 自动调用报告生成 report_content = generate_report_tool(result, report_type) @@ -126,4 +122,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/systrace_mcp/systrace_mcp/report_api.py b/systrace_mcp/systrace_mcp/report_api.py index a5a9448..e2b3cd8 100644 --- a/systrace_mcp/systrace_mcp/report_api.py +++ b/systrace_mcp/systrace_mcp/report_api.py @@ -1,10 +1,12 @@ import json from datetime import datetime +from systrace_mcp.mcp_data import PerceptionResult, AIJobDetectResult -def generate_normal_report(data: dict) -> dict: + +def generate_normal_report(data: PerceptionResult) -> dict: """生成无劣化的正常报告""" - # 解析时间戳为可读格式 + data = data.model_dump() timestamp = data.get("start_time") start_time = datetime.fromtimestamp(timestamp // 1000).strftime("%Y-%m-%d %H:%M:%S") if timestamp else "未知时间" timestamp = data.get("end_time") @@ -15,7 +17,7 @@ def generate_normal_report(data: dict) -> dict: return data -def generate_degraded_report(data: dict) -> dict: +def generate_degraded_report(data: AIJobDetectResult) -> dict: """ 生成设备异常状态的JSON报告 @@ -26,13 +28,13 @@ def generate_degraded_report(data: dict) -> dict: 格式化的JSON报告字典 """ # 解析时间戳为可读格式 + data = data.model_dump() timestamp = data.get("timestamp") detect_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") if timestamp else "未知时间" # 提取异常信息 - abnormalDetail = data.get("abnormalDetail", []) + abnormalDetail = data.get("abnormal_detail", []) abnormal_count = len(abnormalDetail) - # 整理异常节点详情 abnormal_nodes = [] for abnormal in abnormalDetail: @@ -41,12 +43,12 @@ def generate_degraded_report(data: dict) -> dict: "serverIp": abnormal.get("serverIp"), "deviceInfo": abnormal.get("deviceInfo"), "methodType": abnormal.get("methodType"), - "kpiId":abnormal.get("kpiId"), + "kpiId": abnormal.get("kpiId"), "relaIds": abnormal.get("relaIds", []) }) # 整理正常节点信息 - normal_nodes = [item["deviceInfo"] for item in data.get("normalDetail", [])] + normal_nodes = [item["deviceInfo"] for item in data.get("normal_detail", [])] # 构建JSON报告 report = { @@ -54,9 +56,9 @@ def generate_degraded_report(data: dict) -> dict: "overview": { "detectTime": detect_time, "abnormalNodeCount": abnormal_count, - "compute": data.get("compute") , - "network": data.get("network") , - "storage": data.get("storage") , + "compute": data.get("compute"), + "network": data.get("network"), + "storage": data.get("storage"), }, "abnormalNodes": abnormal_nodes, "normalNodes": { @@ -68,13 +70,3 @@ def generate_degraded_report(data: dict) -> dict: return report - -def generate_default_report(data: dict) -> dict: - """生成默认报告(当类型不匹配时),返回JSON格式字典""" - return { - "report_title": "机器性能分析报告", - "warning": "报告类型未识别,以下是原始数据摘要", - "raw_data": data, - "report_type": "default" - } - -- Gitee