From 1468c1ae5d382067ae3d501b5335baadfcf8cbda Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Wed, 17 Jul 2024 15:56:57 +0800 Subject: [PATCH] add ut case, part2 --- .../atat/test/core_ut/test_common_config.py | 152 ++++++++++++++++++ .../atat/test/core_ut/test_log.py | 109 +++++++++++++ .../test/mindspore_ut/test_api_kbk_dump.py | 51 ++++++ .../test/mindspore_ut/test_debugger_config.py | 42 +++++ .../mindspore_ut/test_dump_tool_factory.py | 51 ++++++ .../mindspore_ut/test_kernel_graph_dump.py | 66 ++++++++ .../test_kernel_graph_overflow_check.py | 63 ++++++++ .../atat/test/mindspore_ut/test_ms_config.py | 42 ++++- .../test_overflow_check_tool_factory.py | 51 ++++++ .../mindspore_ut/test_precision_debugger.py | 56 +++++++ .../mindspore_ut/test_task_handler_factory.py | 58 +++++++ 11 files changed, 739 insertions(+), 2 deletions(-) create mode 100644 debug/accuracy_tools/atat/test/core_ut/test_common_config.py create mode 100644 debug/accuracy_tools/atat/test/core_ut/test_log.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py diff --git a/debug/accuracy_tools/atat/test/core_ut/test_common_config.py b/debug/accuracy_tools/atat/test/core_ut/test_common_config.py new file mode 100644 index 0000000000..5dd7aee7ba --- /dev/null +++ b/debug/accuracy_tools/atat/test/core_ut/test_common_config.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common.log import logger +from atat.core.common.utils import Const +from atat.core.common.exceptions import MsaccException +from atat.core.common_config import CommonConfig, BaseConfig + + +class TestCommonConfig(TestCase): + @patch.object(logger, "error_log_with_exp") + def test_common_config(self, mock_error_log_with_exp): + json_config = dict() + + common_config = CommonConfig(json_config) + self.assertIsNone(common_config.task) + self.assertIsNone(common_config.dump_path) + self.assertIsNone(common_config.rank) + self.assertIsNone(common_config.step) + self.assertIsNone(common_config.level) + self.assertIsNone(common_config.seed) + self.assertIsNone(common_config.acl_config) + self.assertFalse(common_config.is_deterministic) + self.assertFalse(common_config.enable_dataloader) + + json_config.update({"task": "md5"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "task is invalid, it should be one of {}".format(Const.TASK_LIST)) + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": 0}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "rank is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": 0}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "step is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": [0]}) + json_config.update({"level": "L3"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "level is invalid, it should be one of {}".format(Const.LEVEL_LIST)) + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": [0]}) + json_config.update({"level": "L0"}) + json_config.update({"seed": "1234"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "seed is invalid, it should be an integer") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": [0]}) + json_config.update({"level": "L0"}) + json_config.update({"seed": 1234}) + json_config.update({"is_deterministic": "ENABLE"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "is_deterministic is invalid, it should be a boolean") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"task": Const.TENSOR}) + json_config.update({"rank": [0]}) + json_config.update({"step": [0]}) + json_config.update({"level": "L0"}) + json_config.update({"seed": 1234}) + json_config.update({"is_deterministic": True}) + json_config.update({"enable_dataloader": "ENABLE"}) + CommonConfig(json_config) + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "enable_dataloader is invalid, it should be a boolean") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + @patch.object(logger, "error_log_with_exp") + def test_base_config(self, mock_error_log_with_exp): + json_config = dict() + + base_config = BaseConfig(json_config) + base_config.check_config() + self.assertIsNone(base_config.scope) + self.assertIsNone(base_config.list) + self.assertIsNone(base_config.data_mode) + self.assertIsNone(base_config.backward_input) + self.assertIsNone(base_config.file_format) + self.assertIsNone(base_config.summary_mode) + self.assertIsNone(base_config.overflow_num) + self.assertIsNone(base_config.check_mode) + + json_config.update({"scope": "Tensor_Add"}) + base_config = BaseConfig(json_config) + base_config.check_config() + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "scope is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"scope": ["Tensor_Add"]}) + json_config.update({"list": "Tensor_Add"}) + base_config = BaseConfig(json_config) + base_config.check_config() + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "list is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + + json_config.update({"scope": ["Tensor_Add"]}) + json_config.update({"list": ["Tensor_Add"]}) + json_config.update({"data_mode": "all"}) + base_config = BaseConfig(json_config) + base_config.check_config() + self.assertEqual(mock_error_log_with_exp.call_args[0][0], + "data_mode is invalid, it should be a list") + self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), + MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) diff --git a/debug/accuracy_tools/atat/test/core_ut/test_log.py b/debug/accuracy_tools/atat/test/core_ut/test_log.py new file mode 100644 index 0000000000..6d7998d5ae --- /dev/null +++ b/debug/accuracy_tools/atat/test/core_ut/test_log.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from unittest.mock import patch, MagicMock + +from atat.core.common.log import BaseLogger, logger + + +class TestLog(TestCase): + @patch("atat.core.common.log.print") + def test__print_log(self, mock_print): + logger._print_log("level", "msg") + self.assertIn("[level] msg", mock_print.call_args[0][0]) + self.assertEqual("\n", mock_print.call_args[1].get("end")) + + logger._print_log("level", "msg", end="end") + self.assertIn("[level] msg", mock_print.call_args[0][0]) + self.assertEqual("end", mock_print.call_args[1].get("end")) + + @patch.object(BaseLogger, "_print_log") + def test_print_info_log(self, mock__print_log): + logger.info("info_msg") + mock__print_log.assert_called_with("INFO", "info_msg") + + @patch.object(BaseLogger, "_print_log") + def test_print_warn_log(self, mock__print_log): + logger.warning("warn_msg") + mock__print_log.assert_called_with("WARNING", "warn_msg") + + @patch.object(BaseLogger, "_print_log") + def test_print_error_log(self, mock__print_log): + logger.error("error_msg") + mock__print_log.assert_called_with("ERROR", "error_msg") + + @patch.object(BaseLogger, "error") + def test_error_log_with_exp(self, mock_error): + with self.assertRaises(Exception) as context: + logger.error_log_with_exp("msg", Exception("Exception")) + self.assertEqual(str(context.exception), "Exception") + mock_error.assert_called_with("msg") + + @patch.object(BaseLogger, "get_rank") + def test_on_rank_0(self, mock_get_rank): + mock_func = MagicMock() + func_rank_0 = logger.on_rank_0(mock_func) + + mock_get_rank.return_value = 1 + func_rank_0() + mock_func.assert_not_called() + + mock_get_rank.return_value = 0 + func_rank_0() + mock_func.assert_called() + + mock_func = MagicMock() + func_rank_0 = logger.on_rank_0(mock_func) + mock_get_rank.return_value = None + func_rank_0() + mock_func.assert_called() + + @patch.object(BaseLogger, "get_rank") + def test_info_on_rank_0(self, mock_get_rank): + mock_print = MagicMock() + with patch("atat.core.common.log.print", new=mock_print): + mock_get_rank.return_value = 0 + logger.info_on_rank_0("msg") + self.assertIn("[INFO] msg", mock_print.call_args[0][0]) + + mock_get_rank.return_value = 1 + logger.info_on_rank_0("msg") + mock_print.assert_called_once() + + @patch.object(BaseLogger, "get_rank") + def test_error_on_rank_0(self, mock_get_rank): + mock_print = MagicMock() + with patch("atat.core.common.log.print", new=mock_print): + mock_get_rank.return_value = 0 + logger.error_on_rank_0("msg") + self.assertIn("[ERROR] msg", mock_print.call_args[0][0]) + + mock_get_rank.return_value = 1 + logger.error_on_rank_0("msg") + mock_print.assert_called_once() + + @patch.object(BaseLogger, "get_rank") + def test_warning_on_rank_0(self, mock_get_rank): + mock_print = MagicMock() + with patch("atat.core.common.log.print", new=mock_print): + mock_get_rank.return_value = 0 + logger.warning_on_rank_0("msg") + self.assertIn("[WARNING] msg", mock_print.call_args[0][0]) + + mock_get_rank.return_value = 1 + logger.warning_on_rank_0("msg") + mock_print.assert_called_once() diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py new file mode 100644 index 0000000000..47d60999b1 --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.dump.api_kbk_dump import ApiKbkDump + + +class TestApiKbkDump(TestCase): + + def test_handle(self): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + dumper = ApiKbkDump(config) + self.assertEqual(dumper.dump_json["common_dump_settings"]["iteration"], "0|2") + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("atat.mindspore.dump.api_kbk_dump.make_dump_path_if_not_exists"), \ + patch("atat.mindspore.dump.api_kbk_dump.FileOpen"), \ + patch("atat.mindspore.dump.api_kbk_dump.json.dump"), \ + patch("atat.mindspore.dump.api_kbk_dump.logger.info"): + dumper.handle() + self.assertEqual(os.environ.get("GRAPH_OP_RUN"), "1") + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py new file mode 100644 index 0000000000..dce76d652f --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase + +from atat.core.common.utils import Const +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig + + +class TestDebuggerConfig(TestCase): + def test_init(self): + json_config = { + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1" + } + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + debugger_config = DebuggerConfig(common_config, task_config) + self.assertEqual(debugger_config.task, Const.STATISTICS) + self.assertEqual(debugger_config.file_format, "npy") + self.assertEqual(debugger_config.check_mode, "all") + + common_config.dump_path = "./path" + with self.assertRaises(Exception) as context: + DebuggerConfig(common_config, task_config) + self.assertEqual(str(context.exception), "Dump path must be absolute path.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py new file mode 100644 index 0000000000..f6626f551f --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.dump.dump_tool_factory import DumpToolFactory + + +class TestDumpToolFactory(TestCase): + + def test_create(self): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + + config.level = "module" + with self.assertRaises(Exception) as context: + DumpToolFactory.create(config) + self.assertEqual(str(context.exception), "valid level is needed.") + + config.level = "cell" + with self.assertRaises(Exception) as context: + DumpToolFactory.create(config) + self.assertEqual(str(context.exception), "Cell dump in not supported now.") + + config.level = "kernel" + dumper = DumpToolFactory.create(config) + self.assertEqual(dumper.dump_json["common_dump_settings"]["net_name"], "Net") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py new file mode 100644 index 0000000000..6c59521a17 --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump + + +class TestKernelGraphDump(TestCase): + + def test_handle(self): + json_config = { + "task": "tensor", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + task_config.data_mode = ["output"] + task_config.file_format = "bin" + config = DebuggerConfig(common_config, task_config) + dumper = KernelGraphDump(config) + self.assertEqual(dumper.dump_json["common_dump_settings"]["iteration"], "0|2") + self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "bin") + self.assertEqual(dumper.dump_json["common_dump_settings"]["input_output"], 2) + + with patch("atat.mindspore.dump.kernel_graph_dump.make_dump_path_if_not_exists"), \ + patch("atat.mindspore.dump.kernel_graph_dump.FileOpen"), \ + patch("atat.mindspore.dump.kernel_graph_dump.json.dump"), \ + patch("atat.mindspore.dump.kernel_graph_dump.logger.info"): + + os.environ["GRAPH_OP_RUN"] = "1" + with self.assertRaises(Exception) as context: + dumper.handle() + self.assertEqual(str(context.exception), "Must run in graph mode, not kbk mode") + if "GRAPH_OP_RUN" in os.environ: + del os.environ["GRAPH_OP_RUN"] + + dumper.handle() + self.assertIn("kernel_graph_dump.json", os.environ.get("MS_ACL_DUMP_CFG_PATH")) + + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + if "MS_ACL_DUMP_CFG_PATH" in os.environ: + del os.environ["MS_ACL_DUMP_CFG_PATH"] diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py new file mode 100644 index 0000000000..101482458d --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck + + +class TestKernelGraphOverflowCheck(TestCase): + + def test_handle(self): + json_config = { + "task": "overflow_check", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L2" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + task_config.check_mode = "atomic" + config = DebuggerConfig(common_config, task_config) + checker = KernelGraphOverflowCheck(config) + self.assertEqual(checker.dump_json["common_dump_settings"]["op_debug_mode"], 2) + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.make_dump_path_if_not_exists"), \ + patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.FileOpen"), \ + patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.json.dump"), \ + patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.logger.info"): + + os.environ["GRAPH_OP_RUN"] = "1" + with self.assertRaises(Exception) as context: + checker.handle() + self.assertEqual(str(context.exception), "Must run in graph mode, not kbk mode") + if "GRAPH_OP_RUN" in os.environ: + del os.environ["GRAPH_OP_RUN"] + + checker.handle() + self.assertIn("kernel_graph_overflow_check.json", os.environ.get("MINDSPORE_DUMP_CONFIG")) + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py index 6be8949684..69f3793d7d 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py @@ -1,8 +1,25 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" from unittest import TestCase from unittest.mock import patch, mock_open from atat.core.common.utils import Const -from atat.mindspore.ms_config import parse_json_config +from atat.mindspore.ms_config import (parse_json_config, parse_task_config, + TensorConfig, StatisticsConfig, OverflowCheck) class TestMsConfig(TestCase): @@ -21,7 +38,7 @@ class TestMsConfig(TestCase): } } with patch("atat.mindspore.ms_config.FileOpen", mock_open(read_data='')), \ - patch("atat.mindspore.ms_config.json.load", return_value=mock_json_data): + patch("atat.mindspore.ms_config.json.load", return_value=mock_json_data): common_config, task_config = parse_json_config("./config.json") self.assertEqual(common_config.task, Const.STATISTICS) self.assertEqual(task_config.data_mode, ["all"]) @@ -29,3 +46,24 @@ class TestMsConfig(TestCase): with self.assertRaises(Exception) as context: parse_json_config(None) self.assertEqual(str(context.exception), "json file path is None") + + def test_parse_task_config(self): + mock_json_config = { + "tensor": None, + "statistics": None, + "overflow_check": None, + "free_benchmark": None + } + + task_config = parse_task_config("tensor", mock_json_config) + self.assertTrue(isinstance(task_config, TensorConfig)) + + task_config = parse_task_config("statistics", mock_json_config) + self.assertTrue(isinstance(task_config, StatisticsConfig)) + + task_config = parse_task_config("overflow_check", mock_json_config) + self.assertTrue(isinstance(task_config, OverflowCheck)) + + with self.assertRaises(Exception) as context: + parse_task_config("free_benchmark", mock_json_config) + self.assertEqual(str(context.exception), "task is invalid.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py new file mode 100644 index 0000000000..497fe1376a --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory + + +class TestOverflowCheckToolFactory(TestCase): + + def test_create(self): + json_config = { + "task": "overflow_check", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L2" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + + config.level = "module" + with self.assertRaises(Exception) as context: + OverflowCheckToolFactory.create(config) + self.assertEqual(str(context.exception), "valid level is needed.") + + config.level = "cell" + with self.assertRaises(Exception) as context: + OverflowCheckToolFactory.create(config) + self.assertEqual(str(context.exception), "Overflow check in not supported in this mode.") + + config.level = "kernel" + dumper = OverflowCheckToolFactory.create(config) + self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "npy") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py new file mode 100644 index 0000000000..834a58e41a --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.debugger.precision_debugger import PrecisionDebugger + + +class TestPrecisionDebugger(TestCase): + def test_start(self): + class Handler: + called = False + + def handle(self): + Handler.called = True + + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + handler = Handler() + + with patch("atat.mindspore.debugger.precision_debugger.parse_json_config", + return_value=[common_config, task_config]), \ + patch("atat.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler): + debugger = PrecisionDebugger() + debugger.start() + self.assertTrue(isinstance(debugger.config, DebuggerConfig)) + self.assertTrue(Handler.called) + + PrecisionDebugger._instance = None + with self.assertRaises(Exception) as context: + debugger.start() + self.assertEqual(str(context.exception), "No instance of PrecisionDebugger found.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py new file mode 100644 index 0000000000..02cd9934cb --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from unittest import TestCase +from unittest.mock import patch + +from atat.core.common_config import CommonConfig, BaseConfig +from atat.mindspore.debugger.debugger_config import DebuggerConfig +from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump +from atat.mindspore.task_handler_factory import TaskHandlerFactory + + +class TestTaskHandlerFactory(TestCase): + + def test_create(self): + class HandlerFactory: + def create(self): + return None + + tasks = {"statistics": HandlerFactory} + + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L2" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + + handler = TaskHandlerFactory.create(config) + self.assertTrue(isinstance(handler, KernelGraphDump)) + + with patch("atat.mindspore.task_handler_factory.TaskHandlerFactory.tasks", new=tasks): + with self.assertRaises(Exception) as context: + TaskHandlerFactory.create(config) + self.assertEqual(str(context.exception), "Can not find task handler") + + config.task = "free_benchmark" + with self.assertRaises(Exception) as context: + TaskHandlerFactory.create(config) + self.assertEqual(str(context.exception), "valid task is needed.") -- Gitee