From a34127e57b48456d05447d050c955860f62dc35c Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Thu, 6 Mar 2025 10:46:09 +0800 Subject: [PATCH] compare read real data improve --- .../msprobe/core/compare/acc_compare.py | 16 ++----- .../msprobe/mindspore/compare/ms_compare.py | 27 +++++------ .../msprobe/mindspore/compare/utils.py | 30 ++++++++++++ .../msprobe/pytorch/__init__.py | 1 - .../pytorch/compare/distributed_compare.py | 12 ++--- .../msprobe/pytorch/compare/pt_compare.py | 37 +++------------ .../msprobe/pytorch/compare/utils.py | 47 +++++++++++++++++++ .../mindspore_ut/compare/test_ms_compare.py | 28 ----------- .../test/mindspore_ut/compare/test_utils.py | 24 ++++++++++ .../pytorch_ut/compare/test_pt_compare.py | 35 +------------- .../test/pytorch_ut/compare/test_utils.py | 43 +++++++++++++++++ 11 files changed, 171 insertions(+), 129 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/mindspore/compare/utils.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/compare/utils.py create mode 100644 debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_utils.py create mode 100644 debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_utils.py diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index f2aa8c479e..4ffbb225b0 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -15,7 +15,6 @@ import multiprocessing import os -import re from copy import deepcopy import pandas as pd @@ -351,6 +350,9 @@ class Comparator: result_df = self.make_result_table(result) return result_df + def read_real_data(self, npu_dir, npu_data_name, bench_dir, bench_data_name) -> tuple: + pass + def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param): """ :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0 @@ -376,17 +378,7 @@ class Comparator: npu_dir = input_param.get("npu_dump_data_dir") bench_dir = input_param.get("bench_dump_data_dir") try: - frame_name = getattr(self, "frame_name") - read_npy_data = getattr(self, "read_npy_data") - if frame_name == "MSComparator": - n_value = read_npy_data(npu_dir, npu_data_name) - if self.cross_frame: - b_value = read_npy_data(bench_dir, bench_data_name, load_pt_file=True) - else: - b_value = read_npy_data(bench_dir, bench_data_name) - else: - n_value = read_npy_data(npu_dir, npu_data_name) - b_value = read_npy_data(bench_dir, bench_data_name) + n_value, b_value = self.read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name) except IOError as error: error_file = error.filename n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index c3767abf87..a7694bf853 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -22,7 +22,7 @@ import pandas as pd from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import create_directory, load_json, load_npy, load_yaml +from msprobe.core.common.file_utils import create_directory, load_json, load_yaml from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \ check_op_str_pattern_valid, get_dump_mode, set_dump_path, detect_framework_by_dump_json @@ -30,6 +30,8 @@ from msprobe.core.compare.acc_compare import Comparator, ModeConfig from msprobe.core.compare.check import dtype_mapping from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list +from msprobe.pytorch.compare.utils import read_pt_data +from msprobe.mindspore.compare.utils import read_npy_data class MappingConfig: @@ -211,21 +213,6 @@ class MSComparator(Comparator): npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1) return npu_op_name - def read_npy_data(self, dir_path, file_name, load_pt_file=False): - if not file_name: - return None - data_path = os.path.join(dir_path, file_name) - if load_pt_file: - import torch - from msprobe.pytorch.common.utils import load_pt - data_value = load_pt(data_path, True).detach() - if data_value.dtype == torch.bfloat16: - data_value = data_value.to(torch.float32) - data_value = data_value.numpy() - else: - data_value = load_npy(data_path) - return data_value - def process_internal_api_mapping(self, npu_op_name): # get api name & class name from op_name # Functional.addcmul.0.forward.input.0 @@ -389,6 +376,14 @@ class MSComparator(Comparator): result['data_name'].append(data_name_reorder.pop(0)) return pd.DataFrame(result) + def read_real_data(self, npu_dir, npu_data_name, bench_dir, bench_data_name) -> tuple: + n_value = read_npy_data(npu_dir, npu_data_name) + if self.cross_frame: + b_value = read_pt_data(bench_dir, bench_data_name) + else: + b_value = read_npy_data(bench_dir, bench_data_name) + return n_value, b_value + def check_cross_framework(bench_json_path): framework = detect_framework_by_dump_json(bench_json_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/utils.py b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py new file mode 100644 index 0000000000..737cdb55d2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.file_utils import load_npy, FileChecker, FileCheckConst + + +def read_npy_data(dir_path, file_name): + if not file_name: + return None + + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.NUMPY_SUFFIX, False) + data_path = path_checker.common_check() + data_value = load_npy(data_path) + return data_value diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index ce84e6b35b..20fbfeed0f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -15,7 +15,6 @@ import torch from .compare.distributed_compare import compare_distributed -from .compare.pt_compare import compare from .common.utils import seed_all from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py index de62af421b..a484ad5cee 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,14 +15,10 @@ import os -from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \ - set_dump_path -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path +from msprobe.core.common.utils import CompareException +from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json from msprobe.pytorch.common.log import logger -from msprobe.pytorch.compare.pt_compare import PTComparator, compare +from msprobe.pytorch.compare.pt_compare import compare def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py index 308a82b3d6..38176ec57c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os.path - -import torch - -from msprobe.core.common.const import FileCheckConst from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml +from msprobe.core.common.file_utils import create_directory, load_yaml from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \ set_dump_path from msprobe.core.compare.acc_compare import Comparator, ModeConfig from msprobe.core.compare.utils import set_stack_json_path from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import load_pt +from msprobe.pytorch.compare.utils import read_pt_data class PTComparator(Comparator): @@ -55,28 +50,10 @@ class PTComparator(Comparator): mapping_dict = {} return mapping_dict - def read_npy_data(self, dir_path, file_name): - if not file_name: - return None - data_path = os.path.join(dir_path, file_name) - path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, - FileCheckConst.PT_SUFFIX, False) - data_path = path_checker.common_check() - try: - # detach because numpy can not process gradient information - data_value = load_pt(data_path, to_cpu=True).detach() - except RuntimeError as e: - # 这里捕获 load_pt 中抛出的异常 - logger.error(f"Failed to load the .pt file at {data_path}.") - raise CompareException(CompareException.INVALID_FILE_ERROR) from e - except AttributeError as e: - # 这里捕获 detach 方法抛出的异常 - logger.error(f"Failed to detach the loaded tensor.") - raise CompareException(CompareException.DETACH_ERROR) from e - if data_value.dtype == torch.bfloat16: - data_value = data_value.to(torch.float32) - data_value = data_value.numpy() - return data_value + def read_real_data(self, npu_dir, npu_data_name, bench_dir, bench_data_name) -> tuple: + n_value = read_pt_data(npu_dir, npu_data_name) + b_value = read_pt_data(bench_dir, bench_data_name) + return n_value, b_value def compare(input_param, output_path, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py new file mode 100644 index 0000000000..16473ff386 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from msprobe.core.common.utils import logger, CompareException +from msprobe.core.common.file_utils import FileChecker, FileCheckConst +from msprobe.pytorch.common.utils import load_pt + + +def read_pt_data(dir_path, file_name): + if not file_name: + return None + + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.PT_SUFFIX, False) + data_path = path_checker.common_check() + try: + # detach because numpy can not process gradient information + data_value = load_pt(data_path, to_cpu=True).detach() + except RuntimeError as e: + # 这里捕获 load_pt 中抛出的异常 + logger.error(f"Failed to load the .pt file at {data_path}.") + raise CompareException(CompareException.INVALID_FILE_ERROR) from e + except AttributeError as e: + # 这里捕获 detach 方法抛出的异常 + logger.error(f"Failed to detach the loaded tensor.") + raise CompareException(CompareException.DETACH_ERROR) from e + if data_value.dtype == torch.bfloat16: + data_value = data_value.to(torch.float32) + data_value = data_value.numpy() + return data_value diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py index 6f73778940..69fae7f9a9 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py @@ -9,7 +9,6 @@ from unittest.mock import patch import numpy as np import pandas as pd -import torch import yaml from msprobe.core.common.utils import CompareException @@ -468,33 +467,6 @@ class TestUtilsMethods(unittest.TestCase): npu_op_name = ms_comparator.process_cell_mapping(npu_cell_dict.get('op_name')[0]) self.assertEqual(npu_op_name, 'Module.fc1.Linear.forward.0.input.0') - def test_read_npy_data(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - - self.temp_file = tempfile.NamedTemporaryFile(suffix='.pt') - tensor = torch.Tensor([1, 2, 3]) - filename = self.temp_file.name.split('/')[-1] - torch.save(tensor, self.temp_file.name) - result = ms_comparator.read_npy_data('/tmp', filename, load_pt_file=True) - self.assertTrue(np.array_equal(result, np.array([1, 2, 3]))) - self.temp_file.close() - - self.temp_file = tempfile.NamedTemporaryFile(suffix='.npy') - tensor = np.array([1, 2, 3]) - filename = self.temp_file.name.split('/')[-1] - np.save(self.temp_file.name, tensor) - result = ms_comparator.read_npy_data('/tmp', filename, load_pt_file=False) - self.assertTrue(np.array_equal(result, np.array([1, 2, 3]))) - self.temp_file.close() - def test_process_internal_api_mapping(self): stack_mode = True auto_analyze = True diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_utils.py new file mode 100644 index 0000000000..d7fb5e38fb --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_utils.py @@ -0,0 +1,24 @@ +import unittest +from unittest.mock import patch + +import numpy as np + +from msprobe.core.common.file_utils import FileCheckConst +from msprobe.mindspore.compare.utils import read_npy_data + + +class TestReadNpyData(unittest.TestCase): + + @patch('msprobe.mindspore.compare.utils.load_npy') + @patch('msprobe.mindspore.compare.utils.FileChecker') + @patch('os.path.join', return_value='/fake/path/to/file.npy') + def test_read_real_data_ms(self, mock_os, mock_file_checker, mock_load_npy): + mock_file_checker.return_value.common_check.return_value = '/fake/path/to/file.npy' + + mock_load_npy.return_value = np.array([1.0, 2.0, 3.0]) + + result = read_npy_data('/fake/dir', 'file_name.npy') + + mock_file_checker.assert_called_once_with('/fake/path/to/file.npy', FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.NUMPY_SUFFIX, False) + mock_load_npy.assert_called_once_with('/fake/path/to/file.npy') + self.assertTrue(np.array_equal(result, np.array([1.0, 2.0, 3.0]))) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py index b079e646c4..4eda1d6d97 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py @@ -3,13 +3,10 @@ import os import shutil import unittest -import numpy as np import torch -from msprobe.core.common.const import Const from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.pytorch.compare.pt_compare import PTComparator, compare +from msprobe.pytorch.compare.pt_compare import compare from msprobe.test.core_ut.compare.test_acc_compare import generate_dump_json, generate_stack_json @@ -40,36 +37,6 @@ class TestUtilsMethods(unittest.TestCase): if os.path.exists(base_dir2): shutil.rmtree(base_dir2) - def test_read_npy_data_bf16(self): - generate_bf16_pt(base_dir1) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - result = pt_comparator.read_npy_data(base_dir1, 'bf16.pt') - - target_result = torch.tensor([1, 2, 3, 4], dtype=torch.float32).numpy() - self.assertTrue(np.array_equal(result, target_result)) - - def test_read_npy_data_dict(self): - generate_dict_pt(base_dir1) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - - with self.assertRaises(CompareException) as context: - result = pt_comparator.read_npy_data(base_dir1, 'dict.pt') - self.assertEqual(context.exception.code, CompareException.DETACH_ERROR) - def test_compare(self): generate_dump_json(base_dir2) generate_stack_json(base_dir2) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_utils.py new file mode 100644 index 0000000000..405503d898 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_utils.py @@ -0,0 +1,43 @@ +import unittest +from unittest.mock import patch, MagicMock + +import torch +import numpy as np + +from msprobe.core.common.utils import CompareException +from msprobe.core.common.file_utils import FileCheckConst +from msprobe.pytorch.compare.utils import read_pt_data + + +class TestReadPtData(unittest.TestCase): + + @patch('msprobe.pytorch.compare.utils.load_pt') + @patch('msprobe.pytorch.compare.utils.FileChecker') + @patch('os.path.join', return_value='/fake/path/to/file.pt') + def test_read_pt_data(self, mock_os, mock_file_checker, mock_load_pt): + mock_file_checker.return_value.common_check.return_value = '/fake/path/to/file.pt' + + mock_tensor = MagicMock() + mock_tensor.detach.return_value = mock_tensor + mock_tensor.to.return_value = mock_tensor + mock_tensor.dtype = torch.bfloat16 + mock_tensor.numpy.return_value = np.array([1.0, 2.0, 3.0]) + mock_load_pt.return_value = mock_tensor + + result = read_pt_data('/fake/dir', 'file_name.pt') + + mock_file_checker.assert_called_once_with('/fake/path/to/file.pt', FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.PT_SUFFIX, False) + mock_load_pt.assert_called_once_with('/fake/path/to/file.pt', to_cpu=True) + mock_tensor.to.assert_called_once_with(torch.float32) + self.assertTrue(np.array_equal(result, np.array([1.0, 2.0, 3.0]))) + + @patch('os.path.join', return_value='/fake/path/to/file.pt') + @patch('msprobe.pytorch.compare.utils.FileChecker') + @patch('msprobe.pytorch.compare.utils.load_pt') + def test_read_real_data_pt_exception(self, mock_load_pt, mock_file_checker, mock_os): + mock_file_checker.return_value.common_check.return_value = '/fake/path/to/file.pt' + + mock_load_pt.side_effect = RuntimeError("Test Error") + + with self.assertRaises(CompareException): + read_pt_data('/fake/dir', 'file_name.pt') -- Gitee