diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py index 04727baac923f3fa00cea4756e36f39f9de4db32..7d0202e363f46b3db2265616b5b6f9f2d5cbecd6 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/graph_service.py @@ -98,7 +98,7 @@ class GraphService: json_data = GraphUtils.safe_json_loads(buffer) yield f"data: {json.dumps({'progress': 99, 'status': 'loading'})}\n\n" except json.JSONDecodeError as e: - yield f"data: {json.dumps({'progress': current_progress, 'error': str(e)})}\n\n" + yield f"data: {json.dumps({'progress': current_progress, 'error': '文件读取错误'})}\n\n" if json_data is not None: # 验证存储 GraphState.set_global_value('current_file_data', json_data) @@ -112,7 +112,7 @@ class GraphService: def load_graph_config_info(run, tag): graph_data, error_message = GraphUtils.get_graph_data({'run': run, 'tag': tag}) if error_message or not graph_data: - return {'success': False, 'error': error_message} + return {'success': False, 'error': '读取文件失败'} config = {} try: # 读取全局信息,tag层面 @@ -149,7 +149,7 @@ class GraphService: def load_graph_all_node_list(run, tag, micro_step): graph_data, error_message = GraphUtils.get_graph_data({'run': run, 'tag': tag}) if error_message or not graph_data: - return {'success': False, 'error': error_message} + return {'success': False, 'error': '读取文件失败'} result = {} try: if not graph_data.get(NPU): @@ -197,13 +197,13 @@ class GraphService: return {'success': True, 'data': result} except Exception as e: logger.error('获取节点列表失败:' + str(e)) - return {'success': False, 'error': '获取节点列表失败:' + str(e)} + return {'success': False, 'error': '获取节点列表失败'} @staticmethod def change_node_expand_state(node_info, meta_data): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message or not graph_data: - return {'success': False, 'error': error_message} + return {'success': False, 'error': '读取文件失败'} graph_type = node_info.get('nodeType') node_name = node_info.get('nodeName') micro_step = meta_data.get('microStep') @@ -241,7 +241,7 @@ class GraphService: def get_node_info(node_info, meta_data): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message: - return {'success': False, 'error': error_message} + return {'success': False, 'error': '读取文件失败'} try: graph_type = node_info.get('nodeType') node_name = node_info.get('nodeName') @@ -263,13 +263,13 @@ class GraphService: return {'success': True, 'data': result} except Exception as e: logger.error('获取节点信息失败:' + str(e)) - return {'success': False, 'error': '获取节点信息失败:' + str(e), 'data': None} + return {'success': False, 'error': '获取节点信息失败', 'data': None} @staticmethod def add_match_nodes(npu_node_name, bench_node_name, meta_data, is_match_children): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message: - return {'success': False, 'error': error_message} + return {'success': False, 'error': '读取文件失败'} task = graph_data.get('task') result = {} try: @@ -293,7 +293,7 @@ class GraphService: else: return {'success': False, 'error': '任务类型不支持(Task type not supported) '} except Exception as e: - return {'success': False, '操作失败': str(e), 'data': None} + return {'success': False, 'error': '操作失败', 'data': None} @staticmethod def add_match_nodes_by_config(config_file, meta_data): @@ -318,7 +318,7 @@ class GraphService: def delete_match_nodes(npu_node_name, bench_node_name, meta_data, is_unmatch_children): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message: - return {'success': False, 'error': error_message} + return {'success': False, 'error': '读取文件失败'} task = graph_data.get('task') result = {} try: @@ -341,7 +341,7 @@ class GraphService: else: return {'success': False, 'error': '任务类型不支持(Task type not supported) '} except Exception as e: - return {'success': False, '操作失败': str(e), 'data': None} + return {'success': False, 'error': '操作失败', 'data': None} @staticmethod def update_colors(run, colors): @@ -361,13 +361,13 @@ class GraphService: GraphUtils.safe_save_data(first_file_data, run, f"{first_run_tag}.vis") return {'success': True, 'error': None, 'data': {}} except Exception as e: - return {'success': False, 'error': str(e), 'data': None} + return {'success': False, 'error': '颜色更新失败', 'data': None} @staticmethod def save_data(meta_data): graph_data, error_message = GraphUtils.get_graph_data(meta_data) if error_message: - return {'success': False, 'error': error_message} + return {'success': False, 'error': '读取文件失败'} run = meta_data.get('run') tag = meta_data.get('tag') diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py index 497f11e37fbe2f5b49564868518c5885ec492bb4..9d644eaf43a4f8f70020c7a24366b99b6dd6c8d8 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================== +import html import math import os import json @@ -141,6 +142,78 @@ class GraphUtils: logger.error(f"Unexpected error: {e}") return default_value + @staticmethod + def safe_get_node_info(request, default_value=None): + + node_info = request.args.get('nodeInfo') + + try: + node_info_str = str(node_info) + + # 验证必要字段是否存在 + required_fields = ["nodeName", "nodeType"] + for field in required_fields: + if field not in node_info: + logger.error(f"Field {field} is missing in nodeInfo.") + return default_value + + # 安全解析JSON + parsed_info = GraphUtils.safe_json_loads(node_info_str) + + # 特殊字符过滤/转义 + sanitized = {} + for key, value in parsed_info.items(): + if isinstance(value, str): + # HTML转义 + sanitized[key] = html.escape(value) + else: + sanitized[key] = value + + return sanitized + + except json.JSONDecodeError: + logger.error("NodeInfo parameter is not in valid JSON format.") + return default_value + except Exception as e: + logger.error(f"An error occurred while parsing the nodeInfo parameter: {str(e)}") + return default_value + + @staticmethod + def safe_get_meta_data(request, default_value=None): + + meta_data = request.args.get('metaData') + + try: + meta_data_str = str(meta_data) + + # 验证必要字段是否存在 + required_fields = ["tag", "run"] + for field in required_fields: + if field not in meta_data: + logger.error(f"Field {field} is missing in metadata.") + return default_value + + # 安全解析JSON + parsed_info = GraphUtils.safe_json_loads(meta_data_str) + + # 特殊字符过滤/转义 + sanitized = {} + for key, value in parsed_info.items(): + if isinstance(value, str): + #HTML转义 + sanitized[key] = html.escape(value) + else: + sanitized[key] = value + + return sanitized + + except json.JSONDecodeError: + logger.error("MetaData parameter is not in valid JSON format.") + return default_value + except Exception as e: + logger.error(f"An error occurred while parsing the metatdata parameter: {str(e)}") + return default_value + @staticmethod def remove_prototype_pollution(obj, current_depth=1, max_depth=200): """ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py index 17d85b704ecd4b1b4d3b3800b9bd533bd31bf3cd..aa02ae9ab76ab8653b0bc66f04048c33c8819864 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py @@ -98,8 +98,8 @@ class GraphView: @staticmethod @wrappers.Request.application def change_node_expand_state(request): - node_info = GraphUtils.safe_json_loads(request.args.get("nodeInfo")) - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + node_info = GraphUtils.safe_get_node_info(request) + meta_data = GraphUtils.safe_get_meta_data(request) hierarchy = GraphService.change_node_expand_state(node_info, meta_data) return http_util.Respond(request, json.dumps(hierarchy), "application/json") @@ -115,8 +115,8 @@ class GraphView: @staticmethod @wrappers.Request.application def get_node_info(request): - node_info = GraphUtils.safe_json_loads(request.args.get("nodeInfo")) - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + node_info = GraphUtils.safe_get_node_info(request) + meta_data = GraphUtils.safe_get_meta_data(request) node_detail = GraphService.get_node_info(node_info, meta_data) return http_util.Respond(request, json.dumps(node_detail), "application/json") @@ -125,7 +125,7 @@ class GraphView: @wrappers.Request.application def add_match_nodes_by_config(request): config_file = request.args.get("configFile") - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) match_result = GraphService.add_match_nodes_by_config(config_file, meta_data) return http_util.Respond(request, json.dumps(match_result), "application/json") @@ -135,7 +135,7 @@ class GraphView: def add_match_nodes(request): npu_node_name = request.args.get("npuNodeName") bench_node_name = request.args.get("benchNodeName") - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) is_match_children = GraphUtils.safe_json_loads(request.args.get("isMatchChildren")) match_result = GraphService.add_match_nodes(npu_node_name, bench_node_name, meta_data, is_match_children) return http_util.Respond(request, json.dumps(match_result), "application/json") @@ -146,7 +146,7 @@ class GraphView: def delete_match_nodes(request): npu_node_name = request.args.get("npuNodeName") bench_node_name = request.args.get("benchNodeName") - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) is_unmatch_children = GraphUtils.safe_json_loads(request.args.get("isUnMatchChildren")) match_result = GraphService.delete_match_nodes(npu_node_name, bench_node_name, meta_data, is_unmatch_children) return http_util.Respond(request, json.dumps(match_result), "application/json") @@ -155,7 +155,7 @@ class GraphView: @staticmethod @wrappers.Request.application def save_data(request): - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) save_result = GraphService.save_data(meta_data) return http_util.Respond(request, json.dumps(save_result), "application/json") @@ -172,6 +172,6 @@ class GraphView: @staticmethod @wrappers.Request.application def save_matched_relations(request): - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) save_result = GraphService.save_matched_relations(meta_data) return http_util.Respond(request, json.dumps(save_result), "application/json")