From 83227c71df1069f0993ccfe09b0254e6614b4d9e Mon Sep 17 00:00:00 2001 From: wuyulong17 <2284273586@qq.com> Date: Mon, 28 Jul 2025 15:32:46 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=AF=B9db=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E6=A0=BC=E5=BC=8F=E6=94=AF=E6=8C=81=EF=BC=8C=E5=89=8D?= =?UTF-8?q?=E5=90=8E=E7=AB=AF=E6=9E=B6=E6=9E=84=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../fe/src/graph_ascend/index.ts | 26 ++- .../fe/src/graph_ascend/type/index.d.ts | 6 +- .../fe/src/graph_ascend/useGraphAscend.ts | 21 +-- .../components/hierarchy/useGraph.ts | 8 +- .../components/tf_main_controler/index.ts | 3 +- .../components/tf_manual_match/useMatched.ts | 56 ++---- .../src/graph_controls_board/type/index.d.ts | 7 +- .../domain/useNodeInfoDomain.ts | 15 +- .../tb_graph_ascend/fe/src/utils/request.ts | 5 + .../server/app/service/__init__.py | 3 + .../server/app/service/base_graph_service.py | 147 +++++++++++++++ .../server/app/service/db_graph_service.py | 54 ++++++ .../server/app/service/file_check_wrapper.py | 47 +++++ .../app/service/graph_service_factory.py | 48 +++++ ...graph_service.py => json_graph_service.py} | 171 ++++++------------ .../server/app/utils/global_state.py | 15 +- .../server/app/utils/graph_utils.py | 5 +- .../server/app/views/graph_views.py | 113 ++++++++---- .../tb_graph_ascend/server/plugin.py | 11 +- 19 files changed, 508 insertions(+), 253 deletions(-) create mode 100644 plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/base_graph_service.py create mode 100644 plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/db_graph_service.py create mode 100644 plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py create mode 100644 plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service_factory.py rename plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/{graph_service.py => json_graph_service.py} (79%) diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/index.ts index 2b093520e..3d895cea6 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/index.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/index.ts @@ -225,10 +225,10 @@ class TfGraphDashboard extends LegacyElementMixin(PolymerElement) { return; } if (this.currentSelection?.run !== this.selection?.run || this.currentSelection?.tag !== this.selection?.tag) { - this.loadGraphData(this.selection.run, this.selection.tag); + this.loadGraphData(this.selection); } else if (this.currentSelection?.microStep !== this.selection?.microStep) { this.initGraphBoard(); // 只改变microsteps时,不重新加载图数据 - this.loadGraphAllNodeList(this.selection.run, this.selection.tag, this.selection.microStep); + this.loadGraphAllNodeList(this.selection); } this.currentSelection = this.selection; }; @@ -263,12 +263,12 @@ class TfGraphDashboard extends LegacyElementMixin(PolymerElement) { this.set('metaDir', data); } - loadGraphData = (run, tag) => { + loadGraphData = (metaData: SelectionType) => { if (this.eventSource) { this.eventSource.close(); this.eventSource = null; } - this.eventSource = new EventSource(`loadGraphData?run=${run}&tag=${tag}`); + this.eventSource = new EventSource(`loadGraphData?run=${metaData.run}&tag=${metaData.tag}`); this.eventSource.onmessage = async (e) => { const data = safeJSONParse(e.data); if (data?.error) { @@ -283,12 +283,8 @@ class TfGraphDashboard extends LegacyElementMixin(PolymerElement) { this.eventSource = null; try { await Promise.all([ - this.loadGraphConfig(this.selection?.run, this.selection?.tag), - this.loadGraphAllNodeList( - this.selection?.run, - this.selection?.tag, - this.selection?.microStep, - ), + this.loadGraphConfig(metaData), + this.loadGraphAllNodeList(metaData), ]); this.initGraphBoard(); // 先读取配置,再加载图,顺序很重要 this.progreesLoading('初始化完成', '请稍后', data); @@ -309,8 +305,8 @@ class TfGraphDashboard extends LegacyElementMixin(PolymerElement) { }; }; - loadGraphConfig = async (run, tag) => { - const { success, data, error } = await this.useGraphAscend.loadGraphConfig(run, tag); + loadGraphConfig = async (metaData) => { + const { success, data, error } = await this.useGraphAscend.loadGraphConfig(metaData); const config = data as GraphConfigType; if (success) { this.set('colors', config.colors); @@ -318,7 +314,7 @@ class TfGraphDashboard extends LegacyElementMixin(PolymerElement) { this.set('overflowcheck', config.overflowCheck); this.set('colorset', Object.entries(config.colors || {})); this.set('isSingleGraph', config.isSingleGraph); - this.set('task', config.task) + this.set('task', config.task); this.set('matchedConfigFiles', ['未选择', ...config.matchedConfigFiles]); const microstepsCount = Number(config.microSteps); if (microstepsCount) { @@ -339,8 +335,8 @@ class TfGraphDashboard extends LegacyElementMixin(PolymerElement) { } }; - loadGraphAllNodeList = async (run, tag, microStep) => { - const { success, data, error } = await this.useGraphAscend.loadGraphAllNodeList(run, tag, microStep); + loadGraphAllNodeList = async (metaData: SelectionType) => { + const { success, data, error } = await this.useGraphAscend.loadGraphAllNodeList(metaData); const allNodeList = data as GraphAllNodeType; if (success) { const nodelist = {} as NodeListType; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/type/index.d.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/type/index.d.ts index e763e9734..04066670d 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/type/index.d.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/type/index.d.ts @@ -24,8 +24,10 @@ export interface ProgressType { export interface SelectionType { run: string; tag: string; - microStep: number; - + type: string; + microStep?: number; + step?: string; + rank?: string; } export interface GraphConfigType { diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/useGraphAscend.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/useGraphAscend.ts index 88e7d344a..b83d74219 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/useGraphAscend.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_ascend/useGraphAscend.ts @@ -14,7 +14,7 @@ * limitations under the License. */ import request from '../utils/request'; -import { LoadGraphFileInfoListType } from './type'; +import { LoadGraphFileInfoListType, SelectionType } from './type'; const useGraphAscend = () => { const loadGraphFileInfoList = async (isSafeCheck: boolean): Promise => { @@ -37,24 +37,13 @@ const useGraphAscend = () => { }; } }; - const loadGraphConfig = async (runName: string, tagName: string): Promise => { - const params = { - run: runName, - tag: tagName, - }; - - const result = await request({ url: 'loadGraphConfigInfo', method: 'GET', params: params }); // 获取异步的 ArrayBuffer + const loadGraphConfig = async (metaData: SelectionType): Promise => { + const result = await request({ url: 'loadGraphConfigInfo', method: 'POST', data: { metaData }}); // 获取异步的 ArrayBuffer return result; }; - const loadGraphAllNodeList = async (runName: string, tagName: string, microStep: number): Promise => { - const params = { - run: runName, - tag: tagName, - microStep: microStep, - }; - - const result = await request({ url: 'loadGraphAllNodeList', method: 'GET', params: params }); // 获取异步的 ArrayBuffer + const loadGraphAllNodeList = async (metaData: SelectionType): Promise => { + const result = await request({ url: 'loadGraphAllNodeList', method: 'POST', data: { metaData }}); // 获取异步的 ArrayBuffer return result; }; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_board/components/hierarchy/useGraph.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_board/components/hierarchy/useGraph.ts index 481c804af..cfaa9e328 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_board/components/hierarchy/useGraph.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_board/components/hierarchy/useGraph.ts @@ -282,13 +282,13 @@ const useGraph = (): UseGraphType => { try { const metaDataSafe = safeJSONParse(JSON.stringify(metaData)); const params = { - nodeInfo: JSON.stringify(nodeInfo), - metaData: JSON.stringify(metaDataSafe), + nodeInfo, + metaData: metaDataSafe, }; const result = await request({ url: 'changeNodeExpandState', - method: 'GET', - params: params, + method: 'POST', + data: params, timeout: 10000, }); return result; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/components/tf_main_controler/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/components/tf_main_controler/index.ts index 847653905..6047bef96 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/components/tf_main_controler/index.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/components/tf_main_controler/index.ts @@ -106,7 +106,7 @@ class MainController extends PolymerElement { if (isEmpty(this.metaDir)) { return; } - const tags = this.metaDir[this.selectedRun]; + const { type, tags } = this.metaDir[this.selectedRun]; this.set('tags', tags); this.set('selectedTag', tags[0]); const selection = { @@ -114,6 +114,7 @@ class MainController extends PolymerElement { run: this.selectedRun, tag: tags[0], microStep: -1, + type }; this.set('selectedTag', tags[0]); this.set('selectedMicroStep', -1); diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/components/tf_manual_match/useMatched.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/components/tf_manual_match/useMatched.ts index 3358e257a..adbbe5337 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/components/tf_manual_match/useMatched.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/components/tf_manual_match/useMatched.ts @@ -15,8 +15,6 @@ */ import { isEmpty } from 'lodash'; -import { fetchPbTxt } from '../../../utils'; -import { safeJSONParse } from '../../../utils'; import request from '../../../utils/request'; import { UseMatchedType, MatchResultType } from '../../type'; const useMatched = (): UseMatchedType => { @@ -24,10 +22,10 @@ const useMatched = (): UseMatchedType => { const params = { npuNodeName: npuNodeName, benchNodeName: benchNodeName, - metaData: JSON.stringify(metaData), + metaData, isMatchChildren }; - const mactchResult = await request({ url: 'addMatchNodes', method: 'GET', params: params }); + const mactchResult = await request({ url: 'addMatchNodes', method: 'POST', data: params }); return mactchResult; }; @@ -35,54 +33,34 @@ const useMatched = (): UseMatchedType => { const params = { npuNodeName: npuNodeName, benchNodeName: benchNodeName, - metaData: JSON.stringify(metaData), + metaData, isUnMatchChildren }; - const mactchResult = await request({ url: 'deleteMatchNodes', method: 'GET', params: params }); + const mactchResult = await request({ url: 'deleteMatchNodes', method: 'POST', data: params }); return mactchResult; }; const saveMatchedNodesLink = async (selection: any): Promise => { - const metaData = { - run: selection.run, - tag: selection.tag, - }; - const params = new URLSearchParams(); - params.set('metaData', JSON.stringify(metaData)); - const precisionPath = `saveData?${String(params)}`; - const precisionStr = await fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer - const decoder = new TextDecoder(); - const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 - const saveResult = safeJSONParse(decodedStr.replace(/"None"/g, '{}')); + const saveResult = await request({ url: 'saveData', method: 'POST', data: { metaData: selection } }); return saveResult; }; const saveMatchedRelations = async (selection: any): Promise => { - const metaData = { - run: selection.run, - tag: selection.tag, - }; - const params = new URLSearchParams(); - params.set('metaData', JSON.stringify(metaData)); - const precisionPath = `saveMatchedRelations?${String(params)}`; - const precisionStr = await fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer - const decoder = new TextDecoder(); - const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 - const saveResult = safeJSONParse(decodedStr.replace(/"None"/g, '{}')); + const saveResult = await request({ url: 'saveMatchedRelations', method: 'POST', data: { metaData: selection } }); return saveResult; }; - const addMatchedNodesLinkByConfigFile = async (condfigFile: string, selection: any): Promise => { - if (isEmpty(condfigFile)) { + const addMatchedNodesLinkByConfigFile = async (configFile: string, selection: any): Promise => { + if (isEmpty(configFile)) { return { success: false, error: '请选择配置文件', }; } const params = { - configFile: condfigFile, - metaData: JSON.stringify(selection), + configFile: configFile, + metaData: selection, }; - const mactchResult = await request({ url: 'addMatchNodesByConfig', method: 'GET', params: params }); + const mactchResult = await request({ url: 'addMatchNodesByConfig', method: 'POST', data: params }); return mactchResult as MatchResultType; }; @@ -99,11 +77,7 @@ const useMatched = (): UseMatchedType => { error: '调试侧节点或标杆节点为空', }; } - const metaData = { - run: selection.run, - tag: selection.tag, - }; - const matchResult: MatchResultType = await requestAddMatchNodes(npuNodeName, benchNodeName, metaData, isMatchChildren); + const matchResult: MatchResultType = await requestAddMatchNodes(npuNodeName, benchNodeName, selection, isMatchChildren); return matchResult; }; @@ -114,11 +88,7 @@ const useMatched = (): UseMatchedType => { error: '调试侧节点或标杆节点为空', }; } - const metaData = { - run: selection.run, - tag: selection.tag, - }; - const matchResult: MatchResultType = await requestDeleteMatchNodes(npuNodeName, benchNodeName, metaData, isUnMatchChildren); + const matchResult: MatchResultType = await requestDeleteMatchNodes(npuNodeName, benchNodeName, selection, isUnMatchChildren); return matchResult; }; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/type/index.d.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/type/index.d.ts index e4339fb2a..06146458b 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/type/index.d.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_controls_board/type/index.d.ts @@ -19,7 +19,12 @@ export interface MinimapVis { } export type Dataset = Array; -export type MetaDirType = Record>; +export type MetaDirType = { + [string]: { + type: string; + tags: string[]; + } +}; export interface UseMatchedType { saveMatchedNodesLink: (selection: any) => Promise; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_info_board/domain/useNodeInfoDomain.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_info_board/domain/useNodeInfoDomain.ts index 05f6d64ee..b33a039bf 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_info_board/domain/useNodeInfoDomain.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/graph_info_board/domain/useNodeInfoDomain.ts @@ -13,21 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { fetchPbTxt, safeJSONParse } from '../../utils'; +import request from '../../utils/request'; const useNodeInfoDomain = (): { getMatchNodeInfo: (nodeInfo: any, metaData: any) => Promise } => { const getMatchNodeInfo = async (nodeInfo: any, metaData: any): Promise => { - const params = new URLSearchParams(); - params.set('nodeInfo', JSON.stringify(nodeInfo)); - params.set('metaData', JSON.stringify(metaData)); - // 接口请求 - const precisionPath = `getNodeInfo?${String(params)}`; - const precisionStr = await fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer - const decoder = new TextDecoder(); - const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 - // 接口返回 - const mactchResult = safeJSONParse(decodedStr.replace(/"None"/g, '{}')); - return mactchResult; + const matchResult = await request({ url: 'getNodeInfo', method: 'POST', data: { nodeInfo, metaData } }); + return matchResult; }; return { diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/utils/request.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/utils/request.ts index 84c23efb6..0f5d6b133 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/utils/request.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/utils/request.ts @@ -34,6 +34,11 @@ export default async function request(options: RequestOptions): Promise const controller = new AbortController(); const signal = controller.signal; const timeoutId = setTimeout(() => controller.abort(), timeout); + if (typeof params === 'object' && params !== null) { + if ('metaData' in params) { + params.metaData.type = 'rank' in params.metaData ? 'db' : 'json'; + } + } const response: AxiosResponse = await axios({ url, method, diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/__init__.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/__init__.py index ee2432f47..ee251c408 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/__init__.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/__init__.py @@ -13,3 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from .base_graph_service import GraphServiceStrategy +from .file_check_wrapper import check_file_type +from .graph_service_factory import ServiceFactory diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/base_graph_service.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/base_graph_service.py new file mode 100644 index 000000000..bf927f67a --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/base_graph_service.py @@ -0,0 +1,147 @@ +# Copyright (c) 2025, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +from abc import ABC, abstractmethod + +from tensorboard.util import tb_logging + +from ..utils.graph_utils import GraphUtils +from ..utils.global_state import GraphState, NPU, BENCH, Extension, DataType +from ..controllers.layout_hierarchy_controller import LayoutHierarchyController + +logger = tb_logging.get_logger() +DB_EXT = Extension.DB.value +JSON_EXT = Extension.JSON.value +DB_TYPE = DataType.DB.value +JSON_TYPE = DataType.JSON.value + + +class GraphServiceStrategy(ABC): + run = '' + tag = '' + + def __init__(self, run, tag): + self.run = run + self.tag = tag + + @staticmethod + def load_meta_dir(is_safe_check): + """ + Scan logdir for directories containing .vis(.db) files. If the directory contains .vis.db files, + it is considered a db type and the .vis files are ignored. Otherwise, it is considered a json type. + """ + logdir = GraphState.get_global_value('logdir') + runs = GraphState.get_global_value('runs', {}) + first_run_tags = GraphState.get_global_value('first_run_tags', {}) + meta_dir = {} + error_list = [] + for root, _, files in GraphUtils.walk_with_max_depth(logdir, 2): + run_abs = os.path.abspath(root) + run = os.path.basename(run_abs) # 不允许同名目录,否则有问题 + for file in files: + if file.endswith(DB_EXT): + tag = file[:-len(DB_EXT)] + _, error = GraphUtils.safe_load_data(run_abs, file, True) + if error and is_safe_check: + error_list.append({ + 'run': run, + 'tag': tag, + 'info': f'Error: {error}' + }) + logger.error(f'Error: File run:"{run_abs},tag:{tag}" is not accessible. Error: {error}') + continue + runs[run] = run_abs + meta_dir[run] = {'type': DB_TYPE, 'tags': [tag]} + break + # 只取文件名,不包含扩展名 + if file.endswith(JSON_EXT): + tag = os.path.splitext(file)[0] + _, error = GraphUtils.safe_load_data(run_abs, file, True) + if error and is_safe_check: + error_list.append({ + 'run': run, + 'tag': tag, + 'info': f'Error: {error}' + }) + logger.error(f'Error: File run:"{run_abs},tag:{tag}" is not accessible. Error: {error}') + continue + runs[run] = run_abs + if meta_dir.get(run) is None: + meta_dir[run] = {'type': JSON_TYPE, 'tags': [tag]} + else: + meta_dir.get(run).get('tags').append(tag) + meta_dir = GraphUtils.sort_data(meta_dir) + for run, value in meta_dir.items(): + first_run_tags[run] = value.get('tags')[0] + GraphState.set_global_value('runs', runs) + GraphState.set_global_value('first_run_tags', first_run_tags) + result = { + 'data': meta_dir, + 'error': error_list + } + return result + + @staticmethod + def update_hierarchy_data(graph_type): + if graph_type == NPU or graph_type == BENCH: + hierarchy = LayoutHierarchyController.update_hierarchy_data(graph_type) + return {'success': True, 'data': hierarchy} + else: + return {'success': False, 'error': '节点类型错误'} + + @abstractmethod + def load_graph_data(self): + pass + + @abstractmethod + def load_graph_config_info(self): + pass + + @abstractmethod + def load_graph_all_node_list(self, meta_data): + pass + + @abstractmethod + def change_node_expand_state(self, node_info, meta_data): + pass + + @abstractmethod + def get_node_info(self, node_info, meta_data): + pass + + @abstractmethod + def add_match_nodes(self, npu_node_name, bench_node_name, meta_data, is_match_children): + pass + + @abstractmethod + def add_match_nodes_by_config(self, config_file, meta_data): + pass + + @abstractmethod + def delete_match_nodes(self, npu_node_name, bench_node_name, meta_data, is_unmatch_children): + pass + + @abstractmethod + def save_data(self, meta_data): + pass + + @abstractmethod + def update_colors(self, colors): + pass + + @abstractmethod + def save_matched_relations(self): + pass diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/db_graph_service.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/db_graph_service.py new file mode 100644 index 000000000..814624848 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/db_graph_service.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from .base_graph_service import GraphServiceStrategy + + +class DbGraphService(GraphServiceStrategy): + def __init__(self, run_path, tag): + super().__init__(run_path, tag) + + def load_graph_data(self): + pass + + def load_graph_config_info(self): + pass + + def load_graph_all_node_list(self, meta_data): + pass + + def change_node_expand_state(self, node_info, meta_data): + pass + + def get_node_info(self, node_info, meta_data): + pass + + def add_match_nodes(self, npu_node_name, bench_node_name, meta_data, is_match_children): + pass + + def add_match_nodes_by_config(self, config_file, meta_data): + pass + + def delete_match_nodes(self, npu_node_name, bench_node_name, meta_data, is_unmatch_children): + pass + + def save_data(self, meta_data): + pass + + def update_colors(self, colors): + pass + + def save_matched_relations(self): + pass diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py new file mode 100644 index 000000000..69320dc0e --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from tensorboard.backend import http_util + +from ..utils.graph_utils import GraphUtils +from ..utils.global_state import DataType + + +def check_file_type(func): + def wrapper(*args, **kwargs): + if len(args) <= 0: + raise RuntimeError('Illegal function call, at least 1 parameter is required but got 0') + request = args[0] + if not hasattr(request, 'get_data'): + raise RuntimeError('Invalid parameters, "request" should have a "get_data" method') + meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') + + result = {'success': False, 'error': ''} + if meta_data is None or not isinstance(meta_data, dict): + result['error'] = 'The query parameter "metaData" is required and must be a dictionary' + return http_util.Respond(request, result, "application/json") + data_type = meta_data.get('type') + run = meta_data.get('run') + tag = meta_data.get('tag') + if data_type is None or run is None or tag is None: + result['error'] = 'The query parameters "type", "run" and "tag" in "metaData" are required' + return http_util.Respond(request, result, "application/json") + if data_type not in [e.value for e in DataType]: + result['error'] = 'Unsupported file type' + return http_util.Respond(request, result, "application/json") + + return func(*args, **kwargs) + + return wrapper diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service_factory.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service_factory.py new file mode 100644 index 000000000..18ff06a64 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service_factory.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from .db_graph_service import DbGraphService +from .json_graph_service import JsonGraphService +from ..utils.global_state import GraphState, DataType + + +class ServiceFactory: + def __init__(self): + self.run = '' + self.tag = '' + self.data_type = None + self.strategy = JsonGraphService('', '') + + def create_strategy(self, data_type, run, tag): + if not (data_type == self.data_type and run == self.run and tag == self.tag): + self.data_type = data_type + self.run = run + self.tag = tag + if data_type == DataType.DB.value: + self.strategy = DbGraphService(run, tag) + else: + self.strategy = JsonGraphService(run, tag) + return self.strategy + + def create_strategy_without_tag(self, data_type, run): + if not (data_type == self.data_type and run == self.run): + self.data_type = data_type + self.run = run + self.tag = GraphState.get_global_value('first_run_tags', {}).get(self.run) + if data_type == DataType.DB.value: + self.strategy = DbGraphService(run, self.tag) + else: + self.strategy = JsonGraphService(run, self.tag) + return self.strategy diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/json_graph_service.py similarity index 79% rename from plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py rename to plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/json_graph_service.py index 04586ffb3..a90085032 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/json_graph_service.py @@ -16,65 +16,31 @@ import os import time import json + from tensorboard.util import tb_logging +from .db_graph_service import DbGraphService from ..utils.graph_utils import GraphUtils from ..utils.global_state import GraphState from ..controllers.match_nodes_controller import MatchNodesController from ..controllers.layout_hierarchy_controller import LayoutHierarchyController from ..utils.global_state import NPU_PREFIX, BENCH_PREFIX, NPU, BENCH, SINGLE +from .base_graph_service import GraphServiceStrategy logger = tb_logging.get_logger() -class GraphService: - - @staticmethod - def load_meta_dir(is_safe_check): - """Scan logdir for directories containing .vis files, modified to return a tuple of (run, tag).""" - - logdir = GraphState.get_global_value('logdir') - runs = GraphState.get_global_value('runs', {}) - first_run_tags = GraphState.get_global_value('first_run_tags', {}) - - meta_dir = {} - error_list = [] - for root, _, files in GraphUtils.walk_with_max_depth(logdir, 2): - for file in files: - if file.endswith('.vis'): # check for .vis extension - run_abs = os.path.abspath(root) - run = os.path.basename(run_abs) # 不允许同名目录,否则有问题 - tag = os.path.splitext(file)[0] # Use the filename without extension as tag - _, error = GraphUtils.safe_load_data(run_abs, f"{tag}.vis", True) - if error and is_safe_check: - error_list.append({ - 'run': run, - 'tag': tag, - 'info': f'Error: {error}' - }) - logger.error(f'Error: File run:"{run_abs},tag:{tag}" is not accessible. Error: {error}') - continue - runs[run] = run_abs - meta_dir.setdefault(run, []).append(tag) - meta_dir = GraphUtils.sort_data(meta_dir) - for run, tags in meta_dir.items(): - first_run_tags[run] = tags[0] - GraphState.set_global_value('runs', runs) - GraphState.set_global_value('first_run_tags', first_run_tags) - result = { - 'data': meta_dir, - 'error': error_list - } - return result +class JsonGraphService(GraphServiceStrategy): + def __init__(self, run_path, tag): + super().__init__(run_path, tag) - @staticmethod - def load_graph_data(run_name, tag): + def load_graph_data(self): runs = GraphState.get_global_value('runs') - run = runs.get(run_name) + run_path = runs.get(self.run) buffer = "" read_bytes = 0 chunk_size = 1024 * 1024 * 60 # 缓冲区 json_data = None # 最终存储的变量 - file_path = os.path.join(run, f"{tag}.vis") + file_path = os.path.join(run_path, f"{self.tag}.vis") file_path = os.path.normpath(file_path) # 标准化路径 file_size = os.path.getsize(file_path) with open(file_path, 'r', encoding='utf-8') as f: @@ -104,15 +70,16 @@ class GraphService: if json_data is not None: # 验证存储 GraphState.set_global_value('current_file_data', json_data) - GraphState.set_global_value('current_tag', tag) - GraphState.set_global_value('current_run', run) + GraphState.set_global_value('current_tag', self.tag) + GraphState.set_global_value('current_run', run_path) yield f"data: {json.dumps({'done': True, 'progress': 100, 'status': 'loading'})}\n\n" else: yield f"data: {json.dumps({'progress': current_progress, 'error': 'Failed to parse JSON'})}\n\n" - @staticmethod - def load_graph_config_info(run, tag): - graph_data, error_message = GraphUtils.get_graph_data({'run': run, 'tag': tag}) + def load_graph_config_info(self): + run_name = self.run + tag = self.tag + graph_data, error_message = GraphUtils.get_graph_data({'run': run_name, 'tag': tag}) if error_message or not graph_data: return {'success': False, 'error': error_message} config = {} @@ -126,30 +93,30 @@ class GraphService: config['isSingleGraph'] = False if graph_data.get(NPU) else True # 读取配置信息,run层面 config_data = GraphState.get_global_value("config_data", {}) - config_data_run = config_data.get(run, {}) + config_data_run = config_data.get(run_name, {}) if not config_data_run: # 如果没有run的配置信息,则读取第一个文件中的Colors first_run_tags = GraphState.get_global_value("first_run_tags") - first_tag = first_run_tags.get(run) + first_tag = first_run_tags.get(run_name) if not first_tag: return {'success': False, 'error': '获取配置信息失败,请检查目录中第一个文件'} - first_graph_data, error_message = GraphUtils.get_graph_data({'run': run, 'tag': first_tag}) + first_graph_data, error_message = GraphUtils.get_graph_data({'run': run_name, 'tag': first_tag}) config_data_run['colors'] = first_graph_data.get('Colors') - config_data[run] = config_data_run + config_data[run_name] = config_data_run GraphState.set_global_value('config_data', config_data) config['colors'] = first_graph_data.get('Colors') else: config['colors'] = config_data_run.get('colors') # 读取目录下配置文件列表 - config_files = GraphUtils.find_config_files(run) + config_files = GraphUtils.find_config_files(run_name) config['matchedConfigFiles'] = config_files or [] config['task'] = graph_data.get('task') return {'success': True, 'data': config} except Exception as e: return {'success': False, 'error': '获取配置信息失败,请检查目录中第一个文件'} - @staticmethod - def load_graph_all_node_list(run, tag, micro_step): - graph_data, error_message = GraphUtils.get_graph_data({'run': run, 'tag': tag}) + def load_graph_all_node_list(self, meta_data): + micro_step = meta_data.get('microStep') + graph_data, error_message = GraphUtils.get_graph_data({'run': self.run, 'tag': self.tag}) if error_message or not graph_data: return {'success': False, 'error': error_message} result = {} @@ -168,13 +135,13 @@ class GraphService: npu_node_name_list = list(npu_node.keys()) bench_node_name_list = list(bench_node.keys()) npu_unmatehed_name_list = [ - key - for key, value in npu_node.items() + key + for key, value in npu_node.items() if not value.get("matched_node_link") ] bench_unmatehed_name_list = [ - key - for key, value in bench_node.items() + key + for key, value in bench_node.items() if not value.get("matched_node_link") ] # 保存未匹配和已匹配的节点到全局变量 @@ -194,15 +161,14 @@ class GraphService: bench_node_name = npu_node.get('matched_node_link', [None])[-1].replace(BENCH_PREFIX, '', 1) result['npuMatchNodes'][npu_node_name] = bench_node_name result['benchMatchNodes'][bench_node_name] = npu_node_name - + GraphState.set_global_value('config_data', config_data) return {'success': True, 'data': result} except Exception as e: logger.error('获取节点列表失败:' + str(e)) return {'success': False, 'error': '获取节点列表失败:' + str(e)} - @staticmethod - def change_node_expand_state(node_info, meta_data): + def change_node_expand_state(self, node_info, meta_data): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message or not graph_data: return {'success': False, 'error': error_message} @@ -214,7 +180,7 @@ class GraphService: if not graph_data.get(NPU): hierarchy = LayoutHierarchyController.change_expand_state(node_name, SINGLE, graph_data, micro_step) # NPU - elif (graph_type == NPU): + elif graph_type == NPU: hierarchy = LayoutHierarchyController.change_expand_state(node_name, graph_type, graph_data.get(NPU, {}), micro_step) # 标杆 @@ -227,20 +193,11 @@ class GraphService: except Exception as e: logger.error('节点展开或收起发生错误:' + str(e)) node_type_name = "" - if graph_data.get(NPU): + if graph_data.get(NPU): node_type_name = '调试侧' if graph_type == NPU else '标杆侧' return {'success': False, 'error': f'{node_type_name}节点展开或收起发生错误', 'data': None} - @staticmethod - def update_hierarchy_data(graph_type): - if (graph_type == NPU or graph_type == BENCH): - hierarchy = LayoutHierarchyController.update_hierarchy_data(graph_type) - return {'success': True, 'data': hierarchy} - else: - return {'success': False, 'error': '节点类型错误'} - - @staticmethod - def get_node_info(node_info, meta_data): + def get_node_info(self, node_info, meta_data): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message: return {'success': False, 'error': error_message} @@ -267,8 +224,7 @@ class GraphService: logger.error('获取节点信息失败:' + str(e)) return {'success': False, 'error': '获取节点信息失败:' + str(e), 'data': None} - @staticmethod - def add_match_nodes(npu_node_name, bench_node_name, meta_data, is_match_children): + def add_match_nodes(self, npu_node_name, bench_node_name, meta_data, is_match_children): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message: return {'success': False, 'error': error_message} @@ -297,8 +253,7 @@ class GraphService: except Exception as e: return {'success': False, '操作失败': str(e), 'data': None} - @staticmethod - def add_match_nodes_by_config(config_file, meta_data): + def add_match_nodes_by_config(self, config_file, meta_data): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message: return {'success': False, 'error': '读取文件失败'} @@ -316,8 +271,7 @@ class GraphService: except Exception as e: return {'success': False, 'error': '操作失败', 'data': None} - @staticmethod - def delete_match_nodes(npu_node_name, bench_node_name, meta_data, is_unmatch_children): + def delete_match_nodes(self, npu_node_name, bench_node_name, meta_data, is_unmatch_children): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message: return {'success': False, 'error': error_message} @@ -345,53 +299,46 @@ class GraphService: except Exception as e: return {'success': False, '操作失败': str(e), 'data': None} - @staticmethod - def update_colors(run, colors): - """Set new colors in jsondata.""" - try: - config_data = GraphState.get_global_value("config_data", {}) - first_run_tags = GraphState.get_global_value("first_run_tags") - config_data_run = config_data.get(run, {}) - first_run_tag = first_run_tags.get(run) - first_file_data, error = GraphUtils.safe_load_data(run, f"{first_run_tag}.vis") - if error: - return {'success': False, 'error': '获取配置信息失败,请检查目录中第一个文件'} - first_file_data['Colors'] = colors - config_data_run['colors'] = colors - config_data[run] = config_data_run - GraphState.set_global_value("config_data", config_data) - GraphUtils.safe_save_data(first_file_data, run, f"{first_run_tag}.vis") - return {'success': True, 'error': None, 'data': {}} - except Exception as e: - return {'success': False, 'error': str(e), 'data': None} - - @staticmethod - def save_data(meta_data): + def save_data(self, meta_data): if not meta_data: return {'success': False, 'error': '参数为空'} graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message: return {'success': False, 'error': error_message} - run = meta_data.get('run') - tag = meta_data.get('tag') try: - _, error = GraphUtils.safe_save_data(graph_data, run, f"{tag}.vis") + _, error = GraphUtils.safe_save_data(graph_data, self.run, f"{self.tag}.vis") if error: return {'success': False, 'error': error} except (ValueError, IOError, PermissionError) as e: return {'success': False, 'error': f"Error: {e}"} return {'success': True} - @staticmethod - def save_matched_relations(meta_data): - if not meta_data: - return {'success': False, 'error': '参数为空'} + def update_colors(self, colors): + """Set new colors in jsondata.""" + try: + config_data = GraphState.get_global_value("config_data", {}) + first_run_tags = GraphState.get_global_value("first_run_tags") + config_data_run = config_data.get(self.run, {}) + first_run_tag = first_run_tags.get(self.run) + first_file_data, error = GraphUtils.safe_load_data(self.run, f"{first_run_tag}.vis") + if error: + return {'success': False, 'error': '获取配置信息失败,请检查目录中第一个文件'} + first_file_data['Colors'] = colors + config_data_run['colors'] = colors + config_data[self.run] = config_data_run + GraphState.set_global_value("config_data", config_data) + GraphUtils.safe_save_data(first_file_data, self.run, f"{first_run_tag}.vis") + return {'success': True, 'error': None, 'data': {}} + except Exception as e: + return {'success': False, 'error': str(e), 'data': None} + + def save_matched_relations(self): + run = self.run + tag = self.tag config_data = GraphState.get_global_value("config_data") # 匹配列表和未匹配列表 npu_match_nodes_list = config_data.get('manualMatchNodes', {}) - run = meta_data.get('run') - tag = meta_data.get('tag') try: _, error = GraphUtils.safe_save_data(npu_match_nodes_list, run, f"{tag}.vis.config") if error: diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py index 6e4fcabae..9ed8c4920 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py @@ -15,6 +15,9 @@ # ============================================================================== # 全局常量 +from enum import Enum + + ADD_MATCH_KEYS = [ 'MaxAbsErr', 'MinAbsErr', @@ -53,7 +56,6 @@ class GraphState: 'logdir': '', 'current_tag': '', 'current_run': '', - 'current_file_path': '', 'current_file_data': {}, 'current_hierarchy': {}, 'config_data': { @@ -79,7 +81,6 @@ class GraphState: 'logdir': '', 'current_tag': '', 'current_run': '', - 'current_file_path': '', 'current_file_data': {}, 'current_hierarchy': {}, 'config_data': { @@ -115,3 +116,13 @@ class GraphState: 重置所有全局变量为默认值 """ GraphState.init_defaults() + + +class Extension(Enum): + DB = '.vis.db' + JSON = '.vis' + + +class DataType(Enum): + DB = 'db' + JSON = 'json' diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py index a862b4dbc..569b55f99 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py @@ -468,8 +468,7 @@ class GraphUtils: sorted_keys = sorted(data.keys(), key=cmp_to_key(GraphUtils.compare_tag_names)) for k in sorted_keys: # 对每个键对应的值列表进行排序 - sorted_values = sorted(data[k], key=cmp_to_key(GraphUtils.compare_tag_names)) - sorted_data[k] = sorted_values + sorted_values = sorted(data.get(k, {}).get('tags'), key=cmp_to_key(GraphUtils.compare_tag_names)) + sorted_data[k] = {'type': data.get(k, {}).get('type'), 'tags': sorted_values} return sorted_data - diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py index 17d85b704..d590d3c53 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py @@ -19,11 +19,12 @@ import json from pathlib import Path from werkzeug import wrappers, Response, exceptions from tensorboard.backend import http_util -from ..service.graph_service import GraphService +from ..service import ServiceFactory, GraphServiceStrategy, check_file_type from ..utils.graph_utils import GraphUtils class GraphView: + service_factory = ServiceFactory() # 静态文件路由 @staticmethod @@ -51,8 +52,8 @@ class GraphView: @wrappers.Request.application def load_meta_dir(request): """Scan logdir for directories containing .vis files, modified to return a tuple of (run, tag).""" - is_safe_check = GraphUtils.safe_json_loads(request.args.get("isSafeCheck")) - result = GraphService.load_meta_dir(is_safe_check) + is_safe_check = GraphUtils.safe_json_loads(request.args.get('isSafeCheck')) + result = GraphServiceStrategy.load_meta_dir(is_safe_check) response = http_util.Respond(request, result, "application/json") return response @@ -60,10 +61,13 @@ class GraphView: @staticmethod @wrappers.Request.application def load_graph_data(request): - run = request.args.get("run") - tag = request.args.get("tag") + meta_data = { + 'run': request.args.get('run'), + 'tag': request.args.get('tag') + } + strategy = GraphView._get_strategy(meta_data) return Response( - GraphService.load_graph_data(run, tag), + strategy.load_graph_data(), mimetype='text/event-stream', headers={ 'Cache-Control': 'no-cache', @@ -75,10 +79,11 @@ class GraphView: # 获取当前图数据配置信息 @staticmethod @wrappers.Request.application + @check_file_type def load_graph_config_info(request): - run = request.args.get("run") - tag = request.args.get("tag") - result = GraphService.load_graph_config_info(run, tag) + meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') + strategy = GraphView._get_strategy(meta_data) + result = strategy.load_graph_config_info() # 创建响应对象 response = http_util.Respond(request, result, "application/json") return response @@ -86,21 +91,24 @@ class GraphView: # 获取当前图所有节点列表 @staticmethod @wrappers.Request.application + @check_file_type def load_graph_all_node_list(request): - run = request.args.get("run") - tag = request.args.get("tag") - micro_step = request.args.get('microStep') - result = GraphService.load_graph_all_node_list(run, tag, micro_step) + meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') + strategy = GraphView._get_strategy(meta_data) + result = strategy.load_graph_all_node_list(meta_data) response = http_util.Respond(request, result, "application/json") return response # 展开关闭节点 @staticmethod @wrappers.Request.application + @check_file_type def change_node_expand_state(request): - node_info = GraphUtils.safe_json_loads(request.args.get("nodeInfo")) - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) - hierarchy = GraphService.change_node_expand_state(node_info, meta_data) + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + node_info = data.get('nodeInfo') + meta_data = data.get('metaData') + strategy = GraphView._get_strategy(meta_data) + hierarchy = strategy.change_node_expand_state(node_info, meta_data) return http_util.Respond(request, json.dumps(hierarchy), "application/json") # 更新当前图节点信息 @@ -108,70 +116,97 @@ class GraphView: @wrappers.Request.application def update_hierarchy_data(request): graph_type = request.args.get("graphType") - hierarchy = GraphService.update_hierarchy_data(graph_type) + hierarchy = GraphServiceStrategy.update_hierarchy_data(graph_type) return http_util.Respond(request, json.dumps(hierarchy), "application/json") # 获取当前节点对应节点的信息看板数据 @staticmethod @wrappers.Request.application + @check_file_type def get_node_info(request): - node_info = GraphUtils.safe_json_loads(request.args.get("nodeInfo")) - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) - node_detail = GraphService.get_node_info(node_info, meta_data) + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + node_info = data.get('nodeInfo') + meta_data = data.get('metaData') + strategy = GraphView._get_strategy(meta_data) + node_detail = strategy.get_node_info(node_info, meta_data) return http_util.Respond(request, json.dumps(node_detail), "application/json") # 根据配置文件添加匹配节点 @staticmethod @wrappers.Request.application + @check_file_type def add_match_nodes_by_config(request): - config_file = request.args.get("configFile") - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) - match_result = GraphService.add_match_nodes_by_config(config_file, meta_data) + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + config_file = data.get('configFile') + meta_data = data.get('metaData') + strategy = GraphView._get_strategy(meta_data) + match_result = strategy.add_match_nodes_by_config(config_file, meta_data) return http_util.Respond(request, json.dumps(match_result), "application/json") # 添加匹配节点 @staticmethod @wrappers.Request.application + @check_file_type def add_match_nodes(request): - npu_node_name = request.args.get("npuNodeName") - bench_node_name = request.args.get("benchNodeName") - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) - is_match_children = GraphUtils.safe_json_loads(request.args.get("isMatchChildren")) - match_result = GraphService.add_match_nodes(npu_node_name, bench_node_name, meta_data, is_match_children) + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + npu_node_name = data.get("npuNodeName") + bench_node_name = data.get("benchNodeName") + meta_data = data.get("metaData") + is_match_children = data.get("isMatchChildren") + strategy = GraphView._get_strategy(meta_data) + match_result = strategy.add_match_nodes(npu_node_name, bench_node_name, meta_data, is_match_children) return http_util.Respond(request, json.dumps(match_result), "application/json") # 取消节点匹配 @staticmethod @wrappers.Request.application + @check_file_type def delete_match_nodes(request): - npu_node_name = request.args.get("npuNodeName") - bench_node_name = request.args.get("benchNodeName") - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) - is_unmatch_children = GraphUtils.safe_json_loads(request.args.get("isUnMatchChildren")) - match_result = GraphService.delete_match_nodes(npu_node_name, bench_node_name, meta_data, is_unmatch_children) + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + npu_node_name = data.get("npuNodeName") + bench_node_name = data.get("benchNodeName") + meta_data = data.get("metaData") + is_unmatch_children = data.get("isUnMatchChildren") + strategy = GraphView._get_strategy(meta_data) + match_result = strategy.delete_match_nodes(npu_node_name, bench_node_name, meta_data, is_unmatch_children) return http_util.Respond(request, json.dumps(match_result), "application/json") # 保存匹配节点列表 @staticmethod @wrappers.Request.application + @check_file_type def save_data(request): - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) - save_result = GraphService.save_data(meta_data) + meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') + strategy = GraphView._get_strategy(meta_data) + save_result = strategy.save_data(meta_data) return http_util.Respond(request, json.dumps(save_result), "application/json") # 更新颜色信息 @staticmethod @wrappers.Request.application def update_colors(request): - run = request.args.get('run') colors = GraphUtils.safe_json_loads(request.args.get('colors')) - update_result = GraphService.update_colors(run, colors) + run = request.args.get('run') + strategy = GraphView._get_strategy({'run': run}, no_tag=True) + update_result = strategy.update_colors(colors) return http_util.Respond(request, json.dumps(update_result), "application/json") # 保存匹配关系 @staticmethod @wrappers.Request.application + @check_file_type def save_matched_relations(request): - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) - save_result = GraphService.save_matched_relations(meta_data) + meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') + strategy = GraphView._get_strategy(meta_data) + save_result = strategy.save_matched_relations() return http_util.Respond(request, json.dumps(save_result), "application/json") + + @staticmethod + def _get_strategy(meta_data, no_tag=False): + run = meta_data.get('run') + data_type = meta_data.get('type') + if no_tag: + return GraphView.service_factory.create_strategy_without_tag(data_type, run) + + tag = meta_data.get('tag') + return GraphView.service_factory.create_strategy(data_type, run, tag) diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py index 2c65227a9..40d1e6612 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py @@ -26,12 +26,14 @@ from tensorboard.util import tb_logging from . import constants from .app.views.graph_views import GraphView from .app.utils.graph_utils import GraphUtils -from .app.utils.global_state import GraphState +from .app.utils.global_state import GraphState, Extension logger = tb_logging.get_logger() PLUGIN_NAME = 'graph_ascend' PLUGIN_NAME_RUN_METADATA_WITH_GRAPH = 'graph_ascend_run_metadata_graph' +DB_EXT = Extension.DB.value +JSON_EXT = Extension.JSON.value class GraphsPlugin(base_plugin.TBPlugin): @@ -84,13 +86,16 @@ class GraphsPlugin(base_plugin.TBPlugin): def is_active(self): """The graphs plugin is active if any run has a graph.""" + def _is_vis(path, file_name): + return os.path.isfile(path) and (file_name.endswith(DB_EXT) or file_name.endswith(JSON_EXT)) + for content in os.listdir(self.logdir): content_path = os.path.join(self.logdir, content) - if os.path.isfile(content_path) and content.endswith('.vis'): + if _is_vis(content_path, content): return True if os.path.isdir(content_path): for file in os.listdir(content_path): - if os.path.isfile(os.path.join(content_path, file)) and file.endswith('.vis'): + if _is_vis(os.path.join(content_path, file), file): return True return False -- Gitee From b8aa4154716ff49d4adab983c88579e8b75f9611 Mon Sep 17 00:00:00 2001 From: wuyulong17 <2284273586@qq.com> Date: Mon, 28 Jul 2025 20:03:11 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A3=80=E8=A7=86?= =?UTF-8?q?=E6=84=8F=E8=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tb_graph_ascend/server/app/service/file_check_wrapper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py index 69320dc0e..47f4ab7f3 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================== from tensorboard.backend import http_util +from werkzeug.wrappers.request import Request from ..utils.graph_utils import GraphUtils from ..utils.global_state import DataType @@ -24,8 +25,8 @@ def check_file_type(func): if len(args) <= 0: raise RuntimeError('Illegal function call, at least 1 parameter is required but got 0') request = args[0] - if not hasattr(request, 'get_data'): - raise RuntimeError('Invalid parameters, "request" should have a "get_data" method') + if not isinstance(request, Request): + raise RuntimeError('The request "parameter" is not in a format supported by werkzeug') meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') result = {'success': False, 'error': ''} -- Gitee