diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json index 787f5b5549c7379d9e5750d7800ec31706df6f6c..6f948a941f03f44e3c7f5f8c8274f333e82339bd 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json @@ -63,6 +63,11 @@ "@polymer/paper-tooltip": "^3.0.1", "@polymer/polymer": "^3.5.1", "@types/lodash": "^4.17.1", + "@vaadin/notification": "^23.5.11", + "@vaadin/tabs": "^23.5.11", + "@vaadin/tabsheet": "^23.5.11", + "@vaadin/tooltip": "^23.5.11", + "@vaadin/vaadin-grid": "^23.5.11", "clean-webpack-plugin": "^4.0.0", "cross-env": "^7.0.3", "css-loader": "^7.1.2", diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.html.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.html.ts index b8250e300a67b5e5abf409fb54a0a7315024699e..e0f378303f21925f323d51150115f06657aebeb0 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.html.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.html.ts @@ -552,7 +552,7 @@ export const template = html` #minimap { position: absolute; right: 20px; - bottom: 20px; + bottom: 28%; } .context-menu { diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_board/tf-graph-board.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_board/tf-graph-board.ts index f43273d7b8f529d106f2726497886f776df25628..6f79db5d416ae40c9c68d23db36579f826db9b51 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_board/tf-graph-board.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_board/tf-graph-board.ts @@ -21,6 +21,7 @@ import * as tf_graph from '../tf_graph_common/graph'; import * as tf_graph_hierarchy from '../tf_graph_common/hierarchy'; import * as tf_graph_render from '../tf_graph_common/render'; import '../tf_graph_info/tf-graph-info'; +import '../tf_graph_node_info/index'; import * as _ from 'lodash'; /** @@ -48,6 +49,8 @@ class TfGraphBoard extends LegacyElementMixin(PolymerElement) { width: 100%; height: 100%; opacity: 1; + display: flex; + flex-direction: column; } .container.loading { @@ -94,6 +97,12 @@ class TfGraphBoard extends LegacyElementMixin(PolymerElement) { height: 100%; } + #tab-info { + /* height: 30%; */ + background-color: #ffffff; + border-top: 2px solidrgb(153, 152, 152); + } + #progress-bar { display: flex; flex-direction: column; @@ -173,40 +182,17 @@ class TfGraphBoard extends LegacyElementMixin(PolymerElement) { auto-extract-nodes="[[autoExtractNodes]]" > -
- + - -
- - - - -
+
`; @@ -297,49 +283,6 @@ class TfGraphBoard extends LegacyElementMixin(PolymerElement) { ready() { super.ready(); - - const handle = this.shadowRoot?.getElementById('resize-handle') as HTMLElement; - - if (handle) { - let isResizing = false; - let lastDownX = 0; - let lastDownY = 0; - - handle.addEventListener('mousedown', (e: MouseEvent) => { - e.preventDefault(); - isResizing = true; - lastDownX = e.clientX; - lastDownY = e.clientY; - document.addEventListener('mousemove', handleMouseMove); - document.addEventListener('mouseup', handleMouseUp); - }); - - const handleMouseMove = (e: MouseEvent) => { - if (isResizing) { - const graphInfoDiv = this.shadowRoot?.getElementById('graph-info') as HTMLElement; - if (graphInfoDiv) { - const rect = graphInfoDiv.getBoundingClientRect(); - const width = rect.width + (lastDownX - e.clientX); - const height = rect.height + (e.clientY - lastDownY); - graphInfoDiv.style.width = `${Math.max(width, 100)}px`; - graphInfoDiv.style.height = `${Math.max(height, 65)}px`; - lastDownX = e.clientX; - lastDownY = e.clientY; - } - } - }; - - const handleMouseUp = () => { - isResizing = false; - //鼠标抬起重新渲染一遍组件 - this.selectNodeCopy = _.cloneDeep(this.selectedNode); - this.set('selectedNode', 'refresh'); // 临时切换到指定node名称,触发信息看板临时刷新 - this.set('selectedNode', this.selectNodeCopy); // 恢复原值 - - document.removeEventListener('mousemove', handleMouseMove); - document.removeEventListener('mouseup', handleMouseUp); - }; - } } fit() { (this.$.graph as any).fit(); diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/node.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/node.ts index 2b1e6d889a2806c2848860972ddc8d6bd148fb07..f3b44a0965a04dd3faa202a5018576a87f709277 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/node.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/node.ts @@ -1466,10 +1466,18 @@ export function getColors(colors) { colorStorage = colors; } -export function getMatched(matched) { +export function setMatched(matched) { matchStorage = matched; } -export function getUnMatched(unmatched) { +export function setUnMatched(unmatched) { unMatchStorage = unmatched; } + +export function getMatched() { + return matchStorage; +} + +export function getUnMatched() { + return unMatchStorage; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/tf-graph-controls.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/tf-graph-controls.ts index 6f92ca8c16a3fdbc9654016a28e932583be9f079..26242802f597fef5988cd484a860ab9997471650 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/tf-graph-controls.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/tf-graph-controls.ts @@ -34,6 +34,8 @@ import * as tf_graph_parser from '../tf_graph_common/parser'; import * as tf_graph_node from '../tf_graph_common/node'; import * as tf_node_info from '../tf_graph_info/tf-node-info'; +import type { MactchResult } from '../types/service'; + export interface Selection { run: string; tag: string | null; @@ -1028,7 +1030,7 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) _selectedBenchMatchMenu: number = -1; _selectedUnMatchMenu: number = -1; - ready() { + override ready() { super.ready(); document.addEventListener('contextMenuTag-changed', this._getTagChanged.bind(this)); } @@ -1378,6 +1380,7 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) this.showDynamicDialog('节点不可匹配'); return; } + // 请求参数 const params = new URLSearchParams(); const run = this.datasets[this._selectedRunIndex].name; const tag = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].tag; @@ -1385,33 +1388,24 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) if (tag) params.set('tag', tag); params.set('NPU', this.selectedNPUNode); params.set('Bench', this.selectedBenchNode); + // 接口请求 const precisionPath = getRouter().pluginRouteForSrc('graph_ascend', '/match', params); const precisionStr = await tf_graph_parser.fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer const decoder = new TextDecoder(); const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 - const dialogMessageMap: { [key: string]: string } = { - inputshape: '匹配失败,inputshape不匹配', - outputshape: '匹配失败,outputshape不匹配', - inputNone: '匹配失败,input中包含None值', - outputNone: '匹配失败,output中包含None值', - }; + // 接口返回 + const mactchResult: MactchResult = JSON.parse(decodedStr); - const message = dialogMessageMap[decodedStr]; - if (message) { - this.showDynamicDialog(message); // 使用动态生成的对话框 - } - let resultArray: any[] = []; - resultArray = JSON.parse(decodedStr) as any[]; - if (resultArray.length !== 0) { - this.push('matchednodeset', resultArray); + if (mactchResult.success) { + const mactchData = mactchResult.data; + this.push('matchednodeset', mactchData); if (this.unMatchednodeset.length !== 0) { const has = this.unMatchednodeset.indexOf(this.selectedNPUNode); if (has !== -1) { this.unMatchednodeset = [...this.unMatchednodeset.slice(0, has), ...this.unMatchednodeset.slice(has + 1)]; } } - tf_graph_node.getMatched(this.matchednodeset); - tf_node_info.getMatched(this.matchednodeset); + tf_graph_node.setMatched(this.matchednodeset); //节点上色 const index_N = this.NPU_unmatched.indexOf(this.selectedNPUNode.slice(4)); if (index_N !== -1) { this.splice('NPU_unmatched', index_N, 1); @@ -1433,7 +1427,7 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) this.selectedNPUNode = ''; this.selectedBenchNode = ''; } else { - this.showDynamicDialog('节点不可匹配'); + this.showDynamicDialog(mactchResult.error); } } @@ -1490,10 +1484,8 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) const precisionStr = await tf_graph_parser.fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer const decoder = new TextDecoder(); const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 - tf_graph_node.getMatched(this.matchednodeset); - tf_node_info.getMatched(this.matchednodeset); - tf_graph_node.getUnMatched(this.unMatchednodeset); - tf_node_info.getUnMatched(this.unMatchednodeset); + tf_graph_node.setMatched(this.matchednodeset); + tf_graph_node.setUnMatched(this.unMatchednodeset); this.set('selectedNode', ''); this.showDynamicDialog('已取消匹配'); } diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_resize_height/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_resize_height/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..dd6ce05800b24766f6eff77e6c6180390d6c9924 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_resize_height/index.ts @@ -0,0 +1,99 @@ +import { PolymerElement, html } from '@polymer/polymer'; +import { customElement, property, observe } from '@polymer/decorators'; + +@customElement('tf-resize-height') +class ResizableTabsheet extends PolymerElement { + static get template() { + return html` + + +
+
+ +
+ `; + } + + @property({ + type: Number, + notify: true, + }) + height: number = 300; + + _resize!: (event: MouseEvent) => void; + _stopResize!: (this: Document, ev: MouseEvent) => any; + + override ready() { + super.ready(); + this._initResizeHandle(); + } + + _initResizeHandle() { + const tabsheet = this.$.tabsheet; + const resizeHandle = this.$.resizeHandle; + + let isResizing = false; + let startY = 0; + let startHeight = 0; + + // 开始拖拽 + resizeHandle.addEventListener('mousedown', (event) => { + isResizing = true; + //@ts-ignore + startY = event.clientY; + //@ts-ignore + startHeight = tabsheet.offsetHeight; + document.body.style.cursor = 'ns-resize'; + document.addEventListener('mousemove', this._resize); + document.addEventListener('mouseup', this._stopResize); + }); + + // 拖拽过程 + this._resize = (event) => { + if (!isResizing) return; + const deltaY = startY - event.clientY; // 向上拖拽为正 + this.set('height', Math.max(10, startHeight + deltaY)); // 更新高度 + }; + + // 停止拖拽 + this._stopResize = () => { + isResizing = false; + document.body.style.cursor = ''; + document.removeEventListener('mousemove', this._resize); + document.removeEventListener('mouseup', this._stopResize); + }; + } + + @observe('height') + _updateHeight(newHeight) { + this.updateStyles({ '--tabsheet-height': `${newHeight}px` }); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaadin_table/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaadin_table/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..1f9e98ef4a550630ae648523ead0db445748a6cc --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaadin_table/index.ts @@ -0,0 +1,152 @@ +import { PolymerElement, html } from '@polymer/polymer'; +import { customElement, property, query } from '@polymer/decorators'; +import '@vaadin/grid'; // 引入新的 Vaadin Grid 组件 +import '@vaadin/tooltip'; +import type { GridEventContext } from '@vaadin/grid'; +@customElement('tf-vaadin-table') +class TfVaadinTable extends PolymerElement { + static get template() { + return html` + + + + `; + } + + renderDefaultValue!: (root: HTMLElement, column: any, rowData: any) => void; + tooltipGenerator!: (context: GridEventContext<{}>) => string; + + @property({ type: Object }) + syncGrid!: HTMLElement; // 点击高亮需要同步的表格元素 + + @property({ type: Object }) + handleCellClick!: (e: MouseEvent, syncGrid: HTMLElement) => void; + + @property({ + type: Array, + computed: '_computeHeaders(ioDataset)', + }) + headers!: any[]; + + @property({ + type: Boolean, + computed: '_isEmptyGrid(ioDataset)', + }) + isEmptyGrid!: false; + + override connectedCallback() { + super.connectedCallback(); + this.renderDefaultValue = this._renderDefaultValue.bind(this); + // this.tooltipGenerator = this._tooltipGenerator.bind(this); + } + /** + * 计算表头(所有唯一的键) + * @param {Array} data 数据源 + * @return {Array} 表头数组 + */ + _computeHeaders(data) { + if (!Array.isArray(data) || data.length === 0) { + return []; + } + const ignoreDataIndex = ['data_name', 'isBeach', 'isMatched', 'value']; + const headers = Array.from( + data.slice(0, 5).reduce((keys, item) => { + // 只取前5个数据项,避免性能问题 + Object.keys(item).forEach((key) => { + if (!ignoreDataIndex.includes(key)) { + keys.add(key); + } + }); + return keys; + }, new Set()), + ); + return headers; + } + + _isEmptyGrid(data) { + return !Array.isArray(data) || data.length === 0; + } + + _renderDefaultValue(root: HTMLElement, column: any, rowData: any) { + const selectedColor = this._getCssVariable('--selected-color'); + const matchedColor = this._getCssVariable('--matched-color'); + if (rowData.item['isBeach']) root.style.backgroundColor = matchedColor; + else root.style.backgroundColor = selectedColor; + if (column.path === 'name') { + const className = rowData.item['isMatched'] ? 'avater-matched' : 'avater-unmatched'; + root.innerHTML = `${rowData.item[column.path]}`; + return; + } + + root.textContent = rowData.item[column.path] || 'null'; + } + + _tooltipGenerator = (context: GridEventContext<{}>): string => { + const { column, item } = context; + return item?.[column?.path || ''] || ''; + }; + + handleGridClick(e: MouseEvent) { + this.handleCellClick(e, this.syncGrid); // 调用后方法的this会指向当前组件,无法拿到同级别的表格组件,所以需要回传 + } + _getCssVariable(variableName) { + const computedStyle = getComputedStyle(this); + return computedStyle.getPropertyValue(variableName).trim(); // 去掉多余的空格 + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaddin_text_table/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaddin_text_table/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..9665479856a071ed2621a6fc0e284ceca5832e9a --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaddin_text_table/index.ts @@ -0,0 +1,209 @@ +import { PolymerElement, html } from '@polymer/polymer'; +import { customElement, property, query } from '@polymer/decorators'; +import '@vaadin/grid'; // 引入新的 Vaadin Grid 组件 +import '@vaadin/tooltip'; +import type { GridEventContext } from '@vaadin/grid'; +import { Notification } from '@vaadin/notification'; +@customElement('tf-vaadin-text-table') +class TfVaadinTable extends PolymerElement { + static get template() { + return html` + + + + `; + } + + renderDefaultValue!: (root: HTMLElement, column: any, rowData: any) => void; + tooltipGenerator!: (context: GridEventContext<{}>) => string; + + @property({ type: Object }) + syncGrid!: HTMLElement; // 点击高亮需要同步的表格元素 + + @property({ type: Object }) + handleCellClick!: (e: MouseEvent, syncGrid: HTMLElement) => void; + + @property({ + type: Array, + computed: '_computeHeaders(dataset)', + }) + headers!: any[]; + + @property({ + type: Boolean, + computed: '_isEmptyGrid(dataset)', + }) + isEmptyGrid!: false; + + override connectedCallback() { + super.connectedCallback(); + this.renderDefaultValue = this._renderDefaultValue.bind(this); + // this.tooltipGenerator = this._tooltipGenerator.bind(this); + } + /** + * 计算表头(所有唯一的键) + * @param {Array} data 数据源 + * @return {Array} 表头数组 + */ + _computeHeaders(data) { + if (!Array.isArray(data) || data.length === 0) { + return []; + } + const ignoreDataIndex = ['title']; + const headers = Array.from( + data.slice(0, 5).reduce((keys, item) => { + // 只取前5个数据项,避免性能问题 + Object.keys(item).forEach((key) => { + if (!ignoreDataIndex.includes(key)) { + keys.add(key); + } + }); + return keys; + }, new Set()), + ); + return headers; + } + + _isEmptyGrid(data) { + return !Array.isArray(data) || data.length === 0; + } + + _renderDefaultValue(root: HTMLElement, column: any, rowData: any) { + const property = column.path; + const titleName = rowData.item['title']; + // root.textContent = rowData.item[column.path] || 'null'; + if (!root.firstElementChild) { + switch (titleName) { + case 'stackInfo': + case 'suggestions': + this._createCopyableTextarea(root, property, rowData); + break; + default: + root.style.fontWeight = 'bold'; + root.textContent = titleName + ':' + (rowData.item[property] || 'null'); + break; + } + } else { + switch (titleName) { + case 'stackInfo': + case 'suggestions': + this._updateCopyableTextarea(root, property, rowData); + break; + default: + root.textContent = titleName + ':' + (rowData.item[property] || 'null'); + } + } + } + + _createCopyableTextarea(root: HTMLElement, property: any, rowData: any) { + const container = document.createElement('div'); + container.className = 'copyable-input'; + + const title = document.createElement('div'); + title.className = 'copyable-input-title'; + title.style.fontWeight = 'bold'; + title.textContent = rowData.item['title'] + ':'; + container.appendChild(title); + + const textarea = document.createElement('textarea'); + textarea.readOnly = true; + textarea.value = rowData.item[property]; + container.appendChild(textarea); + + const button = document.createElement('button'); + button.className = 'copy-button'; + button.textContent = '复制'; + button.addEventListener('click', () => { + textarea.select(); + document.execCommand('copy'); + Notification.show('复制成功', { + position: 'middle', + duration: 1000, + theme: 'success', + }); + }); + container.appendChild(button); + + root.appendChild(container); + } + + _updateCopyableTextarea(root: HTMLElement, property: any, rowData: any) { + const title = root.querySelector('.copyable-input-title'); + if (title) title.textContent = rowData.item['title'] + ':'; + const textarea = root.querySelector('textarea'); + if (textarea) textarea.value = rowData.item[property]; + } + + _tooltipGenerator = (context: GridEventContext<{}>): string => { + const { column, item } = context; + return item?.[column?.path || ''] || ''; + }; + + handleGridClick(e: MouseEvent) { + this.handleCellClick(e, this.syncGrid); // 调用后方法的this会指向当前组件,无法拿到同级别的表格组件,所以需要回传 + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/domain/useNodeInfoDomain.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/domain/useNodeInfoDomain.ts new file mode 100644 index 0000000000000000000000000000000000000000..ba768dc110b6fcceb76de3a111ab000f2adf6221 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/domain/useNodeInfoDomain.ts @@ -0,0 +1,24 @@ +import { getRouter } from '../../tf_backend/router'; +import { fetchPbTxt } from '../../tf_graph_common/parser'; + +const useNodeInfoDomain = () => { + const getMatchNodeInfo = async (nodeInfo: any, metaData: any) => { + const params = new URLSearchParams(); + params.set('nodeInfo', JSON.stringify(nodeInfo)); + params.set('metaData', JSON.stringify(metaData)); + // 接口请求 + const precisionPath = getRouter().pluginRouteForSrc('graph_ascend', '/getNodeInfo', params); + const precisionStr = await fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer + const decoder = new TextDecoder(); + const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 + // 接口返回 + const mactchResult = JSON.parse(decodedStr); + return mactchResult; + }; + + return { + getMatchNodeInfo, + }; +}; + +export default useNodeInfoDomain; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..e4cce26e63d2807a84b69fe104e168e053734d79 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/index.ts @@ -0,0 +1,274 @@ +import '@vaadin/tabs'; +import '@vaadin/tabsheet'; +import { Notification } from '@vaadin/notification'; +import { isEmpty } from 'lodash'; +import { PolymerElement, html } from '@polymer/polymer'; +import { observe, customElement, property } from '@polymer/decorators'; + +import useNodeInfo from './useNodeInfo'; + +import './components/tf_vaadin_table/index'; +import './components/tf_vaddin_text_table/index'; +import './components/tf_resize_height/index'; +import * as tf_graph_hierarchy from '../tf_graph_common/hierarchy'; + +import type { useNodeInfoType } from './useNodeInfo'; + +@customElement('tf-graph-vaadin-tab') +class TfGraphNodeInfo extends PolymerElement { + static get template() { + return html` + + + + 比对详情 + 节点信息 + + +
+ +
+
+
+

+ + NPU节点:[[npuNodeName]] +

+ +
+
+

+ + 已匹配 +

+

+ + 未匹配 +

+
+
+ + + +
+
+
+ +
+ +
+ + +
+
+
+
+ `; + } + + @property({ + type: Object, + computed: '_getNode(selectedNode)', + }) + _node: any; + + @property({ + type: Object, + }) + _matchedNode: any; + + @property({ type: Object }) + selectedNode: any; + + @property({ type: Object }) + selection: any; + + @property({ type: Object }) + graphHierarchy?: tf_graph_hierarchy.Hierarchy; + + @property({ type: Array }) + ioDataset: any[] = []; + + @property({ type: Array }) + detailData: any[] = []; + + useNodeInfo: useNodeInfoType = useNodeInfo(); + + @property({ type: String }) + npuNodeName!: string; + + @property({ type: String }) + beachNodeName!: string; + + @observe('_node') + async _updateInputsData() { + if (!this._node) { + this.set('ioDataset', []); + this.set('detailData', []); + return; + } + const isSingleGraphNode = !this._node?.name?.startsWith('N___') && !this._node?.name?.startsWith('B___'); + if (!isSingleGraphNode) await this._updateMatchedNodeLink(this.selectedNode); + // 考虑选中的节点是匹配节点的情况 + const npuNode = this._node.name.startsWith('N___') || isSingleGraphNode ? this._node : this._matchedNode; + const beachNode = isSingleGraphNode + ? undefined + : this._node.name.startsWith('N___') + ? this._matchedNode + : this._node; + this.set('npuNodeName', npuNode?.name); + this.set('beachNodeName', beachNode?.name); + const inputDataset = this.useNodeInfo.getIoDataSet(npuNode, beachNode, 'inputData'); + const outputDataSet = this.useNodeInfo.getIoDataSet(npuNode, beachNode, 'outputData'); + const ioDataset = [ + ...inputDataset.matchedDataset, + ...outputDataSet.matchedDataset, + ...inputDataset.unMatchedNpuDataset, + ...outputDataSet.unMatchedNpuDataset, + ...inputDataset.unMatchedBeachDataset, + ...outputDataSet.unMatchedBeachDataset, + ]; + this.set('ioDataset', ioDataset); + const detailData = this.useNodeInfo.getDetailDataSet(npuNode, beachNode); + this.set('detailData', detailData); + } + + _getNode(selectedNode) { + return this.graphHierarchy?.node(selectedNode); + } + + async _updateMatchedNodeLink(selectedNode) { + //@ts-ignore + const matchedNodes: Array = this.graphHierarchy?.node(selectedNode)?.matchedNodeLink; + if (isEmpty(matchedNodes)) { + this.set('_matchedNode', null); + return; + } + const matchNodeName = matchedNodes[matchedNodes.length - 1]; + let matchNode = matchNodeName ? this.graphHierarchy?.node(matchNodeName) : null; + // 如果没有匹配节点,则通过接口获取匹配节点的信息 + if (isEmpty(matchNode)) { + const nodeInfo = { + nodeName: matchNodeName?.replace(/^(N___|B___)/, ''), // 去掉前缀 + nodeType: this._node.name.startsWith('N___') ? 'Bench' : 'NPU', + }; + const { success, data, error } = await this.useNodeInfo.getNodeInfo(nodeInfo, this.selection); + if (success) { + //@ts-ignore + matchNode = data; + } else { + Notification.show(`获取匹配节点失败:${error}`, { + position: 'middle', + duration: 2000, + theme: 'error', + }); + } + } + this.set('_matchedNode', matchNode); + } + + // 点击单元格高亮 + handleGridCellClick(e: MouseEvent, syncGrid: HTMLElement) { + const target = e.composedPath()[0] as HTMLElement; // 获取点击的目标元素 + const slotValue = target.getAttribute('slot'); // 提取 slot 属性 + if (!slotValue || !slotValue.startsWith('vaadin-grid-cell-content-')) return; + const cellIndex = parseInt(slotValue.split('-').pop() || '0', 10); + if (cellIndex <= 8) return; // 前8个元素是表头不可选中,所以跳过 + const highlightedCells = this.shadowRoot?.querySelectorAll('.highlight-cell'); + highlightedCells?.forEach((cell) => cell.classList.remove('highlight-cell')); // 清除所有高亮样式 + target.classList.add('highlight-cell'); // 添加高亮样式到当前单元格 + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/type/index.d.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/type/index.d.ts new file mode 100644 index 0000000000000000000000000000000000000000..61af59f8a79ef35d37e4a5c47e9a1db00b2968fc --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/type/index.d.ts @@ -0,0 +1,13 @@ +export type MatchNodeInfo = { + success: boolean; + data?: { + name?: string; + inputData?: {}; + outputData?: {}; + stackData?: string; + suggestions?: {}; + include?: number; + subnodes?: Array; + }; + error?: string; +}; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/useNodeInfo.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/useNodeInfo.ts new file mode 100644 index 0000000000000000000000000000000000000000..57b970e89d80245dc8a29da80d51cf27caa798ac --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/useNodeInfo.ts @@ -0,0 +1,229 @@ +import { isEmpty, cloneDeep } from 'lodash'; +import { MactchResult } from '../types/service'; +import useNodeInfoDomain from './domain/useNodeInfoDomain'; +import * as tf_graph_node_common from '../tf_graph_common/node'; +import { MatchNodeInfo } from './type'; + +const useNodeInfo = () => { + const useNodeInfoService = useNodeInfoDomain(); + const nodeInfoCache = new Map(); + /** + * 获取节点信息 + * @param node_name + * @param graph_name + * @param run_name + * @returns + */ + const getNodeInfo = async ( + nodeInfo: { nodeName: string; nodeType: string }, + metaData: any, + ): Promise => { + if (sessionStorage.getItem(JSON.stringify(nodeInfo))) { + //缓存中存在 + return JSON.parse(sessionStorage.getItem(JSON.stringify(nodeInfo)) || '') as MatchNodeInfo; + } + const mactchResult = await useNodeInfoService.getMatchNodeInfo(nodeInfo, metaData); + mactchResult.data = convertNodeInfo(mactchResult.data); //提取有效数据,统一命名 + sessionStorage.setItem(JSON.stringify(nodeInfo), JSON.stringify(mactchResult)); //缓存 + return mactchResult; + }; + + const convertNodeInfo = (nodeInfo: any): MatchNodeInfo['data'] => { + if (isEmpty(nodeInfo)) { + return {}; + } + return { + name: nodeInfo.id, + inputData: nodeInfo.input_data, + outputData: nodeInfo.output_data, + stackData: !isEmpty(nodeInfo.stack_info) ? JSON.stringify(nodeInfo.stack_info) : '', + suggestions: nodeInfo.suggestions, + subnodes: nodeInfo.subnodes, + include: nodeInfo.subnodes?.length, + }; + }; + + /** + * 将匹配结果数组转换为对象 + * @param resArray + * @example + * [request.args.get("NPU"), + * max_precision_index, + * [ + * [ + ['npu_key', npu_key], + ['Max diff', NPU_max - Bench_max], + ['Min diff', NPU_min - Bench_min], + ['Mean diff', NPU_mean - Bench_mean], + ['L2norm diff', NPU_norm - Bench_norm], + ['MaxRelativeErr', max_relative_err], + ['MinRelativeErr', min_relative_err], + ['MeanRelativeErr', mean_relative_err], + ['NormRelativeErr', norm_relative_err] + ] + ..... + ], + * output(同上)] + * @returns + * { + * NPU_node_name: string, + * max_precision_error: number, + * input: {inputargs:{ + * Max diff: number, + * Min diff: number, + * Mean diff: number, + * },{}} + * output: [{ + * Max diff: number, + * Min diff: number, + * Mean diff: number, + * }] + */ + const converMatchArrayToObject = (resArray: Array): Array => { + if (resArray.length == 0) return []; + //内部函数,将二维数组转换为对象 + const _covertIo = (arrayData: Array) => { + if (isEmpty(arrayData)) return undefined; + const inputItems = {}; + arrayData.forEach((inputArray: any) => { + const inputKey = inputArray[0]; + const inputValue = {}; + inputArray.slice(1).map((item: any) => { + inputValue[item[0]] = item[1]; + }); + inputItems[inputKey] = inputValue; + }); + return inputItems; + }; + + //将数组转换为对象 + const matchedeDataset = resArray.map((item) => { + return { + NPU_node_name: item?.[0], + max_precision_error: item?.[1], + inputData: _covertIo(item?.[2]), + outputData: _covertIo(item?.[3]), + }; + }); + + return matchedeDataset as Array; + }; + + /** + * 获取匹配的输入输出数据 + * @param npuNode NPU节点信息 + * @param beachNode 匹配的节点信息 + * @param name 'inputData' | 'outputData' + * @returns { matchedDataset: Array<{}>, unMatchedNpuDataset: Array<{}>, unMatchedBeachDataset: Array<{}> } + */ + const getIoDataSet = ( + npuNode: any, + beachNode: any, + type: 'inputData' | 'outputData', + ): { matchedDataset: Array<{}>; unMatchedNpuDataset: Array<{}>; unMatchedBeachDataset: Array<{}> } => { + if (isEmpty(npuNode?.[type]) && isEmpty(beachNode?.[type])) { + return { + matchedDataset: [], + unMatchedNpuDataset: [], + unMatchedBeachDataset: [], + }; + } + const npuNodeName = npuNode?.name.replace(/^(N___|B___)/, ''); + const beachNodeName = beachNode?.name?.replace(/^(N___|B___)/, ''); + const npuData = cloneDeep(npuNode?.[type]); // 获取当前节点的输入数据 + const beachData = cloneDeep(beachNode?.[type]); // 获取匹配节点的输入数据 + const matchStorage = converMatchArrayToObject(tf_graph_node_common.getMatched()); // 将匹配数组转换为对象 + const matchItem: MactchResult['result'] = matchStorage.find( + (matchedItem) => matchedItem?.NPU_node_name === npuNode?.name, + ); // 获取两个节点数据的不同 + const matchedDataset: Array<{}> = []; // 初始化输入数据集 + const unMatchedBeachDataset: Array<{}> = []; + const unMatchedNpuDataset: Array<{}> = []; + const npuKeys = Object.keys(npuData || {}); + const beachKeys = Object.keys(beachData || {}); + const minLength = Math.min(npuKeys.length, beachKeys.length); + for (let i = 0; i < minLength; i++) { + const npuKey = npuKeys[i]; + const beachKey = beachKeys[i]; + matchedDataset.push({ + name: npuKey.replace(`${npuNodeName}.`, ''), + isMatched: true, + ...npuData[npuKey], + ...matchItem?.[type]?.[npuKey], + }); + matchedDataset.push({ + name: beachKey.replace(`${beachNodeName}.`, ''), + isBeach: true, + isMatched: true, + ...beachData[beachKey], + }); + delete npuData[npuKey]; + delete beachData[beachKey]; + } + Object.keys(npuData || {}).forEach((key) => { + unMatchedNpuDataset.push({ + name: key.replace(`${npuNodeName}.`, ''), + ...npuData[key], + }); + }); + Object.keys(beachData || {}).forEach((key) => { + unMatchedBeachDataset.push({ + name: key.replace(`${beachNodeName}.`, ''), + isBeach: true, + ...beachData[key], + }); + }); + console.log(matchedDataset, unMatchedNpuDataset, unMatchedBeachDataset); + return { matchedDataset, unMatchedNpuDataset, unMatchedBeachDataset }; + }; + + const getDetailDataSet = (npuNode: any, beachNode: any): {} => { + console.log(npuNode, beachNode); + if (isEmpty(npuNode) && isEmpty(beachNode)) { + return []; + } + + const nodeName = `NPU节点:${npuNode?.name?.replace(/^(N___|B___)/, '')}`; + const beachNodeName = `标杆节点:${beachNode?.name?.replace(/^(N___|B___)/, '')}`; + const detailData: Array<{}> = []; + // 获取stackInfo + const stackInfo: {} = {}; + const npustackInfo = npuNode?.stackData; + const beachstackInfo = beachNode?.stackData; + if (!isEmpty(npustackInfo)) stackInfo[nodeName] = JSON.parse(npustackInfo.replace(/'/g, '"')).join('\n'); + if (!isEmpty(beachstackInfo)) stackInfo[beachNodeName] = JSON.parse(beachstackInfo.replace(/'/g, '"')); + if (!isEmpty(stackInfo)) { + stackInfo['title'] = 'stackInfo'; + detailData.push(stackInfo); + } + // 获取suggestions + const suggestion: {} = {}; + const npusuggestion = npuNode.suggestions; + const beachsuggestion = beachNode?.suggestions; + if (!isEmpty(npusuggestion)) suggestion[nodeName] = converObjectToString(npusuggestion); + if (!isEmpty(beachsuggestion)) suggestion[`**${beachNodeName}`] = converObjectToString(beachsuggestion); + if (!isEmpty(suggestion)) { + suggestion['title'] = 'suggestions'; + detailData.push(suggestion); + } + console.log(detailData); + return detailData; + }; + + const converObjectToString = (obj: any) => { + return Object.entries(obj) + .map(([key, value]) => `${key}: ${value}`) + .join('\n'); + }; + + return { + getNodeInfo, + getIoDataSet, + getDetailDataSet, + converMatchArrayToObject, + }; +}; + +export default useNodeInfo; + +export type useNodeInfoType = ReturnType; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/types/service.d.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/types/service.d.ts new file mode 100644 index 0000000000000000000000000000000000000000..66b146ab11fa35fb14ba40996c900f3744adfb6b --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/types/service.d.ts @@ -0,0 +1,41 @@ +/** + * /mactch 接口返回值 + * @typedef {Object} mactchResult + * @property {boolean} success 是否匹配成功 + * @property {string} error 错误信息 + * @property {Object} data 匹配成功的数据 + */ +export type MactchResult = { + success: boolean; + error?: string; + data?: []; + result?: { + NPU_node_name: string; + max_precision_error: string; + inputData: { + [key: string]: { + 'Max diff': number; + 'Min diff': number; + 'Mean diff': number; + 'L2norm diff': number; + MaxRelativeErr: number; + MinRelativeErr: number; + MeanRelativeErr: number; + NormRelativeErr: number; + }; + }; + outputData: { + [key: string]: { + 'Max diff': number; + 'Min diff': number; + 'Mean diff': number; + 'L2norm diff': number; + MaxRelativeErr: number; + MinRelativeErr: number; + MeanRelativeErr: number; + NormRelativeErr: number; + }; + }; + result_array: []; + }; +}; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/controllers/graph_controller.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/controllers/graph_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c49b7ad34147479fd46e192c2a49186b379d36 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/controllers/graph_controller.py @@ -0,0 +1,17 @@ +from ..utils.graph_utils import GraphUtils + + +class GraphController: + + def get_node_info(node_info, meta_data): + run = meta_data.get('run') + tag = meta_data.get('tag') + nodeType = node_info.get('nodeType') + nodeName = node_info.get('nodeName') + graph_data = GraphUtils.check_jsondata(tag) + if graph_data is None: + graph_data, error_message = GraphUtils.get_jsondata(run, tag) + if error_message is not None: + return {'success': False, 'error': error_message} + node_details = graph_data.get(nodeType).get('node').get(nodeName) + return {'success': True, 'data': node_details} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py new file mode 100644 index 0000000000000000000000000000000000000000..b4162e5f116ce6a542feca42a6436f89164284a9 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/global_state.py @@ -0,0 +1,33 @@ +# utils/global_state.py + +# 模块级全局变量 +_state = {'logdir': '', "current_tag": '', "current_file_path": '', 'current_file_data': {}} + + +def init_defaults(): + """ + 初始化全局变量的默认值 + """ + global _state + _state = {'logdir': '', "current_tag": '', "current_file_path": '', 'current_file_data': {}} + + +def set_global_value(key, value): + """ + 设置全局变量的值 + """ + _state[key] = value + + +def get_global_value(key, default=None): + """ + 获取全局变量的值 + """ + return _state.get(key, default) + + +def reset_global_state(): + """ + 重置所有全局变量为默认值 + """ + init_defaults() diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..425aa1d05acc389d06002a85be9c5233d70d189d --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py @@ -0,0 +1,59 @@ +import json +import os +from tensorboard.util import tb_logging +from .global_state import get_global_value, set_global_value + +logger = tb_logging.get_logger() + + +class GraphUtils: + + # 检查是否使用缓存 + @staticmethod + def check_jsondata(tag): + _current_tag = get_global_value('current_tag') + if _current_tag == tag: + return get_global_value('current_file_data') + else: + return None + + # 读取json文件 + @staticmethod + def get_jsondata(run, tag): + json_data = None + error_message = None + logdir = get_global_value('logdir') + if run is None or tag is None: + error_message = 'The query parameters "run" and "tag" are required' + return json_data, error_message + run_dir = os.path.join(logdir, run) + file_path = GraphUtils._load_json_file(run_dir, tag) + json_data = GraphUtils._read_json_file(file_path) + set_global_value('current_file_data', json_data) + set_global_value('current_tag', tag) + if json_data is None: + error_message = f'vis file for tag "{tag}" not found in run "{run}"' + return json_data, error_message + + @staticmethod + def _load_json_file(run_dir, tag): + """Load a single .vis file from a given directory based on the tag.""" + file_path = os.path.join(run_dir, f"{tag}.vis") + if os.path.exists(file_path): + # Store the path of the current file instead of loading it into memory + set_global_value('current_file_path', tag) + return file_path + return None + + @staticmethod + def _read_json_file(file_path): + """Read and parse a JSON file from disk.""" + if file_path and os.path.exists(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + try: + return json.load(f) + except Exception as e: + logger.error(f'Error: the vis file "{file_path}" is not a legal JSON file!') + else: + logger.error(f'Error: the vis file "{file_path}" is not a legal JSON file!') + return None diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1399c6a6cd50bc045d19b82142db79a31e3ba8 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py @@ -0,0 +1,15 @@ +import json + +from werkzeug import wrappers, Response, exceptions +from tensorboard.backend import http_util +from ..controllers.graph_controller import GraphController + + +class GraphView: + # 获取当前节点对应节点的信息看板数据 + @wrappers.Request.application + def get_node_info(request): + node_info = json.loads(request.args.get("nodeInfo")) + meta_data = json.loads(request.args.get("metaData")) + node_detail = GraphController.get_node_info(node_info, meta_data) + return http_util.Respond(request, node_detail, "application/json") diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/constants.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/constants.py index a795e7f55a2f794290ba5487cddebe42556a61b9..04457951cfa96258ac66c8e90ed0a567f3a67b4c 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/constants.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/constants.py @@ -21,7 +21,7 @@ SETS = { 'Bench': ('Bench', 'B___', 'N___'), 'NPU': ('NPU', 'N___', 'B___'), 'B___': ('Bench', 'N___'), - 'N___': ('NPU', 'B___') + 'N___': ('NPU', 'B___'), } NA_DATA = [ ['Max diff', 'N/A'], @@ -31,15 +31,22 @@ NA_DATA = [ ['MaxRelativeErr', 'N/A'], ['MinRelativeErr', 'N/A'], ['MeanRelativeErr', 'N/A'], - ['NormRelativeErr', 'N/A'] + ['NormRelativeErr', 'N/A'], ] -PREFIX_MAP = { - 'N___': 'NPU', - 'B___': 'Bench' -} -UNMATCH_NAME_SET =['Max diff', 'Min diff', 'Mean diff', 'L2norm diff', 'MaxRelativeErr', 'MinRelativeErr', 'MeanRelativeErr', 'NormRelativeErr', - 'Cosine', 'MaxAbsErr', 'MaxRelativeErr', 'One Thousandth Err Ratio', 'Five Thousandth Err Ratio'] -SCREEN_MAP = { - 'precision_index': 'precision_index', - 'overflow_level': 'overflow_level' - } \ No newline at end of file +PREFIX_MAP = {'N___': 'NPU', 'B___': 'Bench'} +UNMATCH_NAME_SET = [ + 'Max diff', + 'Min diff', + 'Mean diff', + 'L2norm diff', + 'MaxRelativeErr', + 'MinRelativeErr', + 'MeanRelativeErr', + 'NormRelativeErr', + 'Cosine', + 'MaxAbsErr', + 'MaxRelativeErr', + 'One Thousandth Err Ratio', + 'Five Thousandth Err Ratio', +] +SCREEN_MAP = {'precision_index': 'precision_index', 'overflow_level': 'overflow_level'} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py index f34a939576cfc3cfcca8101708f9a2a9a91204ee..cd2f34f8d770604477ded32431d0dd81327b347b 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py @@ -20,15 +20,18 @@ import json import os from werkzeug import wrappers, Response, exceptions - from tensorboard import errors from tensorboard.backend import http_util from tensorboard.plugins import base_plugin from tensorboard.util import tb_logging + from . import constants +from .app.views.graph_views import GraphView +from .app.utils.global_state import set_global_value logger = tb_logging.get_logger() + class GraphsPlugin(base_plugin.TBPlugin): """Graphs Plugin for TensorBoard.""" @@ -44,14 +47,16 @@ class GraphsPlugin(base_plugin.TBPlugin): super().__init__(context) self._data_provider = context.data_provider self.logdir = os.path.abspath(os.path.expanduser(context.logdir.rstrip('/'))) + # 将logdir赋值给global_state中的logdir属性,方便其他模块使用 + set_global_value('logdir', os.path.abspath(os.path.expanduser(context.logdir.rstrip('/')))) self._current_file_path = None # Store the path of the currently loaded file self._current_file_data = None # Store the data of the currently loaded file - self._current_tag = None # Store the tag of the currently loaded file + self._current_tag = None # Store the tag of the currently loaded file self.batch_id = 0 # 将 batch_id 声明为实例变量 - self.step_id = 0 # 可以同样声明 step_id - self.dfs_node_ids = [] # batch和step没变的话就将所有的nodename存起来,方便快速读取 + self.step_id = 0 # 可以同样声明 step_id + self.dfs_node_ids = [] # batch和step没变的话就将所有的nodename存起来,方便快速读取 self.check_batch_id = 0 # 来配合node_ids监察用的,他不变node_ids就不用重新读取了 - self.check_step_id = 0 # 同上 + self.check_step_id = 0 # 同上 self.check_tag = None def get_plugin_apps(self): @@ -67,6 +72,7 @@ class GraphsPlugin(base_plugin.TBPlugin): "/parent": self.get_parent_node, "/rank": self.get_rank, "/subgraph": self.subgraph_route, + '/getNodeInfo': GraphView.get_node_info, } def is_active(self): @@ -96,10 +102,7 @@ class GraphsPlugin(base_plugin.TBPlugin): def add_row_item(run, tag=None, is_vis=False): run_item = result.setdefault( run, - { - "run": run, - "tags": {} - }, + {"run": run, "tags": {}}, ) tag_item = None @@ -130,7 +133,7 @@ class GraphsPlugin(base_plugin.TBPlugin): if run is None or tag is None: error_message = 'The query parameters "run" and "tag" are required' return json_data, error_message - + run_dir = os.path.join(self.logdir, run) file_path = self._load_json_file(run_dir, tag) json_data = self._read_json_file(file_path) @@ -139,7 +142,7 @@ class GraphsPlugin(base_plugin.TBPlugin): error_message = f'vis file for tag "{tag}" not found in run "{run}"' return json_data, error_message - + # 拿所有nodename的 def get_all_nodeName(self, json_data, request): npu_ids, bench_ids = [], [] @@ -151,12 +154,16 @@ class GraphsPlugin(base_plugin.TBPlugin): # 获取 NPU 和 Bench 数据 npu_data = self.json_get(json_data, 'NPU') bench_data = self.json_get(json_data, 'Bench') + def extract_ids(nodes_data, id_list): for node_name in nodes_data.get("node"): id_list.append(node_name) + def traverse_npu(subnodes): for node in subnodes: - node_data = self.json_get(npu_data, 'node', node) if npu_data else self.json_get(json_data, 'node', node) + node_data = ( + self.json_get(npu_data, 'node', node) if npu_data else self.json_get(json_data, 'node', node) + ) micro_step_id = node_data.get('micro_step_id') if micro_step_id == batch or micro_step_id is None: npu_ids.append(node) @@ -182,8 +189,11 @@ class GraphsPlugin(base_plugin.TBPlugin): step = int(request.args.get("step")) except ValueError: logger.error('The param "batch" or "step" does not exist or not a valid value') + def should_include_node(micro_step_id, step_id): - return (micro_step_id is batch or batch == -1 or micro_step_id is None) and (step_id is step or step == -1 or step_id is None) + return (micro_step_id is batch or batch == -1 or micro_step_id is None) and ( + step_id is step or step == -1 or step_id is None + ) nodes_data = self.json_get(json_data, 'NPU', 'node') or self.json_get(json_data, 'node') for node in nodes_data: @@ -192,8 +202,8 @@ class GraphsPlugin(base_plugin.TBPlugin): if should_include_node(micro_step_id, step_id) and not self.json_get(nodes_data, node, 'subnodes'): all_node_names.append(node) return all_node_names - - #拿所有precisonNodes的,与controls的精度筛选联动 + + # 拿所有precisonNodes的,与controls的精度筛选联动 @wrappers.Request.application def get_all_screenNodes(self, request): grouped_screen_set, inaccuracy_node_ids = [], [] @@ -211,18 +221,22 @@ class GraphsPlugin(base_plugin.TBPlugin): if '无匹配节点' in precision_set_str: precision_set_str = [p for p in precision_set_str if p != '无匹配节点'] precision_none = 1 - grouped_screen_set = [list(map(float, precision_set_str[i:i+2])) for i in range(0, len(precision_set_str), 2)] + grouped_screen_set = [ + list(map(float, precision_set_str[i : i + 2])) for i in range(0, len(precision_set_str), 2) + ] else: - grouped_screen_set = screen_set + grouped_screen_set = screen_set tag = request.args.get("tag") json_data = self.check_jsondata(request) + def has_conditions_changed(tag): return ( - self.check_batch_id != self.batch_id or - self.check_step_id != self.step_id or - self.check_tag != tag or - self.check_tag is None + self.check_batch_id != self.batch_id + or self.check_step_id != self.step_id + or self.check_tag != tag + or self.check_tag is None ) + if has_conditions_changed(tag): self.dfs_node_ids = self.dfs_collect_nodes(json_data, request) self.check_batch_id = self.batch_id @@ -230,7 +244,9 @@ class GraphsPlugin(base_plugin.TBPlugin): self.check_tag = tag node_ids = self.dfs_node_ids for node in node_ids: - node_data = self.json_get(json_data, 'NPU', 'node', node, 'data') or self.json_get(json_data, 'node', node, 'data') + node_data = self.json_get(json_data, 'NPU', 'node', node, 'data') or self.json_get( + json_data, 'node', node, 'data' + ) inaccuracy = node_data.get(screen) if node_data is not None else None # 如果 inaccuracy 为 None,直接检查是否符合条件 if inaccuracy is None: @@ -251,20 +267,25 @@ class GraphsPlugin(base_plugin.TBPlugin): logger.error(f'The inaccuracy in {node} is not a valid value') return http_util.Respond(request, inaccuracy_node_ids, "application/json") - + def group_precision_set(self, precision_set): if len(precision_set) % 2 != 0: raise ValueError('The number of elements in precision_set is not even') - grouped_precision_set = [precision_set[i:i+2] for i in range(0, len(precision_set), 2)] + grouped_precision_set = [precision_set[i : i + 2] for i in range(0, len(precision_set), 2)] return grouped_precision_set - + def get_all_unmatchedNodes(self, allNodeName, request): json_data = self.check_jsondata(request) is_npu_present = 'NPU' in json_data + def collect_unmatched_nodes(node_list, *path): return [node for node in node_list if not self.json_get(json_data, *path, node, 'matched_node_link')] - NPU_Unmatched = collect_unmatched_nodes(allNodeName[0], 'NPU', 'node') if is_npu_present else \ - collect_unmatched_nodes(allNodeName[0], 'node') + + NPU_Unmatched = ( + collect_unmatched_nodes(allNodeName[0], 'NPU', 'node') + if is_npu_present + else collect_unmatched_nodes(allNodeName[0], 'node') + ) Bench_Unmatched = collect_unmatched_nodes(allNodeName[1], 'Bench', 'node') if is_npu_present else [] return [NPU_Unmatched, Bench_Unmatched] @@ -280,11 +301,13 @@ class GraphsPlugin(base_plugin.TBPlugin): Bench_node_data = self.json_get(json_data, 'Bench', 'node', Bench_node) # 获取输入输出数据 NPU_input_data, Bench_input_data, NPU_output_data, Bench_output_data = self.get_input_output_data( - NPU_node_data, Bench_node_data) + NPU_node_data, Bench_node_data + ) # 计算输入输出的最小长度 - input_min_length, output_min_length = self.calculate_min_length(NPU_input_data, Bench_input_data, - NPU_output_data, Bench_output_data) - precision_index, max_precision_index, result = -1, -1, 0 + input_min_length, output_min_length = self.calculate_min_length( + NPU_input_data, Bench_input_data, NPU_output_data, Bench_output_data + ) + max_precision_index = -1 input, output = [], [] if input_min_length > 0: npu_input_keys, bench_input_keys = self.get_keys(NPU_input_data, Bench_input_data) @@ -296,12 +319,11 @@ class GraphsPlugin(base_plugin.TBPlugin): 'npu_keys': npu_input_keys, 'bench_keys': bench_input_keys, 'data_type': data_type, - 'precision_index': precision_index, - 'file_path': file_path + 'precision_index': -1, # 初始值 + 'file_path': file_path, } - max_precision_index = self.process_input_data(input, data_set, NPU_node) - if max_precision_index > precision_index: - precision_index = max_precision_index + max_input_precision_index = self.process_input_data(input, data_set, NPU_node) + max_precision_index = max(max_precision_index, max_input_precision_index) if output_min_length > 0: npu_output_keys, bench_output_keys = self.get_keys(NPU_output_data, Bench_output_data) @@ -313,32 +335,38 @@ class GraphsPlugin(base_plugin.TBPlugin): 'npu_keys': npu_output_keys, 'bench_keys': bench_output_keys, 'data_type': data_type, - 'precision_index': precision_index, - 'file_path': file_path + 'precision_index': -1, # 初始值 + 'file_path': file_path, } - max_precision_index = self.process_output_data(output, data_set, NPU_node) - if max_precision_index > precision_index: - precision_index = max_precision_index - - result = [request.args.get("NPU"), precision_index, input, output] - - if not input and not output and precision_index == -1: - return http_util.Respond(request, [], "application/json") - - # 添加匹配的节点链接 - matched_result = self.add_matched_node_link(NPU_node, Bench_node) - if not matched_result: - result = {"error": f"Unable to link matched nodes for NPU_node: {NPU_node} and Bench_node: {Bench_node}"} - return http_util.Respond(request, result, "application/json", status_code=400) - - precision_result = self.add_precision_index(NPU_node, precision_index) - if not precision_result: - result = {"error": f"Unable to add precision index for NPU_node: {NPU_node} with precision_index: {precision_index}"} - return http_util.Respond(request, result, "application/json", status_code=400) - + max_output_precision_index = self.process_output_data(output, data_set, NPU_node) + max_precision_index = max(max_precision_index, max_output_precision_index) + + if not input and not output and max_precision_index == -1: + result = {"success": False, "error": "input and output are empty"} + return http_util.Respond(request, result, "application/json") + + if not self.add_matched_node_link(NPU_node, Bench_node): + result = { + "success": False, + "error": f"Unable to link matched nodes for NPU_node: {NPU_node} and Bench_node: {Bench_node}", + } + return http_util.Respond(request, result, "application/json") + + if not self.add_precision_index(NPU_node, max_precision_index): + result = { + "success": False, + "error": f"Unable to add precision index for NPU_node: {NPU_node} with precision_index: {max_precision_index}", + } + return http_util.Respond(request, result, "application/json") + # 保存修改后的内容 self.save_results(file_path) + result = { + "success": True, + "data": [request.args.get("NPU"), max_precision_index, input, output], + } + return http_util.Respond(request, result, "application/json") def get_input_output_data(self, NPU_node_data, Bench_node_data): @@ -374,11 +402,11 @@ class GraphsPlugin(base_plugin.TBPlugin): def calculate_diff_and_relative_error(self, npu_data, bench_data): """计算最大值、最小值、均值和L2范数的差异以及相对误差""" results = {} - + for key in ['Max', 'Min', 'Mean', 'Norm']: NPU_value = self.safe_float_conversion(npu_data.get(key, 0)) Bench_value = self.safe_float_conversion(bench_data.get(key, 0)) - + # 计算差异 if NPU_value is not None: results[f'NPU_{key}'] = NPU_value @@ -388,19 +416,24 @@ class GraphsPlugin(base_plugin.TBPlugin): results[f'Bench_{key}'] = Bench_value else: results[f'Bench_{key}'] = 'N/A' - + # 计算相对误差 if NPU_value is not None and Bench_value is not None and Bench_value != 0: results[f'{key}_relative_err'] = f"{abs(NPU_value - Bench_value) / float(Bench_value) * 100:.6f}%" else: results[f'{key}_relative_err'] = "N/A" - + return results def process_data(self, data, data_set, NPU_node): """处理数据的公共逻辑""" for npu_key, bench_key in zip(data_set['npu_keys'], data_set['bench_keys']): - if data_set['input_data'].get(npu_key) not in ['None', None] and data_set['output_data'].get(bench_key) not in ['None', None]: + if data_set['input_data'].get(npu_key) not in ['None', None] and data_set['output_data'].get( + bench_key + ) not in [ + 'None', + None, + ]: npu_value = data_set['input_data'][npu_key] bench_value = data_set['output_data'][bench_key] @@ -418,7 +451,7 @@ class GraphsPlugin(base_plugin.TBPlugin): NPU_min, Bench_min = results['NPU_Min'], results['Bench_Min'] NPU_mean, Bench_mean = results['NPU_Mean'], results['Bench_Mean'] NPU_norm, Bench_norm = results['NPU_Norm'], results['Bench_Norm'] - + max_relative_err = results['Max_relative_err'] min_relative_err = results['Min_relative_err'] mean_relative_err = results['Mean_relative_err'] @@ -433,16 +466,21 @@ class GraphsPlugin(base_plugin.TBPlugin): ['MaxRelativeErr', max_relative_err], ['MinRelativeErr', min_relative_err], ['MeanRelativeErr', mean_relative_err], - ['NormRelativeErr', norm_relative_err] + ['NormRelativeErr', norm_relative_err], ] # 更新最大差异 - Max_value = max(abs(NPU_max - Bench_max), abs(NPU_min - Bench_min), abs(NPU_mean - Bench_mean), abs(NPU_norm - Bench_norm)) + Max_value = max( + abs(NPU_max - Bench_max), + abs(NPU_min - Bench_min), + abs(NPU_mean - Bench_mean), + abs(NPU_norm - Bench_norm), + ) if Max_value > data_set['precision_index']: data_set['precision_index'] = Max_value # 将数据添加到列表中 - data.append(data_entry) + data.append([npu_key, *data_entry]) # 保存数据到文件 self.save_data(data_set['file_path'], NPU_node, npu_key, data_entry, data_set['data_type']) @@ -468,7 +506,9 @@ class GraphsPlugin(base_plugin.TBPlugin): """添加匹配的节点链接""" # 获取匹配的节点链接 NPU_matched_node_link = self.json_get(self._current_file_data, "NPU", "node", NPU_node, "matched_node_link") - Bench_matched_node_link = self.json_get(self._current_file_data, "Bench", "node", Bench_node, "matched_node_link") + Bench_matched_node_link = self.json_get( + self._current_file_data, "Bench", "node", Bench_node, "matched_node_link" + ) # 检查是否为可变列表类型 if isinstance(NPU_matched_node_link, list) and isinstance(Bench_matched_node_link, list): @@ -497,13 +537,12 @@ class GraphsPlugin(base_plugin.TBPlugin): logger.error(f'Error: add precision_index failed.') return False return True - + def save_results(self, file_path): """保存修改后的结果""" with open(file_path, "w", encoding="utf-8") as file: json.dump(self._current_file_data, file, ensure_ascii=False, indent=4) - @wrappers.Request.application def get_unmatch(self, request): # 需要处理的项目为input, output, matched_node_link data里面的precision_index 和 来判断的 @@ -556,7 +595,7 @@ class GraphsPlugin(base_plugin.TBPlugin): json.dump(self._current_file_data, file, ensure_ascii=False, indent=4) result = 'unmatched!!!' return http_util.Respond(request, result, "application/json") - + @wrappers.Request.application def get_parent_node(self, request): node = request.args.get("node")[4:] # 获取节点信息 @@ -564,8 +603,10 @@ class GraphsPlugin(base_plugin.TBPlugin): json_data = self.check_jsondata(request) # 检查请求中的 JSON 数据 def find_upnode(node): - matched_node_link_list = self.json_get(json_data, constants.PREFIX_MAP[prefix], 'node', node, 'matched_node_link') - + matched_node_link_list = self.json_get( + json_data, constants.PREFIX_MAP[prefix], 'node', node, 'matched_node_link' + ) + if matched_node_link_list: result = matched_node_link_list[-1] # 获取匹配的最后一个节点 return http_util.Respond(request, result, "application/json") # 返回响应 @@ -579,8 +620,8 @@ class GraphsPlugin(base_plugin.TBPlugin): return http_util.Respond(request, {}, "application/json") # 如果没有找到上级节点,返回空响应 return find_upnode(node) - - #多卡跳转后端的方法,负责拿到rank数据 + + # 多卡跳转后端的方法,负责拿到rank数据 @wrappers.Request.application def get_rank(self, request): node_name = request.args.get('node') @@ -599,7 +640,7 @@ class GraphsPlugin(base_plugin.TBPlugin): else: return http_util.Respond(request, {}, "application/json") - #拿json_data里面所有配置数据的 + # 拿json_data里面所有配置数据的 @wrappers.Request.application def get_all_data(self, request): """Returns all data in json format.""" @@ -609,7 +650,7 @@ class GraphsPlugin(base_plugin.TBPlugin): json_data = self.check_jsondata(request) allNodeName = self.get_all_nodeName(json_data, request) response_data['Menu'] = allNodeName - response_data['UnMatchedNode'] = self.get_all_unmatchedNodes(allNodeName , request) + response_data['UnMatchedNode'] = self.get_all_unmatchedNodes(allNodeName, request) self._current_tag = tag for field in ['MicroSteps', 'StepList', 'match']: if json_data.get(field, {}): @@ -619,7 +660,7 @@ class GraphsPlugin(base_plugin.TBPlugin): json_data[key].insert(0, 'ALL') response_data[key] = json_data.get(key, {}) return http_util.Respond(request, response_data, "application/json") - + @wrappers.Request.application def static_file_route(self, request): filename = os.path.basename(request.path) @@ -636,11 +677,9 @@ class GraphsPlugin(base_plugin.TBPlugin): contents = infile.read() except IOError as e: raise exceptions.NotFound('404 Not Found') from e - return Response( - contents, content_type=mimetype, headers=GraphsPlugin.headers - ) - - #方便多层级展开的upnodes节点集合,与tf-graph的_menuSelectedNodeExpand联动 + return Response(contents, content_type=mimetype, headers=GraphsPlugin.headers) + + # 方便多层级展开的upnodes节点集合,与tf-graph的_menuSelectedNodeExpand联动 @wrappers.Request.application def get_all_upnodes(self, request): npu_upnodes_list, matched_upnodes_list, node_list = [], [], [] @@ -658,6 +697,7 @@ class GraphsPlugin(base_plugin.TBPlugin): if node_list.get(node, {}).get('matched_node_link') else None ) + def get_upnodes(node, prefix): upnodes_list = [] if prefix == '': @@ -671,6 +711,7 @@ class GraphsPlugin(base_plugin.TBPlugin): upnodes_list.insert(0, upnode) node = upnode return upnodes_list + npu_upnodes_list = get_upnodes(node, prefix) # 如果 matched_node 是 None 的话 if matched_node is None: @@ -683,7 +724,7 @@ class GraphsPlugin(base_plugin.TBPlugin): if prefix in constants.PREFIX_MAP: matched_upnodes_list = get_upnodes(matched_node, prefix) return http_util.Respond(request, [[prefix], npu_upnodes_list, matched_upnodes_list], "application/json") - + # 检查到底是读一般还是用之前存的 def check_jsondata(self, request): tag = request.args.get("tag") @@ -694,7 +735,7 @@ class GraphsPlugin(base_plugin.TBPlugin): else: json_data = self._current_file_data return json_data - + # 处理xx.get def json_get(self, data, *args): result = data @@ -713,16 +754,14 @@ class GraphsPlugin(base_plugin.TBPlugin): self.batch_id = request.args.get("batch") self.step_id = request.args.get("step") if node_id is None: - return http_util.Respond( - request, 'The query parameter "node" is required', "text/plain", 400 - ) + return http_util.Respond(request, 'The query parameter "node" is required', "text/plain", 400) if node_id == 'root': if json_data.get('Bench', {}): subgraph_pbtxt_set = {} for node_type in ('Bench', 'NPU'): subgraph = {'node': {}, 'edge': {}} node = self.json_get(json_data, constants.SETS[node_type][0], 'root') - node_data = self.json_get(json_data ,constants.SETS[node_type][0], 'node', node) + node_data = self.json_get(json_data, constants.SETS[node_type][0], 'node', node) node = constants.SETS[node_type][1] + node matched_node_link = node_data['matched_node_link'] if matched_node_link[0][:4] != constants.SETS[node_type][2]: @@ -740,13 +779,13 @@ class GraphsPlugin(base_plugin.TBPlugin): subgraph = self._extract_subgraph(json_data, node_id) subgraph_pbtxt = self._convert_to_protobuf_format(subgraph) return http_util.Respond(request, subgraph_pbtxt, "text/x-protobuf") - + @wrappers.Request.application def info_route(self, request): info = self.info_impl() return http_util.Respond(request, info, "application/json") - #同上二者一体 + # 同上二者一体 def _extract_subgraph(self, json_data, node_id): """提取子图,支持多种节点前缀逻辑""" subgraph = {'node': {}, 'edge': []} @@ -759,11 +798,11 @@ class GraphsPlugin(base_plugin.TBPlugin): else: prefix = '' node_set = json_data.get('node', {}) - + # 获取当前节点数据 node_data = node_set.get(node_id, {}) subnodes = node_data.get('subnodes', []) - + # 遍历子节点 for subnode_id in subnodes: subnode_id_data = node_set.get(subnode_id, {}) @@ -771,19 +810,19 @@ class GraphsPlugin(base_plugin.TBPlugin): self._process_subnode(subgraph, prefix, subnode_id, subnode_id_data, json_data) else: self._process_non_root_subnode(subgraph, prefix, subnode_id, subnode_id_data) - + return subgraph def _process_non_root_subnode(self, subgraph, prefix, subnode_id, subnode_id_data): """处理非根子节点""" # 更新匹配的节点链接 self._update_matched_node_links(subnode_id_data, prefix) - + # 添加前缀并存入子图 full_subnode_id = prefix + subnode_id subgraph['node'][full_subnode_id] = subnode_id_data - #针对分micro_step_id和step_id取的部分节点 + # 针对分micro_step_id和step_id取的部分节点 def _process_subnode(self, subgraph, prefix, subnode_id, subnode_id_data, json_data): batchid = subnode_id_data.get('micro_step_id') stepid = subnode_id_data.get('step_id') @@ -812,7 +851,7 @@ class GraphsPlugin(base_plugin.TBPlugin): matched_node_link = constants.SETS[prefix][1] + matched_node_link subnode_id_data['matched_node_link'][index] = matched_node_link - #拼接成类json + # 拼接成类json def _convert_to_protobuf_format(self, subgraph): """Converts subgraph data to the protobuf text format expected by the frontend.""" nodes = subgraph.get('node', {}) @@ -822,7 +861,9 @@ class GraphsPlugin(base_plugin.TBPlugin): protobuf_format += f' node_type: {node_data.get("node_type", 0)}\n' if node_data.get("matched_node_link"): protobuf_format += f' matched_node_link: {node_data.get("matched_node_link")}\n' - protobuf_format += f' attr: "{node_data.get("data", "{}")}"\n'.replace('True', 'true').replace('False', 'false') + protobuf_format += f' attr: "{node_data.get("data", "{}")}"\n'.replace('True', 'true').replace( + 'False', 'false' + ) protobuf_format += f' precision_index: {(node_data.get("data", {}).get("precision_index"))}\n' if node_data.get("input_data"): protobuf_format += f' input_data: "{node_data.get("input_data", "{}")}"\n' @@ -850,11 +891,13 @@ class GraphsPlugin(base_plugin.TBPlugin): file_path = os.path.join(root, file) file_size = os.path.getsize(file_path) if file_size > constants.MAX_FILE_SIZE: - logger.error(f'Error: the vis file "{file_path}" exceeds the maximum limit size of 1GB and will be skipped.') + logger.error( + f'Error: the vis file "{file_path}" exceeds the maximum limit size of 1GB and will be skipped.' + ) continue run_tag_pairs.append((run, tag)) return run_tag_pairs - + def _load_json_file(self, run_dir, tag): """Load a single .vis file from a given directory based on the tag.""" file_path = os.path.join(run_dir, f"{tag}.vis")