diff --git a/debug/accuracy_tools/atat/test/core_ut/test_common_config.py b/debug/accuracy_tools/atat/test/core_ut/test_common_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5dd7aee7ba3f149263b9382d937616a20a6c6293 --- /dev/null +++ b/debug/accuracy_tools/atat/test/core_ut/test_common_config.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common.log import logger +from atat.core.common.utils import Const +from atat.core.common.exceptions import MsaccException +from atat.core.common_config import CommonConfig, BaseConfig + + +class TestCommonConfig(TestCase): + @patch.object(logger, "error_log_with_exp") + def test_common_config(self, mock_error_log_with_exp): + json_config = dict() + + common_config = CommonConfig(json_config) + self.assertIsNone(common_config.task) + self.assertIsNone(common_config.dump_path) + self.assertIsNone(common_config.rank) + self.assertIsNone(common_config.step) + self.assertIsNone(common_config.level) + self.assertIsNone(common_config.seed) + self.assertIsNone(common_config.acl_config) + self.assertFalse(common_config.is_deterministic) + self.assertFalse(common_config.enable_dataloader) + + json_config.update({"task": "md5"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "task is invalid, it should be one of {}".format(Const.TASK_LIST)) + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": 0}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "rank is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": 0}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "step is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": [0]}) + json_config.update({"level": "L3"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "level is invalid, it should be one of {}".format(Const.LEVEL_LIST)) + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": [0]}) + json_config.update({"level": "L0"}) + json_config.update({"seed": "1234"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "seed is invalid, it should be an integer") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": [0]}) + json_config.update({"level": "L0"}) + json_config.update({"seed": 1234}) + json_config.update({"is_deterministic": "ENABLE"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "is_deterministic is invalid, it should be a boolean") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": [0]}) + json_config.update({"level": "L0"}) + json_config.update({"seed": 1234}) + json_config.update({"is_deterministic": True}) + json_config.update({"enable_dataloader": "ENABLE"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "enable_dataloader is invalid, it should be a boolean") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + @patch.object(logger, "error_log_with_exp") + def test_base_config(self, mock_error_log_with_exp): + json_config = dict() + + base_config = BaseConfig(json_config) + base_config.check_config() + self.assertIsNone(base_config.scope) + self.assertIsNone(base_config.list) + self.assertIsNone(base_config.data_mode) + self.assertIsNone(base_config.backward_input) + self.assertIsNone(base_config.file_format) + self.assertIsNone(base_config.summary_mode) + self.assertIsNone(base_config.overflow_num) + self.assertIsNone(base_config.check_mode) + + json_config.update({"scope": "Tensor_Add"}) + base_config = BaseConfig(json_config) + base_config.check_config() + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "scope is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"scope": ["Tensor_Add"]}) + json_config.update({"list": "Tensor_Add"}) + base_config = BaseConfig(json_config) + base_config.check_config() + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "list is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"scope": ["Tensor_Add"]}) + json_config.update({"list": ["Tensor_Add"]}) + json_config.update({"data_mode": "all"}) + base_config = BaseConfig(json_config) + base_config.check_config() + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "data_mode is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) diff --git a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py new file mode 100644 index 0000000000000000000000000000000000000000..3305acb0b74a1862be5f654c07d2e068337d22e9 --- /dev/null +++ b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from unittest import TestCase +from unittest.mock import patch, MagicMock + +from atat.core.common.log import logger +from atat.core.common.exceptions import FileCheckException +from atat.core.common.file_check import (FileCheckConst, + check_link, + check_path_length, + check_path_exists, + check_path_readability, + check_path_writability, + check_path_executable, + check_other_user_writable, + check_path_owner_consistent, + check_path_pattern_vaild, + check_file_size, + check_common_file_size, + check_file_suffix, + check_path_type) + + +class TestFileCheckUtil(TestCase): + @patch.object(logger, "error") + def test_check_link(self, mock_logger_error): + with patch("atat.core.common.file_check.os.path.islink", return_value=True): + with self.assertRaises(FileCheckException) as context: + check_link("link_path") + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.SOFT_LINK_ERROR)) + mock_logger_error.assert_called_with("The file path link_path is a soft link.") + + @patch.object(logger, "error") + def test_check_path_length(self, mock_logger_error): + path = "P" * (FileCheckConst.DIRECTORY_LENGTH + 1) + with self.assertRaises(FileCheckException) as context: + check_path_length(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path length exceeds limit.") + + path = "P" * (FileCheckConst.FILE_NAME_LENGTH + 1) + with self.assertRaises(FileCheckException) as context: + check_path_length(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path length exceeds limit.") + + path = "P" * (FileCheckConst.FILE_NAME_LENGTH - 5) + with self.assertRaises(FileCheckException) as context: + check_path_length(path, name_length=FileCheckConst.FILE_NAME_LENGTH - 6) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path length exceeds limit.") + + @patch.object(logger, "error") + def test_check_path_exists(self, mock_logger_error): + with patch("atat.core.common.file_check.os.path.exists", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_exists("file_path") + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path file_path does not exist.") + + @patch.object(logger, "error") + def test_check_path_readability(self, mock_logger_error): + path = "file_path" + with patch("atat.core.common.file_check.os.access", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_readability(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} is not readable.") + + mock_access = MagicMock() + mock_access.return_value = True + with patch("atat.core.common.file_check.os.access", new=mock_access): + check_path_readability(path) + self.assertEqual(mock_access.call_args[0], (path, os.R_OK)) + + @patch.object(logger, "error") + def test_check_path_writability(self, mock_logger_error): + path = "file_path" + with patch("atat.core.common.file_check.os.access", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_writability(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} is not writable.") + + mock_access = MagicMock() + mock_access.return_value = True + with patch("atat.core.common.file_check.os.access", new=mock_access): + check_path_writability(path) + self.assertEqual(mock_access.call_args[0], (path, os.W_OK)) + + @patch.object(logger, "error") + def test_check_path_executable(self, mock_logger_error): + path = "file_path" + with patch("atat.core.common.file_check.os.access", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_executable(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} is not executable.") + + mock_access = MagicMock() + mock_access.return_value = True + with patch("atat.core.common.file_check.os.access", new=mock_access): + check_path_executable(path) + self.assertEqual(mock_access.call_args[0], (path, os.X_OK)) + + @patch.object(logger, "error") + def test_check_other_user_writable(self, mock_logger_error): + class TestStat: + def __init__(self, mode): + self.st_mode = mode + + path = "file_path" + mock_stat = TestStat(0o002) + with patch("atat.core.common.file_check.os.stat", return_value=mock_stat): + with self.assertRaises(FileCheckException) as context: + check_other_user_writable(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} may be insecure " + "because other users have write permissions. ") + + @patch.object(logger, "error") + def test_check_path_owner_consistent(self, mock_logger_error): + file_path = os.path.realpath(__file__) + file_owner = os.stat(file_path).st_uid + with patch("atat.core.common.file_check.os.getuid", return_value=file_owner+1): + with self.assertRaises(FileCheckException) as context: + check_path_owner_consistent(file_path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {file_path} may be insecure " + "because is does not belong to you.") + + @patch.object(logger, "error") + def test_check_path_pattern_vaild(self, mock_logger_error): + path = "path" + mock_re_match = MagicMock() + mock_re_match.return_value = False + with patch("atat.core.common.file_check.re.match", new=mock_re_match): + with self.assertRaises(FileCheckException) as context: + check_path_pattern_vaild(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} contains special characters.") + mock_re_match.assert_called_with(FileCheckConst.FILE_VALID_PATTERN, path) + + @patch.object(logger, "error") + def test_check_file_size(self, mock_logger_error): + file_path = os.path.realpath(__file__) + file_size = os.path.getsize(file_path) + max_size = file_size + with self.assertRaises(FileCheckException) as context: + check_file_size(file_path, max_size) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_TOO_LARGE_ERROR)) + mock_logger_error.assert_called_with(f"The size of file path {file_path} exceeds {max_size} bytes.") + + def test_check_common_file_size(self): + mock_check_file_size = MagicMock() + with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \ + patch("atat.core.common.file_check.check_file_size", new=mock_check_file_size): + for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): + check_common_file_size(suffix) + mock_check_file_size.assert_called_with(suffix, max_size) + + @patch.object(logger, "error") + def test_check_file_suffix(self, mock_logger_error): + file_path = "file_path" + suffix = "suffix" + with self.assertRaises(FileCheckException) as context: + check_file_suffix(file_path, suffix) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) + mock_logger_error.assert_called_with(f"The {file_path} should be a {suffix} file!") + + @patch.object(logger, "error") + def test_check_path_type(self, mock_logger_error): + file_path = "file_path" + + with patch("atat.core.common.file_check.os.path.isfile", return_value=False), \ + patch("atat.core.common.file_check.os.path.isdir", return_value=True): + with self.assertRaises(FileCheckException) as context: + check_path_type(file_path, FileCheckConst.FILE) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) + mock_logger_error.assert_called_with(f"The {file_path} should be a file!") + + with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \ + patch("atat.core.common.file_check.os.path.isdir", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_type(file_path, FileCheckConst.DIR) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) + mock_logger_error.assert_called_with(f"The {file_path} should be a dictionary!") diff --git a/debug/accuracy_tools/atat/test/core_ut/test_log.py b/debug/accuracy_tools/atat/test/core_ut/test_log.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7998d5ae068bf5044eb097b92cd5dbc6971e0a --- /dev/null +++ b/debug/accuracy_tools/atat/test/core_ut/test_log.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from unittest.mock import patch, MagicMock + +from atat.core.common.log import BaseLogger, logger + + +class TestLog(TestCase): + @patch("atat.core.common.log.print") + def test__print_log(self, mock_print): + logger._print_log("level", "msg") + self.assertIn("[level] msg", mock_print.call_args[0][0]) + self.assertEqual("\n", mock_print.call_args[1].get("end")) + + logger._print_log("level", "msg", end="end") + self.assertIn("[level] msg", mock_print.call_args[0][0]) + self.assertEqual("end", mock_print.call_args[1].get("end")) + + @patch.object(BaseLogger, "_print_log") + def test_print_info_log(self, mock__print_log): + logger.info("info_msg") + mock__print_log.assert_called_with("INFO", "info_msg") + + @patch.object(BaseLogger, "_print_log") + def test_print_warn_log(self, mock__print_log): + logger.warning("warn_msg") + mock__print_log.assert_called_with("WARNING", "warn_msg") + + @patch.object(BaseLogger, "_print_log") + def test_print_error_log(self, mock__print_log): + logger.error("error_msg") + mock__print_log.assert_called_with("ERROR", "error_msg") + + @patch.object(BaseLogger, "error") + def test_error_log_with_exp(self, mock_error): + with self.assertRaises(Exception) as context: + logger.error_log_with_exp("msg", Exception("Exception")) + self.assertEqual(str(context.exception), "Exception") + mock_error.assert_called_with("msg") + + @patch.object(BaseLogger, "get_rank") + def test_on_rank_0(self, mock_get_rank): + mock_func = MagicMock() + func_rank_0 = logger.on_rank_0(mock_func) + + mock_get_rank.return_value = 1 + func_rank_0() + mock_func.assert_not_called() + + mock_get_rank.return_value = 0 + func_rank_0() + mock_func.assert_called() + + mock_func = MagicMock() + func_rank_0 = logger.on_rank_0(mock_func) + mock_get_rank.return_value = None + func_rank_0() + mock_func.assert_called() + + @patch.object(BaseLogger, "get_rank") + def test_info_on_rank_0(self, mock_get_rank): + mock_print = MagicMock() + with patch("atat.core.common.log.print", new=mock_print): + mock_get_rank.return_value = 0 + logger.info_on_rank_0("msg") + self.assertIn("[INFO] msg", mock_print.call_args[0][0]) + + mock_get_rank.return_value = 1 + logger.info_on_rank_0("msg") + mock_print.assert_called_once() + + @patch.object(BaseLogger, "get_rank") + def test_error_on_rank_0(self, mock_get_rank): + mock_print = MagicMock() + with patch("atat.core.common.log.print", new=mock_print): + mock_get_rank.return_value = 0 + logger.error_on_rank_0("msg") + self.assertIn("[ERROR] msg", mock_print.call_args[0][0]) + + mock_get_rank.return_value = 1 + logger.error_on_rank_0("msg") + mock_print.assert_called_once() + + @patch.object(BaseLogger, "get_rank") + def test_warning_on_rank_0(self, mock_get_rank): + mock_print = MagicMock() + with patch("atat.core.common.log.print", new=mock_print): + mock_get_rank.return_value = 0 + logger.warning_on_rank_0("msg") + self.assertIn("[WARNING] msg", mock_print.call_args[0][0]) + + mock_get_rank.return_value = 1 + logger.warning_on_rank_0("msg") + mock_print.assert_called_once() diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/atat/test/core_ut/test_utils.py index 89734f2c572bff2aa864db16f23dfe8665042f74..fae0e4e2558bbaad058c8f172549d69cf3dfc13d 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_utils.py +++ b/debug/accuracy_tools/atat/test/core_ut/test_utils.py @@ -1,13 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +import uuid + from unittest import TestCase -from unittest.mock import patch +from unittest.mock import patch, MagicMock, mock_open -from atat.core.common.utils import check_seed_all, Const, CompareException, check_inplace_op from atat.core.common.log import logger +from atat.core.common.utils import (Const, CompareException, + check_seed_all, + check_inplace_op, + make_dump_path_if_not_exists, + check_mode_valid, + check_switch_valid, + check_dump_mode_valid, + check_summary_mode_valid, + check_summary_only_valid, + check_file_or_directory_path, + check_compare_param, + check_configuration_param, + is_starts_with, + _check_json, + check_json_file, + check_file_size, + check_regex_prefix_format_valid, + get_dump_data_path, + task_dumppath_get) +from atat.core.common.file_check import FileCheckConst class TestUtils(TestCase): @patch.object(logger, "error") - def test_check_seed_all(self, mock_print_error_log): + def test_check_seed_all(self, mock_error): self.assertIsNone(check_seed_all(1234, True)) self.assertIsNone(check_seed_all(0, True)) self.assertIsNone(check_seed_all(Const.MAX_SEED_VALUE, True)) @@ -15,23 +53,23 @@ class TestUtils(TestCase): with self.assertRaises(CompareException) as context: check_seed_all(-1, True) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") with self.assertRaises(CompareException) as context: check_seed_all(Const.MAX_SEED_VALUE + 1, True) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") with self.assertRaises(CompareException) as context: check_seed_all("1234", True) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with("Seed must be integer.") + mock_error.assert_called_with("Seed must be integer.") with self.assertRaises(CompareException) as context: check_seed_all(1234, 1) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with("seed_all mode must be bool.") - + mock_error.assert_called_with("seed_all mode must be bool.") + def test_check_inplace_op(self): test_prefix_1 = "Distributed.broadcast.0.forward.input.0" self.assertTrue(check_inplace_op(test_prefix_1)) @@ -39,3 +77,268 @@ class TestUtils(TestCase): self.assertFalse(check_inplace_op(test_prefix_2)) test_prefix_3 = "Torch.sum.0.backward.output.0" self.assertFalse(check_inplace_op(test_prefix_3)) + + @patch.object(logger, "error") + def test_make_dump_path_if_not_exists(self, mock_error): + file_path = os.path.realpath(__file__) + dirname = os.path.dirname(file_path) + str(uuid.uuid4()) + + def test_mkdir(self, **kwargs): + raise OSError + + if not os.path.exists(dirname): + with patch("atat.core.common.utils.Path.mkdir", new=test_mkdir): + with self.assertRaises(CompareException) as context: + make_dump_path_if_not_exists(dirname) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + + make_dump_path_if_not_exists(file_path) + mock_error.assert_called_with(f"{file_path} already exists and is not a directory.") + + def test_check_mode_valid(self): + with self.assertRaises(ValueError) as context: + check_mode_valid("all", scope="scope") + self.assertEqual(str(context.exception), "scope param set invalid, it's must be a list.") + + with self.assertRaises(ValueError) as context: + check_mode_valid("all", api_list="api_list") + self.assertEqual(str(context.exception), "api_list param set invalid, it's must be a list.") + + mode = "all_list" + with self.assertRaises(CompareException) as context: + check_mode_valid(mode) + self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_MODE) + self.assertEqual(str(context.exception), + f"Current mode '{mode}' is not supported. Please use the field in {Const.DUMP_MODE}") + + mode = "list" + with self.assertRaises(ValueError) as context: + check_mode_valid(mode) + self.assertEqual(str(context.exception), + "set_dump_switch, scope param set invalid, it's should not be an empty list.") + + @patch.object(logger, "error") + def test_check_switch_valid(self, mock_error): + with self.assertRaises(CompareException) as context: + check_switch_valid("Close") + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Please set switch with 'ON' or 'OFF'.") + + @patch.object(logger, "warning") + def test_check_dump_mode_valid(self, mock_warning): + dump_mode = check_dump_mode_valid("all") + mock_warning.assert_called_with("Please set dump_mode as a list.") + self.assertEqual(dump_mode, ["forward", "backward", "input", "output"]) + + with self.assertRaises(ValueError) as context: + check_dump_mode_valid("all_forward") + self.assertEqual(str(context.exception), + "Please set dump_mode as a list containing one or more of the following: " + + "'all', 'forward', 'backward', 'input', 'output'.") + + def test_check_summary_mode_valid(self): + with self.assertRaises(CompareException) as context: + check_summary_mode_valid("MD5") + self.assertEqual(context.exception.code, CompareException.INVALID_SUMMARY_MODE) + self.assertEqual(str(context.exception), "The summary_mode is not valid") + + @patch.object(logger, "error") + def test_check_summary_only_valid(self, mock_error): + summary_only = check_summary_only_valid(True) + self.assertTrue(summary_only) + + with self.assertRaises(CompareException) as context: + check_summary_only_valid("True") + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Params summary_only only support True or False.") + + def test_check_file_or_directory_path(self): + class TestFileChecker: + file_path = "" + path_type = "" + ability = "" + checked = False + + def __init__(self, file_path, path_type, ability=None): + TestFileChecker.file_path = file_path + TestFileChecker.path_type = path_type + TestFileChecker.ability = ability + + def common_check(self): + TestFileChecker.checked = True + + file_path = os.path.realpath(__file__) + dirname = os.path.dirname(file_path) + + with patch("atat.core.common.utils.FileChecker", new=TestFileChecker): + check_file_or_directory_path(file_path, isdir=False) + self.assertTrue(TestFileChecker.checked) + self.assertEqual(TestFileChecker.file_path, file_path) + self.assertEqual(TestFileChecker.path_type, FileCheckConst.FILE) + self.assertEqual(TestFileChecker.ability, FileCheckConst.READ_ABLE) + + TestFileChecker.checked = False + with patch("atat.core.common.utils.FileChecker", new=TestFileChecker): + check_file_or_directory_path(dirname, isdir=True) + self.assertTrue(TestFileChecker.checked) + self.assertEqual(TestFileChecker.file_path, dirname) + self.assertEqual(TestFileChecker.path_type, FileCheckConst.DIR) + self.assertEqual(TestFileChecker.ability, FileCheckConst.WRITE_ABLE) + + @patch.object(logger, "error") + def test_check_compare_param(self, mock_error): + params = { + "npu_json_path": "npu_json_path", + "bench_json_path": "bench_json_path", + "stack_json_path": "stack_json_path", + "npu_dump_data_dir": "npu_dump_data_dir", + "bench_dump_data_dir": "bench_dump_data_dir" + } + + call_args = [ + ("npu_json_path", False), + ("bench_json_path", False), + ("stack_json_path", False), + ("npu_dump_data_dir", True), + ("bench_dump_data_dir", True), + ("output_path", True), + ("npu_json_path", False), + ("bench_json_path", False), + ("stack_json_path", False), + ("output_path", True) + ] + + with self.assertRaises(CompareException) as context: + check_compare_param("npu_json_path", "output_path") + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Invalid input parameters") + + mock_check_file_or_directory_path = MagicMock() + mock_check_json_file = MagicMock() + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.check_json_file", new=mock_check_json_file), \ + patch("atat.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path): + check_compare_param(params, "output_path") + check_compare_param(params, "output_path", summary_compare=False, md5_compare=True) + for i in range(len(call_args)): + self.assertEqual(mock_check_file_or_directory_path.call_args_list[i][0], call_args[i]) + self.assertEqual(len(mock_check_json_file.call_args[0]), 4) + self.assertEqual(mock_check_json_file.call_args[0][0], params) + + @patch.object(logger, "error") + def test_check_configuration_param(self, mock_error): + with self.assertRaises(CompareException) as context: + check_configuration_param(stack_mode="False", auto_analyze=True, fuzzy_match=False) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Invalid input parameters which should be only bool type.") + + def test_is_starts_with(self): + string = "input_slot0" + self.assertFalse(is_starts_with(string, [])) + self.assertFalse(is_starts_with("", ["input"])) + self.assertFalse(is_starts_with(string, ["output"])) + self.assertTrue(is_starts_with(string, ["input", "output"])) + + @patch.object(logger, "error") + def test__check_json(self, mock_error): + class TestOpen: + def __init__(self, string): + self.string = string + + def readline(self): + return self.string + + def seek(self, begin, end): + self.string = str(begin) + "_" + str(end) + + with self.assertRaises(CompareException) as context: + _check_json(TestOpen(""), "test.json") + self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_FILE) + mock_error.assert_called_with("dump file test.json have empty line!") + + handler = TestOpen("jons file\n") + _check_json(handler, "test.json") + self.assertEqual(handler.string, "0_0") + + @patch("atat.core.common.utils._check_json") + def test_check_json_file(self, _mock_check_json): + input_param = { + "npu_json_path": "npu_json_path", + "bench_json_path": "bench_json_path", + "stack_json_path": "stack_json_path" + } + check_json_file(input_param, "npu_json", "bench_json", "stack_json") + self.assertEqual(_mock_check_json.call_args_list[0][0], ("npu_json", "npu_json_path")) + self.assertEqual(_mock_check_json.call_args_list[1][0], ("bench_json", "bench_json_path")) + self.assertEqual(_mock_check_json.call_args_list[2][0], ("stack_json", "stack_json_path")) + + @patch.object(logger, "error") + def test_check_file_size(self, mock_error): + with patch("atat.core.common.utils.os.path.getsize", return_value=120): + with self.assertRaises(CompareException) as context: + check_file_size("input_file", 100) + self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR) + mock_error.assert_called_with("The size (120) of input_file exceeds (100) bytes, tools not support.") + + def test_check_regex_prefix_format_valid(self): + prefix = "A" * 21 + with self.assertRaises(ValueError) as context: + check_regex_prefix_format_valid(prefix) + self.assertEqual(str(context.exception), f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, " + f"while current length is {len(prefix)}") + + prefix = "(prefix)" + with self.assertRaises(ValueError) as context: + check_regex_prefix_format_valid(prefix) + self.assertEqual(str(context.exception), f"prefix contains invalid characters, " + f"prefix pattern {Const.REGEX_PREFIX_PATTERN}") + + @patch("atat.core.common.utils.check_file_or_directory_path") + def test_get_dump_data_path(self, mock_check_file_or_directory_path): + file_path = os.path.realpath(__file__) + dirname = os.path.dirname(file_path) + + dump_data_path, file_is_exist = get_dump_data_path(dirname) + self.assertEqual(mock_check_file_or_directory_path.call_args[0], (dirname, True)) + self.assertEqual(dump_data_path, dirname) + self.assertTrue(file_is_exist) + + @patch.object(logger, "error") + def test_task_dumppath_get(self, mock_error): + input_param = { + "npu_json_path": None, + "bench_json_path": "bench_json_path" + } + npu_json = { + "task": Const.TENSOR, + "dump_data_dir": "dump_data_dir", + "data": "data" + } + + with self.assertRaises(CompareException) as context: + task_dumppath_get(input_param) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + mock_error.assert_called_with("Please check the json path is valid.") + + input_param["npu_json_path"] = "npu_json_path" + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.json.load", return_value=npu_json): + summary_compare, md5_compare = task_dumppath_get(input_param) + self.assertFalse(summary_compare) + self.assertFalse(md5_compare) + + npu_json["task"] = Const.STATISTICS + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.json.load", return_value=npu_json), \ + patch("atat.core.common.utils.md5_find", return_value=True): + summary_compare, md5_compare = task_dumppath_get(input_param) + self.assertFalse(summary_compare) + self.assertTrue(md5_compare) + + npu_json["task"] = Const.OVERFLOW_CHECK + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.json.load", return_value=npu_json): + with self.assertRaises(CompareException) as context: + task_dumppath_get(input_param) + self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR) + mock_error.assert_called_with("Compare is not required for overflow_check or free_benchmark.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..47d60999b16a16d7593559c581354d4674438343 --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.dump.api_kbk_dump import ApiKbkDump + + +class TestApiKbkDump(TestCase): + + def test_handle(self): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + dumper = ApiKbkDump(config) + self.assertEqual(dumper.dump_json["common_dump_settings"]["iteration"], "0|2") + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("atat.mindspore.dump.api_kbk_dump.make_dump_path_if_not_exists"), \ + patch("atat.mindspore.dump.api_kbk_dump.FileOpen"), \ + patch("atat.mindspore.dump.api_kbk_dump.json.dump"), \ + patch("atat.mindspore.dump.api_kbk_dump.logger.info"): + dumper.handle() + self.assertEqual(os.environ.get("GRAPH_OP_RUN"), "1") + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dce76d652f3dab064ff89a04e31cb871093a1098 --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase + +from atat.core.common.utils import Const +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig + + +class TestDebuggerConfig(TestCase): + def test_init(self): + json_config = { + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1" + } + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + debugger_config = DebuggerConfig(common_config, task_config) + self.assertEqual(debugger_config.task, Const.STATISTICS) + self.assertEqual(debugger_config.file_format, "npy") + self.assertEqual(debugger_config.check_mode, "all") + + common_config.dump_path = "./path" + with self.assertRaises(Exception) as context: + DebuggerConfig(common_config, task_config) + self.assertEqual(str(context.exception), "Dump path must be absolute path.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..f6626f551fec02438e39ed474374654201e6204c --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.dump.dump_tool_factory import DumpToolFactory + + +class TestDumpToolFactory(TestCase): + + def test_create(self): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + + config.level = "module" + with self.assertRaises(Exception) as context: + DumpToolFactory.create(config) + self.assertEqual(str(context.exception), "valid level is needed.") + + config.level = "cell" + with self.assertRaises(Exception) as context: + DumpToolFactory.create(config) + self.assertEqual(str(context.exception), "Cell dump in not supported now.") + + config.level = "kernel" + dumper = DumpToolFactory.create(config) + self.assertEqual(dumper.dump_json["common_dump_settings"]["net_name"], "Net") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..6c59521a17d57585170753f4d935b6109b920595 --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump + + +class TestKernelGraphDump(TestCase): + + def test_handle(self): + json_config = { + "task": "tensor", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + task_config.data_mode = ["output"] + task_config.file_format = "bin" + config = DebuggerConfig(common_config, task_config) + dumper = KernelGraphDump(config) + self.assertEqual(dumper.dump_json["common_dump_settings"]["iteration"], "0|2") + self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "bin") + self.assertEqual(dumper.dump_json["common_dump_settings"]["input_output"], 2) + + with patch("atat.mindspore.dump.kernel_graph_dump.make_dump_path_if_not_exists"), \ + patch("atat.mindspore.dump.kernel_graph_dump.FileOpen"), \ + patch("atat.mindspore.dump.kernel_graph_dump.json.dump"), \ + patch("atat.mindspore.dump.kernel_graph_dump.logger.info"): + + os.environ["GRAPH_OP_RUN"] = "1" + with self.assertRaises(Exception) as context: + dumper.handle() + self.assertEqual(str(context.exception), "Must run in graph mode, not kbk mode") + if "GRAPH_OP_RUN" in os.environ: + del os.environ["GRAPH_OP_RUN"] + + dumper.handle() + self.assertIn("kernel_graph_dump.json", os.environ.get("MS_ACL_DUMP_CFG_PATH")) + + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + if "MS_ACL_DUMP_CFG_PATH" in os.environ: + del os.environ["MS_ACL_DUMP_CFG_PATH"] diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py new file mode 100644 index 0000000000000000000000000000000000000000..101482458dc0a901e0066937ae6df9e9a23fe4fc --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck + + +class TestKernelGraphOverflowCheck(TestCase): + + def test_handle(self): + json_config = { + "task": "overflow_check", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L2" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + task_config.check_mode = "atomic" + config = DebuggerConfig(common_config, task_config) + checker = KernelGraphOverflowCheck(config) + self.assertEqual(checker.dump_json["common_dump_settings"]["op_debug_mode"], 2) + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.make_dump_path_if_not_exists"), \ + patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.FileOpen"), \ + patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.json.dump"), \ + patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.logger.info"): + + os.environ["GRAPH_OP_RUN"] = "1" + with self.assertRaises(Exception) as context: + checker.handle() + self.assertEqual(str(context.exception), "Must run in graph mode, not kbk mode") + if "GRAPH_OP_RUN" in os.environ: + del os.environ["GRAPH_OP_RUN"] + + checker.handle() + self.assertIn("kernel_graph_overflow_check.json", os.environ.get("MINDSPORE_DUMP_CONFIG")) + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py index 6be8949684c89012f0dc2165ba24eab4e7a77f1c..69f3793d7d598b98a6480b5a470bb7d341dc114b 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py @@ -1,8 +1,25 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" from unittest import TestCase from unittest.mock import patch, mock_open from atat.core.common.utils import Const -from atat.mindspore.ms_config import parse_json_config +from atat.mindspore.ms_config import (parse_json_config, parse_task_config, + TensorConfig, StatisticsConfig, OverflowCheck) class TestMsConfig(TestCase): @@ -21,7 +38,7 @@ class TestMsConfig(TestCase): } } with patch("atat.mindspore.ms_config.FileOpen", mock_open(read_data='')), \ - patch("atat.mindspore.ms_config.json.load", return_value=mock_json_data): + patch("atat.mindspore.ms_config.json.load", return_value=mock_json_data): common_config, task_config = parse_json_config("./config.json") self.assertEqual(common_config.task, Const.STATISTICS) self.assertEqual(task_config.data_mode, ["all"]) @@ -29,3 +46,24 @@ class TestMsConfig(TestCase): with self.assertRaises(Exception) as context: parse_json_config(None) self.assertEqual(str(context.exception), "json file path is None") + + def test_parse_task_config(self): + mock_json_config = { + "tensor": None, + "statistics": None, + "overflow_check": None, + "free_benchmark": None + } + + task_config = parse_task_config("tensor", mock_json_config) + self.assertTrue(isinstance(task_config, TensorConfig)) + + task_config = parse_task_config("statistics", mock_json_config) + self.assertTrue(isinstance(task_config, StatisticsConfig)) + + task_config = parse_task_config("overflow_check", mock_json_config) + self.assertTrue(isinstance(task_config, OverflowCheck)) + + with self.assertRaises(Exception) as context: + parse_task_config("free_benchmark", mock_json_config) + self.assertEqual(str(context.exception), "task is invalid.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..497fe1376abcff0607d6afd3f2b03d94f963bcd7 --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory + + +class TestOverflowCheckToolFactory(TestCase): + + def test_create(self): + json_config = { + "task": "overflow_check", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L2" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + + config.level = "module" + with self.assertRaises(Exception) as context: + OverflowCheckToolFactory.create(config) + self.assertEqual(str(context.exception), "valid level is needed.") + + config.level = "cell" + with self.assertRaises(Exception) as context: + OverflowCheckToolFactory.create(config) + self.assertEqual(str(context.exception), "Overflow check in not supported in this mode.") + + config.level = "kernel" + dumper = OverflowCheckToolFactory.create(config) + self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "npy") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..834a58e41a426d975ac97b0f757db5a0432a297f --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.debugger.precision_debugger import PrecisionDebugger + + +class TestPrecisionDebugger(TestCase): + def test_start(self): + class Handler: + called = False + + def handle(self): + Handler.called = True + + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + handler = Handler() + + with patch("atat.mindspore.debugger.precision_debugger.parse_json_config", + return_value=[common_config, task_config]), \ + patch("atat.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler): + debugger = PrecisionDebugger() + debugger.start() + self.assertTrue(isinstance(debugger.config, DebuggerConfig)) + self.assertTrue(Handler.called) + + PrecisionDebugger._instance = None + with self.assertRaises(Exception) as context: + debugger.start() + self.assertEqual(str(context.exception), "No instance of PrecisionDebugger found.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..02cd9934cb1635b9cec68bcd4b54dd2753ea23e9 --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump +from atat.mindspore.task_handler_factory import TaskHandlerFactory + + +class TestTaskHandlerFactory(TestCase): + + def test_create(self): + class HandlerFactory: + def create(self): + return None + + tasks = {"statistics": HandlerFactory} + + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L2" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + + handler = TaskHandlerFactory.create(config) + self.assertTrue(isinstance(handler, KernelGraphDump)) + + with patch("atat.mindspore.task_handler_factory.TaskHandlerFactory.tasks", new=tasks): + with self.assertRaises(Exception) as context: + TaskHandlerFactory.create(config) + self.assertEqual(str(context.exception), "Can not find task handler") + + config.task = "free_benchmark" + with self.assertRaises(Exception) as context: + TaskHandlerFactory.create(config) + self.assertEqual(str(context.exception), "valid task is needed.")