diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py index 47f4ab7f369410ac973bb8ada01409543cbefa2e..0fa9adef57be7d5104e0975850ccdb4f166e618a 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/service/file_check_wrapper.py @@ -27,21 +27,12 @@ def check_file_type(func): request = args[0] if not isinstance(request, Request): raise RuntimeError('The request "parameter" is not in a format supported by werkzeug') - meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') - - result = {'success': False, 'error': ''} + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + meta_data = GraphUtils.safe_get_meta_data(data) if meta_data is None or not isinstance(meta_data, dict): - result['error'] = 'The query parameter "metaData" is required and must be a dictionary' - return http_util.Respond(request, result, "application/json") - data_type = meta_data.get('type') - run = meta_data.get('run') - tag = meta_data.get('tag') - if data_type is None or run is None or tag is None: - result['error'] = 'The query parameters "type", "run" and "tag" in "metaData" are required' - return http_util.Respond(request, result, "application/json") - if data_type not in [e.value for e in DataType]: - result['error'] = 'Unsupported file type' + result['error'] = 'The query parameter “metaData” and its sub-items must exist and must be a dictionary.' return http_util.Respond(request, result, "application/json") + result = {'success': False, 'error': ''} return func(*args, **kwargs) diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py index 74ff3399d3bb7d0b95cda5a35cba6f82af60309d..3e4e0436926333f253c1dacf930abeab5ac253d5 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py @@ -140,6 +140,62 @@ class GraphUtils: logger.error(f"Unexpected error: {e}") return default_value + @staticmethod + def safe_get_node_info(data, default_value=None): + + node_info = data.get('nodeInfo') + + try: + # 长度限制 - 检查字典转为字符串后的长度 + node_info_str = str(node_info) + if len(node_info_str) > MAX_FILE_SIZE: + logger.error(f"Input length exceeds {MAX_FILE_SIZE} characters.") + return default_value + + # 验证必要字段是否存在 + required_fields = ["nodeName", "nodeType"] + for field in required_fields: + if field not in node_info: + logger.error(f"Field {field} is missing in metadata.") + return default_value + + return node_info + + except json.JSONDecodeError: + logger.error("NodeInfo parameter is not in valid JSON format.") + return default_value + except Exception as e: + logger.error(f"An error occurred while parsing the nodeInfo parameter: {str(e)}") + return default_value + + @staticmethod + def safe_get_meta_data(data, default_value=None): + + meta_data = data.get('metaData') + + try: + # 长度限制 + meta_data_str = str(meta_data) + if len(meta_data_str) > MAX_FILE_SIZE: + logger.error(f"Input length exceeds {MAX_FILE_SIZE} characters.") + return default_value + + # 验证必要字段是否存在 + required_fields = ["tag", "microStep", "run", "type"] + for field in required_fields: + if field not in meta_data: + logger.error(f"Field {field} is missing in metadata.") + return default_value + + return meta_data + + except json.JSONDecodeError: + logger.error("MetaData parameter is not in valid JSON format.") + return default_value + except Exception as e: + logger.error(f"An error occurred while parsing the metatdata parameter: {str(e)}") + return default_value + @staticmethod def remove_prefix(node_data, prefix): if node_data is None: diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py index d3fe234c564309afa401c3df22f947d85b4cd767..af14d915112808f1850bae2274122aff483bd7a2 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py @@ -81,7 +81,8 @@ class GraphView: @wrappers.Request.application @check_file_type def load_graph_config_info(request): - meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + meta_data = data.get('metaData') strategy = GraphView._get_strategy(meta_data) result = strategy.load_graph_config_info() # 创建响应对象 @@ -93,7 +94,8 @@ class GraphView: @wrappers.Request.application @check_file_type def load_graph_all_node_list(request): - meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + meta_data = data.get('metaData') strategy = GraphView._get_strategy(meta_data) result = strategy.load_graph_all_node_list(meta_data) response = http_util.Respond(request, result, "application/json") @@ -117,7 +119,11 @@ class GraphView: @check_file_type def change_node_expand_state(request): data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) - node_info = data.get('nodeInfo') + node_info = GraphUtils.safe_get_node_info(data) + result = {'success': False, 'error': ''} + if node_info is None or not isinstance(node_info, dict): + result['error'] = 'The query parameter nodeInfo and its sub-items must exist and must be a dictionary.' + return http_util.Respond(request, result, "application/json") meta_data = data.get('metaData') strategy = GraphView._get_strategy(meta_data) hierarchy = strategy.change_node_expand_state(node_info, meta_data) @@ -137,7 +143,11 @@ class GraphView: @check_file_type def get_node_info(request): data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) - node_info = data.get('nodeInfo') + node_info = GraphUtils.safe_get_node_info(data) + result = {'success': False, 'error': ''} + if node_info is None or not isinstance(node_info, dict): + result['error'] = 'The query parameter nodeInfo and its sub-items must exist and must be a dictionary.' + return http_util.Respond(request, result, "application/json") meta_data = data.get('metaData') strategy = GraphView._get_strategy(meta_data) node_detail = strategy.get_node_info(node_info, meta_data) @@ -163,7 +173,7 @@ class GraphView: data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) npu_node_name = data.get("npuNodeName") bench_node_name = data.get("benchNodeName") - meta_data = data.get("metaData") + meta_data = data.get('metaData') is_match_children = data.get("isMatchChildren") strategy = GraphView._get_strategy(meta_data) match_result = strategy.add_match_nodes(npu_node_name, bench_node_name, meta_data, is_match_children) @@ -177,7 +187,7 @@ class GraphView: data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) npu_node_name = data.get("npuNodeName") bench_node_name = data.get("benchNodeName") - meta_data = data.get("metaData") + meta_data = data.get('metaData') is_unmatch_children = data.get("isUnMatchChildren") strategy = GraphView._get_strategy(meta_data) match_result = strategy.delete_match_nodes(npu_node_name, bench_node_name, meta_data, is_unmatch_children) @@ -188,7 +198,8 @@ class GraphView: @wrappers.Request.application @check_file_type def save_data(request): - meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + meta_data = data.get('metaData') strategy = GraphView._get_strategy(meta_data) save_result = strategy.save_data(meta_data) return http_util.Respond(request, json.dumps(save_result), "application/json") @@ -208,7 +219,8 @@ class GraphView: @wrappers.Request.application @check_file_type def save_matched_relations(request): - meta_data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}).get('metaData') + data = GraphUtils.safe_json_loads(request.get_data().decode('utf-8'), {}) + meta_data = data.get('metaData') strategy = GraphView._get_strategy(meta_data) save_result = strategy.save_matched_relations() return http_util.Respond(request, json.dumps(save_result), "application/json")