diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py index 974fcdf8a19d81a6c22a0396f45fc1725b00a39a..81b87a117815d4c9605877fa07f2d2f882c5b60f 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py @@ -1,7 +1,6 @@ import os import unittest -from msprobe.visualization.utils import (load_json_file, load_data_json_file, str2float, check_directory_content, - GraphConst, SerializableArgs) +from msprobe.visualization.utils import (str2float, check_directory_content, GraphConst, SerializableArgs) class TestMappingConfig(unittest.TestCase): @@ -10,14 +9,6 @@ class TestMappingConfig(unittest.TestCase): self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml") self.input = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") - def test_load_json_file(self): - result = load_json_file(self.yaml_path) - self.assertEqual(result, {}) - - def test_load_data_json_file(self): - result = load_data_json_file(self.yaml_path) - self.assertEqual(result, {}) - def test_str2float(self): result = str2float('23.4%') self.assertAlmostEqual(result, 0.234) diff --git a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py index 5c8a489245bcae66c6440d1961c4025405211dce..ac6c9c3eccaad4b2bde997b38bb2e2367d5909d5 100644 --- a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py @@ -15,11 +15,11 @@ import re from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data, get_csv_df -from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file +from msprobe.visualization.utils import GraphConst from msprobe.visualization.graph.graph import Graph, NodeOp from msprobe.visualization.compare.mode_adapter import ModeAdapter from msprobe.core.common.const import Const -from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.file_utils import load_json class GraphComparator: @@ -129,9 +129,9 @@ class GraphComparator: self.output_path = output_path compare_mode = get_compare_mode(self.dump_path_param) self.ma = ModeAdapter(compare_mode) - self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path')) - self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path')) - self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path')) + self.data_n_dict = load_json(dump_path_param.get('npu_json_path')).get(GraphConst.DATA_KEY, {}) + self.data_b_dict = load_json(dump_path_param.get('bench_json_path')).get(GraphConst.DATA_KEY, {}) + self.stack_json_data = load_json(dump_path_param.get('stack_json_path')) def _postcompare(self): self._handle_api_collection_index() diff --git a/debug/accuracy_tools/msprobe/visualization/utils.py b/debug/accuracy_tools/msprobe/visualization/utils.py index 52a5a52fc3e2ad8ca6f3ea0db552663e901c43d5..d682af3c2a8875156b4c60e567d65c5e17acb0f6 100644 --- a/debug/accuracy_tools/msprobe/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/visualization/utils.py @@ -16,35 +16,12 @@ import os import re import pickle -import json -from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.log import logger from msprobe.core.common.exceptions import MsprobeException from msprobe.core.compare.utils import check_and_return_dir_contents -def load_json_file(file_path): - """ - 加载json文件 - """ - try: - with FileOpen(file_path, 'r') as f: - file_dict = json.load(f) - if not isinstance(file_dict, dict): - return {} - return file_dict - except json.JSONDecodeError: - return {} - - -def load_data_json_file(file_path): - """ - 加载dump.json中的data字段 - """ - return load_json_file(file_path).get(GraphConst.DATA_KEY, {}) - - def str2float(percentage_str): """ 百分比字符串转换转换为浮点型