From 98f96d5305c6cf017407dff26ad8d6ab4704dff1 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Thu, 26 Jun 2025 14:27:12 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=86=E7=BA=A7=E5=8F=AF=E8=A7=86=E5=8C=96?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E7=BB=9F=E4=B8=80=E7=9A=84load=20json?= =?UTF-8?q?=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_visualization_utils.py | 14 ++--------- .../visualization/compare/graph_comparator.py | 10 ++++---- .../msprobe/visualization/utils.py | 25 +------------------ 3 files changed, 8 insertions(+), 41 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py index 41ea145208..41a5b48419 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py @@ -1,7 +1,6 @@ import os import unittest -from msprobe.visualization.utils import (load_json_file, load_data_json_file, str2float, check_directory_content, - GraphConst, SerializableArgs) +from msprobe.visualization.utils import str2float, check_directory_content, GraphConst, SerializableArgs class TestMappingConfig(unittest.TestCase): @@ -10,14 +9,6 @@ class TestMappingConfig(unittest.TestCase): self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml") self.input = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") - def test_load_json_file(self): - result = load_json_file(self.yaml_path) - self.assertEqual(result, {}) - - def test_load_data_json_file(self): - result = load_data_json_file(self.yaml_path) - self.assertEqual(result, {}) - def test_str2float(self): result = str2float('23.4%') self.assertAlmostEqual(result, 0.234) @@ -43,6 +34,7 @@ class TestMappingConfig(unittest.TestCase): self.a = a self.b = b self.c = c + input_args1 = TmpArgs('a', 123, [1, 2, 3]) serializable_args1 = SerializableArgs(input_args1) self.assertEqual(serializable_args1.__dict__, input_args1.__dict__) @@ -51,7 +43,5 @@ class TestMappingConfig(unittest.TestCase): self.assertNotEqual(serializable_args2.__dict__, input_args2.__dict__) - - if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py index 95982658d2..62ffa8c66d 100644 --- a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py @@ -15,11 +15,11 @@ import re from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data, get_csv_df -from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file +from msprobe.visualization.utils import GraphConst from msprobe.visualization.graph.graph import Graph, NodeOp from msprobe.visualization.compare.mode_adapter import ModeAdapter from msprobe.core.common.const import Const -from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.file_utils import load_json class GraphComparator: @@ -127,9 +127,9 @@ class GraphComparator: self.output_path = output_path compare_mode = get_compare_mode(self.dump_path_param) self.ma = ModeAdapter(compare_mode) - self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path')) - self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path')) - self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path')) + self.data_n_dict = load_json(dump_path_param.get('npu_json_path')).get(GraphConst.DATA_KEY, {}) + self.data_b_dict = load_json(dump_path_param.get('bench_json_path')).get(GraphConst.DATA_KEY, {}) + self.stack_json_data = load_json(dump_path_param.get('stack_json_path')) def _postcompare(self): self._handle_api_collection_index() diff --git a/debug/accuracy_tools/msprobe/visualization/utils.py b/debug/accuracy_tools/msprobe/visualization/utils.py index 4ee05b1254..5f1e5bf639 100644 --- a/debug/accuracy_tools/msprobe/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/visualization/utils.py @@ -15,34 +15,11 @@ import os import re -import json import pickle -from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.log import logger -def load_json_file(file_path): - """ - 加载json文件 - """ - try: - with FileOpen(file_path, 'r') as f: - file_dict = json.load(f) - if not isinstance(file_dict, dict): - return {} - return file_dict - except json.JSONDecodeError: - return {} - - -def load_data_json_file(file_path): - """ - 加载dump.json中的data字段 - """ - return load_json_file(file_path).get(GraphConst.DATA_KEY, {}) - - def str2float(percentage_str): """ 百分比字符串转换转换为浮点型 @@ -184,7 +161,7 @@ class GraphConst: OP = 'op' PEER = 'peer' GROUP_ID = 'group_id' - + IGNORE_PRECISION_INDEX = {'empty', 'empty_like', 'empty_with_format', 'new_empty_strided', 'new_empty', 'empty_strided'} -- Gitee