From fb503eb57f723712d3d169f52e2d612895de2107 Mon Sep 17 00:00:00 2001 From: curry3 <485078529@qq.com> Date: Wed, 3 Jul 2024 17:03:10 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90feature=E3=80=91atat=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?ut=E6=A1=86=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../atat/test/core_ut/test_utils.py | 32 ++++++++++ .../atat/test/mindspore_ut/test_ms_config.py | 31 ++++++++++ .../test_perturbed_layser.py | 0 .../atat/test/pytorch_ut/test_pt_config.py | 38 ++++++++++++ debug/accuracy_tools/atat/test/run_test.sh | 30 +++++++++ debug/accuracy_tools/atat/test/run_ut.py | 62 +++++++++++++++++++ 6 files changed, 193 insertions(+) create mode 100644 debug/accuracy_tools/atat/test/core_ut/test_utils.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py rename debug/accuracy_tools/{test/pytorch/free_benchmark => atat/test/pytorch_ut/free_benchmark/perturbed_layers}/test_perturbed_layser.py (100%) create mode 100644 debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py create mode 100644 debug/accuracy_tools/atat/test/run_test.sh create mode 100644 debug/accuracy_tools/atat/test/run_ut.py diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/atat/test/core_ut/test_utils.py new file mode 100644 index 0000000000..9492bbc9f9 --- /dev/null +++ b/debug/accuracy_tools/atat/test/core_ut/test_utils.py @@ -0,0 +1,32 @@ +from unittest import TestCase +from unittest.mock import patch + +from atat.core.utils import check_seed_all, Const, CompareException + + +class TestUtils(TestCase): + @patch("atat.core.utils.print_error_log") + def test_check_seed_all(self, mock_print_error_log): + self.assertIsNone(check_seed_all(1234, True)) + self.assertIsNone(check_seed_all(0, True)) + self.assertIsNone(check_seed_all(Const.MAX_SEED_VALUE, True)) + + with self.assertRaises(CompareException) as context: + check_seed_all(-1, True) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + + with self.assertRaises(CompareException) as context: + check_seed_all(Const.MAX_SEED_VALUE + 1, True) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + + with self.assertRaises(CompareException) as context: + check_seed_all("1234", True) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_print_error_log.assert_called_with("Seed must be integer.") + + with self.assertRaises(CompareException) as context: + check_seed_all(1234, 1) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_print_error_log.assert_called_with("seed_all mode must be bool.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py new file mode 100644 index 0000000000..0029e24bda --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py @@ -0,0 +1,31 @@ +from unittest import TestCase +from unittest.mock import patch, mock_open + +from atat.core.utils import Const +from atat.mindspore.ms_config import parse_json_config + + +class TestMsConfig(TestCase): + def test_parse_json_config(self): + mock_json_data = { + "dump_path": "./dump/", + "rank": [], + "step": [], + "level": "L1", + "seed": 1234, + "statistics": { + "scope": [], + "list": [], + "data_mode": ["all"], + "summary_mode": "statistics" + } + } + with (patch("atat.mindspore.ms_config.FileOpen", mock_open(read_data='')), + patch("atat.mindspore.ms_config.json.load", return_value=mock_json_data)): + common_config, task_config = parse_json_config("./config.json") + self.assertEqual(common_config.task, Const.STATISTICS) + self.assertEqual(task_config.data_mode, ["all"]) + + with self.assertRaises(Exception) as context: + parse_json_config(None) + self.assertEqual(str(context.exception), "json file path is None") diff --git a/debug/accuracy_tools/test/pytorch/free_benchmark/test_perturbed_layser.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py similarity index 100% rename from debug/accuracy_tools/test/pytorch/free_benchmark/test_perturbed_layser.py rename to debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py new file mode 100644 index 0000000000..8279c20776 --- /dev/null +++ b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py @@ -0,0 +1,38 @@ +from unittest import TestCase +from unittest.mock import patch, mock_open + +from atat.core.utils import Const +from atat.pytorch.pt_config import parse_json_config + + +class TestPtConfig(TestCase): + def test_parse_json_config(self): + mock_json_data = { + "task": "statistics", + "dump_path": "./dump/", + "rank": [], + "step": [], + "level": "L1", + "seed": 1234, + "statistics": { + "scope": [], + "list": [], + "data_mode": ["all"], + }, + "tensor": { + "file_format": "npy" + } + } + with (patch("atat.pytorch.pt_config.os.path.join", return_value="/path/config.json"), + patch("atat.pytorch.pt_config.FileOpen", mock_open(read_data='')), + patch("atat.pytorch.pt_config.json.load", return_value=mock_json_data)): + common_config, task_config = parse_json_config(None, None) + self.assertEqual(common_config.task, Const.STATISTICS) + self.assertEqual(task_config.data_mode, ["all"]) + + with (patch("atat.pytorch.pt_config.os.path.join", return_value="/path/config.json"), + patch("atat.pytorch.pt_config.FileOpen", mock_open(read_data='')), + patch("atat.pytorch.pt_config.json.load", return_value=mock_json_data)): + common_config, task_config = parse_json_config(None, Const.TENSOR) + self.assertEqual(common_config.task, Const.STATISTICS) + self.assertEqual(task_config.file_format, "npy") diff --git a/debug/accuracy_tools/atat/test/run_test.sh b/debug/accuracy_tools/atat/test/run_test.sh new file mode 100644 index 0000000000..1bf0ccb771 --- /dev/null +++ b/debug/accuracy_tools/atat/test/run_test.sh @@ -0,0 +1,30 @@ +#!/bin/bash +CUR_DIR=$(dirname $(readlink -f $0)) +TOP_DIR=${CUR_DIR}/.. +TEST_DIR=${TOP_DIR}/"test" +SRC_DIR=${TOP_DIR}/../ + +install_pytest() { + if ! pip show pytest &> /dev/null; then + echo "pytest not found, trying to install..." + pip install pytest + fi + + if ! pip show pytest-cov &> /dev/null; then + echo "pytest-cov not found, trying to install..." + pip install pytest-cov + fi +} + +run_ut() { + install_pytest + + export PYTHONPATH=${SRC_DIR}:${PYTHONPATH} + python3 run_ut.py +} + +main() { + cd ${TEST_DIR} && run_ut +} + +main $@ diff --git a/debug/accuracy_tools/atat/test/run_ut.py b/debug/accuracy_tools/atat/test/run_ut.py new file mode 100644 index 0000000000..7f51d266c2 --- /dev/null +++ b/debug/accuracy_tools/atat/test/run_ut.py @@ -0,0 +1,62 @@ +import os +import shutil +import subprocess +import sys + +from atat.core.log import print_info_log, print_error_log + + +def get_ignore_dirs(cur_dir): + ignore_dirs = [] + try: + import torch + import torch_npu + except ImportError: + print_info_log(f"Skipping the {cur_dir}/pytorch_ut directory") + ignore_dirs.extend(["--ignore", f"{cur_dir}/pytorch_ut"]) + + try: + import mindspore + except ImportError: + print_info_log(f"Skipping the {cur_dir}/mindspore_ut directory") + ignore_dirs.extend(["--ignore", f"{cur_dir}/mindspore_ut"]) + + return ignore_dirs + + +def run_ut(): + cur_dir = os.path.realpath(os.path.dirname(__file__)) + ut_path = cur_dir + ignore_dirs = get_ignore_dirs(cur_dir) + cov_dir = os.path.dirname(cur_dir) + report_dir = os.path.join(cur_dir, "report") + final_xml_path = os.path.join(report_dir, "final.xml") + cov_report_path = os.path.join(report_dir, "coverage.xml") + + if os.path.exists(report_dir): + shutil.rmtree(report_dir) + os.makedirs(report_dir) + + cmd = ["python3", "-m", "pytest", ut_path, "--junitxml=" + final_xml_path, "--cov=" + cov_dir, + "--cov-branch", "--cov-report=xml:" + cov_report_path] + ignore_dirs + result_ut = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + while result_ut.poll() is None: + line = result_ut.stdout.readline().strip() + if line: + print_info_log(str(line)) + + ut_flag = False + if result_ut.returncode == 0: + ut_flag = True + print_info_log("run ut successfully.") + else: + print_error_log("run ut failed.") + + return ut_flag + + +if __name__ == "__main__": + if run_ut(): + sys.exit(0) + else: + sys.exit(1) -- Gitee