diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py index 974fcdf8a19d81a6c22a0396f45fc1725b00a39a..41a5b48419a762ff1e4e382ec2fc736b76bce309 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py @@ -1,7 +1,6 @@ import os import unittest -from msprobe.visualization.utils import (load_json_file, load_data_json_file, str2float, check_directory_content, - GraphConst, SerializableArgs) +from msprobe.visualization.utils import str2float, check_directory_content, GraphConst, SerializableArgs class TestMappingConfig(unittest.TestCase): @@ -10,14 +9,6 @@ class TestMappingConfig(unittest.TestCase): self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml") self.input = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") - def test_load_json_file(self): - result = load_json_file(self.yaml_path) - self.assertEqual(result, {}) - - def test_load_data_json_file(self): - result = load_data_json_file(self.yaml_path) - self.assertEqual(result, {}) - def test_str2float(self): result = str2float('23.4%') self.assertAlmostEqual(result, 0.234) diff --git a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py index 20c97d717b1cc87e69523144aca0ec93d564e682..2bcea0fd4343289b01dd1f92b1e4cd9e9bb8c48a 100644 --- a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py @@ -16,12 +16,12 @@ import re from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data, \ run_real_data_single, get_csv_df, compare_node_by_dump_data -from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file +from msprobe.visualization.utils import GraphConst from msprobe.visualization.graph.graph import Graph, NodeOp from msprobe.visualization.compare.mode_adapter import ModeAdapter from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import load_yaml +from msprobe.core.common.file_utils import load_yaml, load_json from msprobe.visualization.compare.multi_mapping import MultiMapping @@ -205,9 +205,9 @@ class GraphComparator: self.output_path = output_path compare_mode = get_compare_mode(self.dump_path_param) self.ma = ModeAdapter(compare_mode) - self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path')) - self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path')) - self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path')) + self.data_n_dict = load_json(dump_path_param.get('npu_json_path')).get(GraphConst.DATA_KEY, {}) + self.data_b_dict = load_json(dump_path_param.get('bench_json_path')).get(GraphConst.DATA_KEY, {}) + self.stack_json_data = load_json(dump_path_param.get('stack_json_path')) def _postcompare(self): self._handle_api_collection_index() diff --git a/debug/accuracy_tools/msprobe/visualization/utils.py b/debug/accuracy_tools/msprobe/visualization/utils.py index 5a08921392dac136b0437f878a8b30710e113b00..d5c96bdc7b4b18618e83a6323f034f7fd8315084 100644 --- a/debug/accuracy_tools/msprobe/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/visualization/utils.py @@ -15,36 +15,13 @@ import os import re -import json import pickle -from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.log import logger from msprobe.core.common.exceptions import MsprobeException from msprobe.core.compare.utils import check_and_return_dir_contents -def load_json_file(file_path): - """ - 加载json文件 - """ - try: - with FileOpen(file_path, 'r') as f: - file_dict = json.load(f) - if not isinstance(file_dict, dict): - return {} - return file_dict - except json.JSONDecodeError: - return {} - - -def load_data_json_file(file_path): - """ - 加载dump.json中的data字段 - """ - return load_json_file(file_path).get(GraphConst.DATA_KEY, {}) - - def str2float(percentage_str): """ 百分比字符串转换转换为浮点型