diff --git a/debug/accuracy_tools/msprobe/core/common/file_utils.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py index fdc626ca6a1a90e9060cefa237f9d5d8d7e42844..89d33a6a3e6fe830b981483edbe2aa6a4e5aa41f 100644 --- a/debug/accuracy_tools/msprobe/core/common/file_utils.py +++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,6 +23,7 @@ import shutil from datetime import datetime, timezone from dateutil import parser import yaml + import numpy as np import pandas as pd @@ -446,8 +447,6 @@ def save_excel(path, data): change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY) - - def move_file(src_path, dst_path): check_file_or_directory_path(src_path) check_path_before_create(dst_path) diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index f2aa8c479ecd0e40f3585708f82ff48a0e74832c..5e646c5352982ea25373808887cc938276fcb677 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -15,7 +15,6 @@ import multiprocessing import os -import re from copy import deepcopy import pandas as pd @@ -351,6 +350,9 @@ class Comparator: result_df = self.make_result_table(result) return result_df + def read_real_data(self, npu_dir, npu_data_name, bench_dir, bench_data_name) -> tuple: + return None, None + def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param): """ :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0 @@ -376,17 +378,7 @@ class Comparator: npu_dir = input_param.get("npu_dump_data_dir") bench_dir = input_param.get("bench_dump_data_dir") try: - frame_name = getattr(self, "frame_name") - read_npy_data = getattr(self, "read_npy_data") - if frame_name == "MSComparator": - n_value = read_npy_data(npu_dir, npu_data_name) - if self.cross_frame: - b_value = read_npy_data(bench_dir, bench_data_name, load_pt_file=True) - else: - b_value = read_npy_data(bench_dir, bench_data_name) - else: - n_value = read_npy_data(npu_dir, npu_data_name) - b_value = read_npy_data(bench_dir, bench_data_name) + n_value, b_value = self.read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name) except IOError as error: error_file = error.filename n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 4c56419dcb17b78918e4d46a3aaa50b12ef32777..c6985c3316617fa28537a42b825158fdebfe4335 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -29,7 +29,7 @@ from msprobe.core.common.log import logger from msprobe.core.common.utils import convert_tuple from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo -from msprobe.pytorch.common.utils import save_pt, load_pt +from msprobe.pytorch.common.utils import save_pt from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow from msprobe.core.common.utils import recursion_depth_decorator diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index c3767abf87158ec0ff02ad594d340597496e1aba..0060a4d1be60546a49655f83e916ba1418a2954a 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -22,7 +22,7 @@ import pandas as pd from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import create_directory, load_json, load_npy, load_yaml +from msprobe.core.common.file_utils import create_directory, load_json, load_yaml from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \ check_op_str_pattern_valid, get_dump_mode, set_dump_path, detect_framework_by_dump_json @@ -30,6 +30,8 @@ from msprobe.core.compare.acc_compare import Comparator, ModeConfig from msprobe.core.compare.check import dtype_mapping from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list +from msprobe.pytorch.compare.utils import read_pt_data +from msprobe.mindspore.compare.utils import read_npy_data class MappingConfig: @@ -211,20 +213,6 @@ class MSComparator(Comparator): npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1) return npu_op_name - def read_npy_data(self, dir_path, file_name, load_pt_file=False): - if not file_name: - return None - data_path = os.path.join(dir_path, file_name) - if load_pt_file: - import torch - from msprobe.pytorch.common.utils import load_pt - data_value = load_pt(data_path, True).detach() - if data_value.dtype == torch.bfloat16: - data_value = data_value.to(torch.float32) - data_value = data_value.numpy() - else: - data_value = load_npy(data_path) - return data_value def process_internal_api_mapping(self, npu_op_name): # get api name & class name from op_name @@ -389,6 +377,14 @@ class MSComparator(Comparator): result['data_name'].append(data_name_reorder.pop(0)) return pd.DataFrame(result) + def read_real_data(self, npu_dir, npu_data_name, bench_dir, bench_data_name) -> tuple: + n_value = read_npy_data(npu_dir, npu_data_name) + if self.cross_frame: + b_value = read_pt_data(bench_dir, bench_data_name) + else: + b_value = read_npy_data(bench_dir, bench_data_name) + return n_value, b_value + def check_cross_framework(bench_json_path): framework = detect_framework_by_dump_json(bench_json_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/utils.py b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..737cdb55d29c755d4626ea8ef7d12b9b7df13a39 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.file_utils import load_npy, FileChecker, FileCheckConst + + +def read_npy_data(dir_path, file_name): + if not file_name: + return None + + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.NUMPY_SUFFIX, False) + data_path = path_checker.common_check() + data_value = load_npy(data_path) + return data_value diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index ce84e6b35b74e55a90915350ff3ef2da3f7ba441..20fbfeed0fe60293e5030aff86f005b8b28f395a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -15,7 +15,6 @@ import torch from .compare.distributed_compare import compare_distributed -from .compare.pt_compare import compare from .common.utils import seed_all from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 4e82bee4a04d9ffe2be8aebe1a85791eccae4070..3fb9474fa9fda8dac9cd79db3ae0d11d7587b661 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,8 +25,8 @@ import numpy as np import torch import torch.distributed as dist from msprobe.core.common.exceptions import DistributedNotInitializedError -from msprobe.core.common.file_utils import (FileCheckConst, change_mode, - check_file_or_directory_path, check_path_before_create, FileOpen) +from msprobe.core.common.file_utils import FileCheckConst, change_mode, check_file_or_directory_path, \ + check_path_before_create, FileOpen from msprobe.core.common.log import logger from msprobe.core.common.utils import check_seed_all from packaging import version diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py index de62af421b5a37e39140a9836fb16853443740d7..1b49df0653a59c2a60e1abf5a544d3e294bae70e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2014-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,14 +15,10 @@ import os -from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \ - set_dump_path -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path +from msprobe.core.common.utils import CompareException +from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json from msprobe.pytorch.common.log import logger -from msprobe.pytorch.compare.pt_compare import PTComparator, compare +from msprobe.pytorch.compare.pt_compare import compare def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py index 308a82b3d6e9beb67a669ea05b83d7b8a6eddc90..38176ec57c077cb8369300af657224110e00bcdf 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os.path - -import torch - -from msprobe.core.common.const import FileCheckConst from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml +from msprobe.core.common.file_utils import create_directory, load_yaml from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \ set_dump_path from msprobe.core.compare.acc_compare import Comparator, ModeConfig from msprobe.core.compare.utils import set_stack_json_path from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import load_pt +from msprobe.pytorch.compare.utils import read_pt_data class PTComparator(Comparator): @@ -55,28 +50,10 @@ class PTComparator(Comparator): mapping_dict = {} return mapping_dict - def read_npy_data(self, dir_path, file_name): - if not file_name: - return None - data_path = os.path.join(dir_path, file_name) - path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, - FileCheckConst.PT_SUFFIX, False) - data_path = path_checker.common_check() - try: - # detach because numpy can not process gradient information - data_value = load_pt(data_path, to_cpu=True).detach() - except RuntimeError as e: - # 这里捕获 load_pt 中抛出的异常 - logger.error(f"Failed to load the .pt file at {data_path}.") - raise CompareException(CompareException.INVALID_FILE_ERROR) from e - except AttributeError as e: - # 这里捕获 detach 方法抛出的异常 - logger.error(f"Failed to detach the loaded tensor.") - raise CompareException(CompareException.DETACH_ERROR) from e - if data_value.dtype == torch.bfloat16: - data_value = data_value.to(torch.float32) - data_value = data_value.numpy() - return data_value + def read_real_data(self, npu_dir, npu_data_name, bench_dir, bench_data_name) -> tuple: + n_value = read_pt_data(npu_dir, npu_data_name) + b_value = read_pt_data(bench_dir, bench_data_name) + return n_value, b_value def compare(input_param, output_path, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16473ff386d89de5f3bbb269e69837c07a950ea5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from msprobe.core.common.utils import logger, CompareException +from msprobe.core.common.file_utils import FileChecker, FileCheckConst +from msprobe.pytorch.common.utils import load_pt + + +def read_pt_data(dir_path, file_name): + if not file_name: + return None + + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.PT_SUFFIX, False) + data_path = path_checker.common_check() + try: + # detach because numpy can not process gradient information + data_value = load_pt(data_path, to_cpu=True).detach() + except RuntimeError as e: + # 这里捕获 load_pt 中抛出的异常 + logger.error(f"Failed to load the .pt file at {data_path}.") + raise CompareException(CompareException.INVALID_FILE_ERROR) from e + except AttributeError as e: + # 这里捕获 detach 方法抛出的异常 + logger.error(f"Failed to detach the loaded tensor.") + raise CompareException(CompareException.DETACH_ERROR) from e + if data_value.dtype == torch.bfloat16: + data_value = data_value.to(torch.float32) + data_value = data_value.numpy() + return data_value diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py index 9ed13f78aed57fd4d8153e2f005ea14d4fb33643..ac3a859bf4b2da478e92650cfe3267cf90c23146 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py @@ -1,7 +1,5 @@ from unittest.mock import patch, mock_open, MagicMock -import numpy as np -import pandas as pd import pytest from msprobe.core.common.file_utils import * @@ -533,4 +531,4 @@ class TestDirectoryChecks: # Test file path check_file_or_directory_path(self.test_file, isdir=False) # Test directory path - check_file_or_directory_path(self.test_dir, isdir=True) \ No newline at end of file + check_file_or_directory_path(self.test_dir, isdir=True) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py index 1ed3ca016108519fb3f643c9d4bb768f63a52d40..80f91a53f79c81c4e79947bc66b7bf932b774bd0 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py @@ -15,21 +15,13 @@ # limitations under the License. """ import unittest -from unittest.mock import MagicMock, patch, call +from unittest.mock import patch import numpy as np import mindspore as ms -import os -import random - -from msprobe.core.common.exceptions import DistributedNotInitializedError -from msprobe.mindspore.common.utils import (get_rank_if_initialized, - convert_bf16_to_fp32, - save_tensor_as_npy, - convert_to_int, - list_lowest_level_directories, - seed_all, - remove_dropout, - MsprobeStep) + +from msprobe.mindspore.common.utils import get_rank_if_initialized, convert_bf16_to_fp32, convert_to_int, \ + list_lowest_level_directories, seed_all, remove_dropout, MsprobeStep + class MockCell: def __init__(self): @@ -136,8 +128,3 @@ class TestMsprobeFunctions(unittest.TestCase): from mindspore.mint.nn.functional import dropout self.assertTrue((Dropout(0.5)(x1d).numpy() == x1d.numpy()).all()) self.assertTrue((dropout(x1d, p=0.5).numpy() == x1d.numpy()).all()) - - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py index 6f7377894002e60add41dc7b2d3c1d3d68391e0b..4a3f01f988443a3ceea42831fd542709b9560170 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py @@ -9,7 +9,6 @@ from unittest.mock import patch import numpy as np import pandas as pd -import torch import yaml from msprobe.core.common.utils import CompareException @@ -468,32 +467,6 @@ class TestUtilsMethods(unittest.TestCase): npu_op_name = ms_comparator.process_cell_mapping(npu_cell_dict.get('op_name')[0]) self.assertEqual(npu_op_name, 'Module.fc1.Linear.forward.0.input.0') - def test_read_npy_data(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - - self.temp_file = tempfile.NamedTemporaryFile(suffix='.pt') - tensor = torch.Tensor([1, 2, 3]) - filename = self.temp_file.name.split('/')[-1] - torch.save(tensor, self.temp_file.name) - result = ms_comparator.read_npy_data('/tmp', filename, load_pt_file=True) - self.assertTrue(np.array_equal(result, np.array([1, 2, 3]))) - self.temp_file.close() - - self.temp_file = tempfile.NamedTemporaryFile(suffix='.npy') - tensor = np.array([1, 2, 3]) - filename = self.temp_file.name.split('/')[-1] - np.save(self.temp_file.name, tensor) - result = ms_comparator.read_npy_data('/tmp', filename, load_pt_file=False) - self.assertTrue(np.array_equal(result, np.array([1, 2, 3]))) - self.temp_file.close() def test_process_internal_api_mapping(self): stack_mode = True diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fb5e38fb82b309caf3ab2a1b621655d7babc86 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_utils.py @@ -0,0 +1,24 @@ +import unittest +from unittest.mock import patch + +import numpy as np + +from msprobe.core.common.file_utils import FileCheckConst +from msprobe.mindspore.compare.utils import read_npy_data + + +class TestReadNpyData(unittest.TestCase): + + @patch('msprobe.mindspore.compare.utils.load_npy') + @patch('msprobe.mindspore.compare.utils.FileChecker') + @patch('os.path.join', return_value='/fake/path/to/file.npy') + def test_read_real_data_ms(self, mock_os, mock_file_checker, mock_load_npy): + mock_file_checker.return_value.common_check.return_value = '/fake/path/to/file.npy' + + mock_load_npy.return_value = np.array([1.0, 2.0, 3.0]) + + result = read_npy_data('/fake/dir', 'file_name.npy') + + mock_file_checker.assert_called_once_with('/fake/path/to/file.npy', FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.NUMPY_SUFFIX, False) + mock_load_npy.assert_called_once_with('/fake/path/to/file.npy') + self.assertTrue(np.array_equal(result, np.array([1.0, 2.0, 3.0]))) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py index cdc922cc98d59b59ec0be85833d2000cd38913c8..b1ac148ae742517c389f6de474463468ef90b572 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py @@ -10,8 +10,8 @@ import torch.distributed as dist from msprobe.core.common.file_utils import FileCheckConst from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData -from msprobe.pytorch.common.utils import parameter_adapter, get_rank_if_initialized, \ - get_tensor_rank, get_rank_id, print_rank_0, load_pt, save_pt, save_api_data, load_api_data, save_pkl, load_pkl +from msprobe.pytorch.common.utils import parameter_adapter, get_rank_if_initialized, get_tensor_rank, get_rank_id, \ + print_rank_0, load_pt, save_pt, save_api_data, load_api_data, save_pkl, load_pkl class TestParameterAdapter(unittest.TestCase): @@ -180,6 +180,7 @@ class TestLoadPt(unittest.TestCase): if os.path.isfile(self.temp_file.name): os.remove(self.temp_file.name) + class TestSavePT(unittest.TestCase): def setUp(self): @@ -195,6 +196,7 @@ class TestSavePT(unittest.TestCase): mock_torch_save.assert_called_once_with(self.tensor, self.filepath) mock_change_mode.assert_called_once_with(self.filepath, FileCheckConst.DATA_FILE_AUTHORITY) + class TestSavePT(unittest.TestCase): def setUp(self): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py index b079e646c4a8f4098bb233e3e6259ef3ebea9c94..4eda1d6d974bdc4f6699808946fafb4b136cf98e 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py @@ -3,13 +3,10 @@ import os import shutil import unittest -import numpy as np import torch -from msprobe.core.common.const import Const from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.pytorch.compare.pt_compare import PTComparator, compare +from msprobe.pytorch.compare.pt_compare import compare from msprobe.test.core_ut.compare.test_acc_compare import generate_dump_json, generate_stack_json @@ -40,36 +37,6 @@ class TestUtilsMethods(unittest.TestCase): if os.path.exists(base_dir2): shutil.rmtree(base_dir2) - def test_read_npy_data_bf16(self): - generate_bf16_pt(base_dir1) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - result = pt_comparator.read_npy_data(base_dir1, 'bf16.pt') - - target_result = torch.tensor([1, 2, 3, 4], dtype=torch.float32).numpy() - self.assertTrue(np.array_equal(result, target_result)) - - def test_read_npy_data_dict(self): - generate_dict_pt(base_dir1) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - - with self.assertRaises(CompareException) as context: - result = pt_comparator.read_npy_data(base_dir1, 'dict.pt') - self.assertEqual(context.exception.code, CompareException.DETACH_ERROR) - def test_compare(self): generate_dump_json(base_dir2) generate_stack_json(base_dir2) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..405503d898bb9eca8e23e5cde844b828513865ee --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_utils.py @@ -0,0 +1,43 @@ +import unittest +from unittest.mock import patch, MagicMock + +import torch +import numpy as np + +from msprobe.core.common.utils import CompareException +from msprobe.core.common.file_utils import FileCheckConst +from msprobe.pytorch.compare.utils import read_pt_data + + +class TestReadPtData(unittest.TestCase): + + @patch('msprobe.pytorch.compare.utils.load_pt') + @patch('msprobe.pytorch.compare.utils.FileChecker') + @patch('os.path.join', return_value='/fake/path/to/file.pt') + def test_read_pt_data(self, mock_os, mock_file_checker, mock_load_pt): + mock_file_checker.return_value.common_check.return_value = '/fake/path/to/file.pt' + + mock_tensor = MagicMock() + mock_tensor.detach.return_value = mock_tensor + mock_tensor.to.return_value = mock_tensor + mock_tensor.dtype = torch.bfloat16 + mock_tensor.numpy.return_value = np.array([1.0, 2.0, 3.0]) + mock_load_pt.return_value = mock_tensor + + result = read_pt_data('/fake/dir', 'file_name.pt') + + mock_file_checker.assert_called_once_with('/fake/path/to/file.pt', FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.PT_SUFFIX, False) + mock_load_pt.assert_called_once_with('/fake/path/to/file.pt', to_cpu=True) + mock_tensor.to.assert_called_once_with(torch.float32) + self.assertTrue(np.array_equal(result, np.array([1.0, 2.0, 3.0]))) + + @patch('os.path.join', return_value='/fake/path/to/file.pt') + @patch('msprobe.pytorch.compare.utils.FileChecker') + @patch('msprobe.pytorch.compare.utils.load_pt') + def test_read_real_data_pt_exception(self, mock_load_pt, mock_file_checker, mock_os): + mock_file_checker.return_value.common_check.return_value = '/fake/path/to/file.pt' + + mock_load_pt.side_effect = RuntimeError("Test Error") + + with self.assertRaises(CompareException): + read_pt_data('/fake/dir', 'file_name.pt')