From 8b4e1abde918337f93bd1376a56b35f87e31dbdc Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Wed, 17 Jul 2024 17:13:54 +0800 Subject: [PATCH 1/2] add ut case - part1 --- .../atat/test/core_ut/test_file_check.py | 218 ++++++++++++ .../atat/test/core_ut/test_utils.py | 319 +++++++++++++++++- 2 files changed, 529 insertions(+), 8 deletions(-) create mode 100644 debug/accuracy_tools/atat/test/core_ut/test_file_check.py diff --git a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py new file mode 100644 index 0000000000..3305acb0b7 --- /dev/null +++ b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from unittest import TestCase +from unittest.mock import patch, MagicMock + +from atat.core.common.log import logger +from atat.core.common.exceptions import FileCheckException +from atat.core.common.file_check import (FileCheckConst, + check_link, + check_path_length, + check_path_exists, + check_path_readability, + check_path_writability, + check_path_executable, + check_other_user_writable, + check_path_owner_consistent, + check_path_pattern_vaild, + check_file_size, + check_common_file_size, + check_file_suffix, + check_path_type) + + +class TestFileCheckUtil(TestCase): + @patch.object(logger, "error") + def test_check_link(self, mock_logger_error): + with patch("atat.core.common.file_check.os.path.islink", return_value=True): + with self.assertRaises(FileCheckException) as context: + check_link("link_path") + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.SOFT_LINK_ERROR)) + mock_logger_error.assert_called_with("The file path link_path is a soft link.") + + @patch.object(logger, "error") + def test_check_path_length(self, mock_logger_error): + path = "P" * (FileCheckConst.DIRECTORY_LENGTH + 1) + with self.assertRaises(FileCheckException) as context: + check_path_length(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path length exceeds limit.") + + path = "P" * (FileCheckConst.FILE_NAME_LENGTH + 1) + with self.assertRaises(FileCheckException) as context: + check_path_length(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path length exceeds limit.") + + path = "P" * (FileCheckConst.FILE_NAME_LENGTH - 5) + with self.assertRaises(FileCheckException) as context: + check_path_length(path, name_length=FileCheckConst.FILE_NAME_LENGTH - 6) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path length exceeds limit.") + + @patch.object(logger, "error") + def test_check_path_exists(self, mock_logger_error): + with patch("atat.core.common.file_check.os.path.exists", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_exists("file_path") + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path file_path does not exist.") + + @patch.object(logger, "error") + def test_check_path_readability(self, mock_logger_error): + path = "file_path" + with patch("atat.core.common.file_check.os.access", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_readability(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} is not readable.") + + mock_access = MagicMock() + mock_access.return_value = True + with patch("atat.core.common.file_check.os.access", new=mock_access): + check_path_readability(path) + self.assertEqual(mock_access.call_args[0], (path, os.R_OK)) + + @patch.object(logger, "error") + def test_check_path_writability(self, mock_logger_error): + path = "file_path" + with patch("atat.core.common.file_check.os.access", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_writability(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} is not writable.") + + mock_access = MagicMock() + mock_access.return_value = True + with patch("atat.core.common.file_check.os.access", new=mock_access): + check_path_writability(path) + self.assertEqual(mock_access.call_args[0], (path, os.W_OK)) + + @patch.object(logger, "error") + def test_check_path_executable(self, mock_logger_error): + path = "file_path" + with patch("atat.core.common.file_check.os.access", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_executable(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} is not executable.") + + mock_access = MagicMock() + mock_access.return_value = True + with patch("atat.core.common.file_check.os.access", new=mock_access): + check_path_executable(path) + self.assertEqual(mock_access.call_args[0], (path, os.X_OK)) + + @patch.object(logger, "error") + def test_check_other_user_writable(self, mock_logger_error): + class TestStat: + def __init__(self, mode): + self.st_mode = mode + + path = "file_path" + mock_stat = TestStat(0o002) + with patch("atat.core.common.file_check.os.stat", return_value=mock_stat): + with self.assertRaises(FileCheckException) as context: + check_other_user_writable(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} may be insecure " + "because other users have write permissions. ") + + @patch.object(logger, "error") + def test_check_path_owner_consistent(self, mock_logger_error): + file_path = os.path.realpath(__file__) + file_owner = os.stat(file_path).st_uid + with patch("atat.core.common.file_check.os.getuid", return_value=file_owner+1): + with self.assertRaises(FileCheckException) as context: + check_path_owner_consistent(file_path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {file_path} may be insecure " + "because is does not belong to you.") + + @patch.object(logger, "error") + def test_check_path_pattern_vaild(self, mock_logger_error): + path = "path" + mock_re_match = MagicMock() + mock_re_match.return_value = False + with patch("atat.core.common.file_check.re.match", new=mock_re_match): + with self.assertRaises(FileCheckException) as context: + check_path_pattern_vaild(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} contains special characters.") + mock_re_match.assert_called_with(FileCheckConst.FILE_VALID_PATTERN, path) + + @patch.object(logger, "error") + def test_check_file_size(self, mock_logger_error): + file_path = os.path.realpath(__file__) + file_size = os.path.getsize(file_path) + max_size = file_size + with self.assertRaises(FileCheckException) as context: + check_file_size(file_path, max_size) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_TOO_LARGE_ERROR)) + mock_logger_error.assert_called_with(f"The size of file path {file_path} exceeds {max_size} bytes.") + + def test_check_common_file_size(self): + mock_check_file_size = MagicMock() + with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \ + patch("atat.core.common.file_check.check_file_size", new=mock_check_file_size): + for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): + check_common_file_size(suffix) + mock_check_file_size.assert_called_with(suffix, max_size) + + @patch.object(logger, "error") + def test_check_file_suffix(self, mock_logger_error): + file_path = "file_path" + suffix = "suffix" + with self.assertRaises(FileCheckException) as context: + check_file_suffix(file_path, suffix) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) + mock_logger_error.assert_called_with(f"The {file_path} should be a {suffix} file!") + + @patch.object(logger, "error") + def test_check_path_type(self, mock_logger_error): + file_path = "file_path" + + with patch("atat.core.common.file_check.os.path.isfile", return_value=False), \ + patch("atat.core.common.file_check.os.path.isdir", return_value=True): + with self.assertRaises(FileCheckException) as context: + check_path_type(file_path, FileCheckConst.FILE) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) + mock_logger_error.assert_called_with(f"The {file_path} should be a file!") + + with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \ + patch("atat.core.common.file_check.os.path.isdir", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_type(file_path, FileCheckConst.DIR) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) + mock_logger_error.assert_called_with(f"The {file_path} should be a dictionary!") diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/atat/test/core_ut/test_utils.py index 89734f2c57..fae0e4e255 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_utils.py +++ b/debug/accuracy_tools/atat/test/core_ut/test_utils.py @@ -1,13 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +import uuid + from unittest import TestCase -from unittest.mock import patch +from unittest.mock import patch, MagicMock, mock_open -from atat.core.common.utils import check_seed_all, Const, CompareException, check_inplace_op from atat.core.common.log import logger +from atat.core.common.utils import (Const, CompareException, + check_seed_all, + check_inplace_op, + make_dump_path_if_not_exists, + check_mode_valid, + check_switch_valid, + check_dump_mode_valid, + check_summary_mode_valid, + check_summary_only_valid, + check_file_or_directory_path, + check_compare_param, + check_configuration_param, + is_starts_with, + _check_json, + check_json_file, + check_file_size, + check_regex_prefix_format_valid, + get_dump_data_path, + task_dumppath_get) +from atat.core.common.file_check import FileCheckConst class TestUtils(TestCase): @patch.object(logger, "error") - def test_check_seed_all(self, mock_print_error_log): + def test_check_seed_all(self, mock_error): self.assertIsNone(check_seed_all(1234, True)) self.assertIsNone(check_seed_all(0, True)) self.assertIsNone(check_seed_all(Const.MAX_SEED_VALUE, True)) @@ -15,23 +53,23 @@ class TestUtils(TestCase): with self.assertRaises(CompareException) as context: check_seed_all(-1, True) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") with self.assertRaises(CompareException) as context: check_seed_all(Const.MAX_SEED_VALUE + 1, True) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") with self.assertRaises(CompareException) as context: check_seed_all("1234", True) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with("Seed must be integer.") + mock_error.assert_called_with("Seed must be integer.") with self.assertRaises(CompareException) as context: check_seed_all(1234, 1) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with("seed_all mode must be bool.") - + mock_error.assert_called_with("seed_all mode must be bool.") + def test_check_inplace_op(self): test_prefix_1 = "Distributed.broadcast.0.forward.input.0" self.assertTrue(check_inplace_op(test_prefix_1)) @@ -39,3 +77,268 @@ class TestUtils(TestCase): self.assertFalse(check_inplace_op(test_prefix_2)) test_prefix_3 = "Torch.sum.0.backward.output.0" self.assertFalse(check_inplace_op(test_prefix_3)) + + @patch.object(logger, "error") + def test_make_dump_path_if_not_exists(self, mock_error): + file_path = os.path.realpath(__file__) + dirname = os.path.dirname(file_path) + str(uuid.uuid4()) + + def test_mkdir(self, **kwargs): + raise OSError + + if not os.path.exists(dirname): + with patch("atat.core.common.utils.Path.mkdir", new=test_mkdir): + with self.assertRaises(CompareException) as context: + make_dump_path_if_not_exists(dirname) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + + make_dump_path_if_not_exists(file_path) + mock_error.assert_called_with(f"{file_path} already exists and is not a directory.") + + def test_check_mode_valid(self): + with self.assertRaises(ValueError) as context: + check_mode_valid("all", scope="scope") + self.assertEqual(str(context.exception), "scope param set invalid, it's must be a list.") + + with self.assertRaises(ValueError) as context: + check_mode_valid("all", api_list="api_list") + self.assertEqual(str(context.exception), "api_list param set invalid, it's must be a list.") + + mode = "all_list" + with self.assertRaises(CompareException) as context: + check_mode_valid(mode) + self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_MODE) + self.assertEqual(str(context.exception), + f"Current mode '{mode}' is not supported. Please use the field in {Const.DUMP_MODE}") + + mode = "list" + with self.assertRaises(ValueError) as context: + check_mode_valid(mode) + self.assertEqual(str(context.exception), + "set_dump_switch, scope param set invalid, it's should not be an empty list.") + + @patch.object(logger, "error") + def test_check_switch_valid(self, mock_error): + with self.assertRaises(CompareException) as context: + check_switch_valid("Close") + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Please set switch with 'ON' or 'OFF'.") + + @patch.object(logger, "warning") + def test_check_dump_mode_valid(self, mock_warning): + dump_mode = check_dump_mode_valid("all") + mock_warning.assert_called_with("Please set dump_mode as a list.") + self.assertEqual(dump_mode, ["forward", "backward", "input", "output"]) + + with self.assertRaises(ValueError) as context: + check_dump_mode_valid("all_forward") + self.assertEqual(str(context.exception), + "Please set dump_mode as a list containing one or more of the following: " + + "'all', 'forward', 'backward', 'input', 'output'.") + + def test_check_summary_mode_valid(self): + with self.assertRaises(CompareException) as context: + check_summary_mode_valid("MD5") + self.assertEqual(context.exception.code, CompareException.INVALID_SUMMARY_MODE) + self.assertEqual(str(context.exception), "The summary_mode is not valid") + + @patch.object(logger, "error") + def test_check_summary_only_valid(self, mock_error): + summary_only = check_summary_only_valid(True) + self.assertTrue(summary_only) + + with self.assertRaises(CompareException) as context: + check_summary_only_valid("True") + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Params summary_only only support True or False.") + + def test_check_file_or_directory_path(self): + class TestFileChecker: + file_path = "" + path_type = "" + ability = "" + checked = False + + def __init__(self, file_path, path_type, ability=None): + TestFileChecker.file_path = file_path + TestFileChecker.path_type = path_type + TestFileChecker.ability = ability + + def common_check(self): + TestFileChecker.checked = True + + file_path = os.path.realpath(__file__) + dirname = os.path.dirname(file_path) + + with patch("atat.core.common.utils.FileChecker", new=TestFileChecker): + check_file_or_directory_path(file_path, isdir=False) + self.assertTrue(TestFileChecker.checked) + self.assertEqual(TestFileChecker.file_path, file_path) + self.assertEqual(TestFileChecker.path_type, FileCheckConst.FILE) + self.assertEqual(TestFileChecker.ability, FileCheckConst.READ_ABLE) + + TestFileChecker.checked = False + with patch("atat.core.common.utils.FileChecker", new=TestFileChecker): + check_file_or_directory_path(dirname, isdir=True) + self.assertTrue(TestFileChecker.checked) + self.assertEqual(TestFileChecker.file_path, dirname) + self.assertEqual(TestFileChecker.path_type, FileCheckConst.DIR) + self.assertEqual(TestFileChecker.ability, FileCheckConst.WRITE_ABLE) + + @patch.object(logger, "error") + def test_check_compare_param(self, mock_error): + params = { + "npu_json_path": "npu_json_path", + "bench_json_path": "bench_json_path", + "stack_json_path": "stack_json_path", + "npu_dump_data_dir": "npu_dump_data_dir", + "bench_dump_data_dir": "bench_dump_data_dir" + } + + call_args = [ + ("npu_json_path", False), + ("bench_json_path", False), + ("stack_json_path", False), + ("npu_dump_data_dir", True), + ("bench_dump_data_dir", True), + ("output_path", True), + ("npu_json_path", False), + ("bench_json_path", False), + ("stack_json_path", False), + ("output_path", True) + ] + + with self.assertRaises(CompareException) as context: + check_compare_param("npu_json_path", "output_path") + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Invalid input parameters") + + mock_check_file_or_directory_path = MagicMock() + mock_check_json_file = MagicMock() + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.check_json_file", new=mock_check_json_file), \ + patch("atat.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path): + check_compare_param(params, "output_path") + check_compare_param(params, "output_path", summary_compare=False, md5_compare=True) + for i in range(len(call_args)): + self.assertEqual(mock_check_file_or_directory_path.call_args_list[i][0], call_args[i]) + self.assertEqual(len(mock_check_json_file.call_args[0]), 4) + self.assertEqual(mock_check_json_file.call_args[0][0], params) + + @patch.object(logger, "error") + def test_check_configuration_param(self, mock_error): + with self.assertRaises(CompareException) as context: + check_configuration_param(stack_mode="False", auto_analyze=True, fuzzy_match=False) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Invalid input parameters which should be only bool type.") + + def test_is_starts_with(self): + string = "input_slot0" + self.assertFalse(is_starts_with(string, [])) + self.assertFalse(is_starts_with("", ["input"])) + self.assertFalse(is_starts_with(string, ["output"])) + self.assertTrue(is_starts_with(string, ["input", "output"])) + + @patch.object(logger, "error") + def test__check_json(self, mock_error): + class TestOpen: + def __init__(self, string): + self.string = string + + def readline(self): + return self.string + + def seek(self, begin, end): + self.string = str(begin) + "_" + str(end) + + with self.assertRaises(CompareException) as context: + _check_json(TestOpen(""), "test.json") + self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_FILE) + mock_error.assert_called_with("dump file test.json have empty line!") + + handler = TestOpen("jons file\n") + _check_json(handler, "test.json") + self.assertEqual(handler.string, "0_0") + + @patch("atat.core.common.utils._check_json") + def test_check_json_file(self, _mock_check_json): + input_param = { + "npu_json_path": "npu_json_path", + "bench_json_path": "bench_json_path", + "stack_json_path": "stack_json_path" + } + check_json_file(input_param, "npu_json", "bench_json", "stack_json") + self.assertEqual(_mock_check_json.call_args_list[0][0], ("npu_json", "npu_json_path")) + self.assertEqual(_mock_check_json.call_args_list[1][0], ("bench_json", "bench_json_path")) + self.assertEqual(_mock_check_json.call_args_list[2][0], ("stack_json", "stack_json_path")) + + @patch.object(logger, "error") + def test_check_file_size(self, mock_error): + with patch("atat.core.common.utils.os.path.getsize", return_value=120): + with self.assertRaises(CompareException) as context: + check_file_size("input_file", 100) + self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR) + mock_error.assert_called_with("The size (120) of input_file exceeds (100) bytes, tools not support.") + + def test_check_regex_prefix_format_valid(self): + prefix = "A" * 21 + with self.assertRaises(ValueError) as context: + check_regex_prefix_format_valid(prefix) + self.assertEqual(str(context.exception), f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, " + f"while current length is {len(prefix)}") + + prefix = "(prefix)" + with self.assertRaises(ValueError) as context: + check_regex_prefix_format_valid(prefix) + self.assertEqual(str(context.exception), f"prefix contains invalid characters, " + f"prefix pattern {Const.REGEX_PREFIX_PATTERN}") + + @patch("atat.core.common.utils.check_file_or_directory_path") + def test_get_dump_data_path(self, mock_check_file_or_directory_path): + file_path = os.path.realpath(__file__) + dirname = os.path.dirname(file_path) + + dump_data_path, file_is_exist = get_dump_data_path(dirname) + self.assertEqual(mock_check_file_or_directory_path.call_args[0], (dirname, True)) + self.assertEqual(dump_data_path, dirname) + self.assertTrue(file_is_exist) + + @patch.object(logger, "error") + def test_task_dumppath_get(self, mock_error): + input_param = { + "npu_json_path": None, + "bench_json_path": "bench_json_path" + } + npu_json = { + "task": Const.TENSOR, + "dump_data_dir": "dump_data_dir", + "data": "data" + } + + with self.assertRaises(CompareException) as context: + task_dumppath_get(input_param) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + mock_error.assert_called_with("Please check the json path is valid.") + + input_param["npu_json_path"] = "npu_json_path" + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.json.load", return_value=npu_json): + summary_compare, md5_compare = task_dumppath_get(input_param) + self.assertFalse(summary_compare) + self.assertFalse(md5_compare) + + npu_json["task"] = Const.STATISTICS + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.json.load", return_value=npu_json), \ + patch("atat.core.common.utils.md5_find", return_value=True): + summary_compare, md5_compare = task_dumppath_get(input_param) + self.assertFalse(summary_compare) + self.assertTrue(md5_compare) + + npu_json["task"] = Const.OVERFLOW_CHECK + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.json.load", return_value=npu_json): + with self.assertRaises(CompareException) as context: + task_dumppath_get(input_param) + self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR) + mock_error.assert_called_with("Compare is not required for overflow_check or free_benchmark.") -- Gitee From 2b0d9d3c71fb704e2af1e06e08e8f4e115281f5b Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Thu, 18 Jul 2024 15:32:53 +0800 Subject: [PATCH 2/2] change Const use --- .../api_accuracy_checker/compare/compare.py | 82 +----- .../accuracy_tools/atat/core/common/const.py | 237 ++++++++++++++++++ .../atat/core/common/file_check.py | 38 +-- .../accuracy_tools/atat/core/common/utils.py | 212 +--------------- .../accuracy_tools/atat/core/common_config.py | 2 +- .../atat/core/data_dump/data_collector.py | 2 +- .../core/data_dump/data_processor/base.py | 3 +- .../core/data_dump/data_processor/factory.py | 2 +- .../data_processor/pytorch_processor.py | 4 +- .../atat/core/data_dump/json_writer.py | 4 +- .../atat/core/data_dump/scope.py | 2 +- .../atat/pytorch/advisor/advisor.py | 6 +- .../atat/pytorch/advisor/advisor_result.py | 4 +- .../api_accuracy_checker/common/utils.py | 4 +- .../api_accuracy_checker/compare/algorithm.py | 2 +- .../compare/api_precision_compare.py | 5 +- .../api_accuracy_checker/compare/compare.py | 92 +------ .../compare/compare_column.py | 2 +- .../compare/compare_utils.py | 18 +- .../run_ut/data_generate.py | 3 +- .../run_ut/multi_run_ut.py | 3 +- .../api_accuracy_checker/run_ut/run_ut.py | 16 +- .../atat/pytorch/compare/acc_compare.py | 7 +- .../atat/pytorch/compare/highlight.py | 3 +- .../atat/pytorch/compare/npy_compare.py | 3 +- .../atat/pytorch/debugger/debugger_config.py | 2 +- .../atat/pytorch/free_benchmark/__init__.py | 2 +- .../atat/pytorch/free_benchmark/main.py | 3 +- .../perturbed_layers/npu/improve_precision.py | 3 +- .../result_handlers/base_handler.py | 3 +- .../atat/pytorch/functional/dump_module.py | 2 +- .../atat/pytorch/hook_module/api_registry.py | 3 +- .../atat/pytorch/hook_module/hook_module.py | 2 +- .../atat/pytorch/hook_module/wrap_aten.py | 3 +- .../pytorch/hook_module/wrap_distributed.py | 3 +- .../pytorch/hook_module/wrap_functional.py | 3 +- .../pytorch/hook_module/wrap_npu_custom.py | 3 +- .../atat/pytorch/hook_module/wrap_tensor.py | 3 +- .../atat/pytorch/hook_module/wrap_torch.py | 3 +- .../atat/pytorch/hook_module/wrap_vf.py | 3 +- .../atat/pytorch/module_processer.py | 2 +- .../atat/pytorch/online_dispatch/compare.py | 3 +- .../atat/pytorch/online_dispatch/dispatch.py | 4 +- .../pytorch/online_dispatch/dump_compare.py | 5 +- .../atat/pytorch/online_dispatch/utils.py | 15 +- .../atat/pytorch/parse_tool/lib/utils.py | 2 +- .../accuracy_tools/atat/pytorch/pt_config.py | 2 +- debug/accuracy_tools/atat/pytorch/service.py | 4 +- .../atat/test/core_ut/test_file_check.py | 4 +- .../atat/test/core_ut/test_utils.py | 3 +- .../atat/test/mindspore_ut/test_ms_config.py | 2 +- .../compare/test_api_precision_compare.py | 2 +- .../perturbed_layers/test_perturbed_layser.py | 2 +- .../result_handlers/test_result_handler.py | 2 +- .../pytorch_ut/free_benchmark/test_main.py | 2 +- .../atat/test/pytorch_ut/test_pt_config.py | 2 +- .../grad_tool/common/base_comparator.py | 24 +- .../grad_tool/grad_pt/grad_monitor.py | 4 +- .../tensorboard-plugins/{ OWNERS => OWNERS} | 18 +- profiler/advisor/README.md | 12 +- .../compare_backend/utils/constant.py | 2 +- profiler/module_visualization/__init__.py | 0 .../module_visualization/graph/__init__.py | 0 .../module_visualization/graph/prof_node.py | 90 +++++++ .../graph_build/__init__.py | 0 .../graph_build/fwd_module_node.py | 29 +++ .../graph_build/prof_graph_builder.py | 115 +++++++++ .../module_visualization/prof_graph_export.py | 39 +++ .../prof_parse/__init__.py | 0 .../prof_parse/prof_data_pre_process.py | 102 ++++++++ profiler/prof_common/base_node.py | 78 ++++++ profiler/prof_common/constant.py | 15 +- profiler/prof_common/file_reader.py | 59 +++++ profiler/prof_common/path_manager.py | 191 ++++++++++++++ profiler/prof_common/trace_event_bean.py | 69 +++++ profiler/prof_common/tree_builder.py | 33 +++ profiler/prof_common/utils.py | 25 ++ 77 files changed, 1220 insertions(+), 538 deletions(-) create mode 100644 debug/accuracy_tools/atat/core/common/const.py rename plugins/tensorboard-plugins/{ OWNERS => OWNERS} (93%) create mode 100644 profiler/module_visualization/__init__.py create mode 100644 profiler/module_visualization/graph/__init__.py create mode 100644 profiler/module_visualization/graph/prof_node.py create mode 100644 profiler/module_visualization/graph_build/__init__.py create mode 100644 profiler/module_visualization/graph_build/fwd_module_node.py create mode 100644 profiler/module_visualization/graph_build/prof_graph_builder.py create mode 100644 profiler/module_visualization/prof_graph_export.py create mode 100644 profiler/module_visualization/prof_parse/__init__.py create mode 100644 profiler/module_visualization/prof_parse/prof_data_pre_process.py create mode 100644 profiler/prof_common/base_node.py create mode 100644 profiler/prof_common/file_reader.py create mode 100644 profiler/prof_common/path_manager.py create mode 100644 profiler/prof_common/trace_event_bean.py create mode 100644 profiler/prof_common/tree_builder.py create mode 100644 profiler/prof_common/utils.py diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index c7b175cd49..1b79635904 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,13 +1,10 @@ # 进行比对及结果展示 import os -import csv from collections import namedtuple import torch import numpy as np -from rich.table import Table -from rich.console import Console -from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_warn_log, Const +from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_info_log, Const from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \ precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, \ ThousandthStandardApi, apis_threshold @@ -17,7 +14,6 @@ from api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err from api_accuracy_checker.common.config import msCheckerConfig -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status', @@ -49,83 +45,13 @@ class Comparator: else: self.stack_info = None - self.test_result_cnt = { - "success_num": 0, "warning_num": 0, "error_num": 0, - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, - "total_num": 0, "total_skip_num": 0 - } - @staticmethod def get_path_from_rank(rank, path_list, path_pattern): return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank) - def print_pretest_result(self): - for save_path in self.save_path_list: - self.get_statistics_from_result_csv(save_path) - total_tests = self.test_result_cnt.get("total_num", 0) - if total_tests != 0: - passing_rate = "{:.2%}".format(self.test_result_cnt.get("success_num", 0) / total_tests) - else: - passing_rate = "0%" - - print_warn_log("The follwing tables will be deprecated in the future." - "The following results are for reference only.") - console = Console() - table_total = Table( - show_header=True, title="Overall Statistics", show_lines=True, width=75 - ) - table_total.add_column("Result") - table_total.add_column("Statistics") - table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num", 0))) - table_total.add_row("[yellow]Warning[/yellow]", str(self.test_result_cnt.get("warning_num", 0))) - table_total.add_row("[red]Error[/red]", str(self.test_result_cnt.get("error_num", 0))) - table_total.add_row("Passing Rate", passing_rate) - table_total.add_row("Skip Tests", str(self.test_result_cnt.get("total_skip_num", 0))) - - table_detail = Table( - show_header=True, title="Detail Statistics", show_lines=True, width=75 - ) - table_detail.add_column("Result") - table_detail.add_column("Statistics") - table_detail.add_row("Forward Error", str(self.test_result_cnt.get("forward_fail_num", 0))) - table_detail.add_row("Backward Error", str(self.test_result_cnt.get("backward_fail_num", 0))) - table_detail.add_row("Both Forward & Backward Error", str(self.test_result_cnt.get("forward_and_backward_fail_num", 0))) - - console.print(table_total) - console.print(table_detail) - - def get_statistics_from_result_csv(self, save_path): - checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.SPACE, CompareConst.SKIP, "skip"] - with FileOpen(save_path, 'r') as file: - reader = csv.reader(file) - result_csv_rows = [row for row in reader] - result_csv_name = os.path.basename(save_path) - for item in result_csv_rows[1:]: - if not isinstance(item, list) or len(item) < 3: - raise ValueError("The number of columns in %s is incorrect" % result_csv_name) - if not all(item[i] and item[i] in checklist for i in (1, 2)): - raise ValueError( - "The value in the 2nd or 3rd column of %s is wrong, it must be pass, error, warning, skip, or SPACE" - % result_csv_name) - column1 = item[1] - column2 = item[2] - if column1.upper() == CompareConst.SKIP: - self.test_result_cnt["total_skip_num"] += 1 - continue - self.test_result_cnt["total_num"] += 1 - if column1 == CompareConst.PASS and column2 in [CompareConst.PASS, CompareConst.SPACE, CompareConst.SKIP]: - self.test_result_cnt['success_num'] += 1 - elif column1 == CompareConst.ERROR and column2 == CompareConst.ERROR: - self.test_result_cnt['forward_and_backward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column1 == CompareConst.ERROR: - self.test_result_cnt['forward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column2 == CompareConst.ERROR: - self.test_result_cnt['backward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column1 == CompareConst.WARNING or column2 == CompareConst.WARNING: - self.test_result_cnt['warning_num'] += 1 + @staticmethod + def print_pretest_result(): + print_info_log("Successfully completed run_ut/multi_run_ut.") def write_csv_title(self): summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, diff --git a/debug/accuracy_tools/atat/core/common/const.py b/debug/accuracy_tools/atat/core/common/const.py new file mode 100644 index 0000000000..89de3a4e5a --- /dev/null +++ b/debug/accuracy_tools/atat/core/common/const.py @@ -0,0 +1,237 @@ +import os +import stat +import numpy as np + +class Const: + """ + Class for const + """ + SEP = "." + REGEX_PREFIX_MAX_LENGTH = 20 + REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$" + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + COMMA = "," + FLOAT_EPSILON = np.finfo(float).eps + OFF = 'OFF' + BACKWARD = 'backward' + FORWARD = 'forward' + + # dump mode + ALL = "all" + LIST = "list" + RANGE = "range" + STACK = "stack" + ACL = "acl" + API_LIST = "api_list" + API_STACK = "api_stack" + DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] + SUMMARY = "summary" + MD5 = "md5" + SUMMARY_MODE = [ALL, SUMMARY, MD5] + + WRITE_FLAGS = os.O_WRONLY | os.O_CREAT + WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR + OVERWRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024 + TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024 + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + DISTRIBUTED_PREFIX_LENGTH = 60 + # env dump path + KWARGS = 'kwargs' + INPUT = 'input' + OUTPUT = 'output' + INPUT_ARGS = 'input_args' + INPUT_KWARGS = 'input_kwargs' + GRAD_INPUT = 'grad_input' + GRAD_OUTPUT = 'grad_output' + START = "start" + STOP = "stop" + ENV_ENABLE = "1" + ENV_DISABLE = "0" + MAX_SEED_VALUE = 4294967295 # 2**32 - 1 + TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"] + LEVEL_LIST = ["L0", "L1", "L2", "mix"] + STATISTICS = "statistics" + TENSOR = "tensor" + OVERFLOW_CHECK = "overflow_check" + FREE_BENCHMARK = "free_benchmark" + ATTR_NAME_PREFIX = "wrap_" + KERNEL_DUMP = "kernel_dump" + DATA = "data" + PT_FRAMEWORK = "pytorch" + MS_FRAMEWORK = "mindspore" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16] + BOOL_TYPE = [bool, np.uint8] + INT_TYPE = [np.int32, np.int64] + NPU = 'NPU' + DISTRIBUTED = 'Distributed' + + INPLACE_LIST = [ + "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", + "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single" + ] + + CONVERT = { + "int32_to_int64": ["torch.int32", "torch.int64"], + } + + CONVERT_API = { + "int32_to_int64": ["cross_entropy"] + } + +class CompareConst: + """ + Class for compare module const + """ + SPACE = " " + # compare result column name + NPU_NAME = "NPU Name" + BENCH_NAME = "Bench Name" + NPU_DTYPE = "NPU Dtype" + BENCH_DTYPE = "Bench Dtype" + NPU_SHAPE = "NPU Tensor Shape" + BENCH_SHAPE = "Bench Tensor Shape" + NPU_MAX = "NPU max" + NPU_MIN = "NPU min" + NPU_MEAN = "NPU mean" + NPU_NORM = "NPU l2norm" + BENCH_MAX = "Bench max" + BENCH_MIN = "Bench min" + BENCH_MEAN = "Bench mean" + BENCH_NORM = "Bench l2norm" + MAX_DIFF = "Max diff" + MIN_DIFF = "Min diff" + MEAN_DIFF = "Mean diff" + NORM_DIFF = "L2norm diff" + COSINE = "Cosine" + MAX_ABS_ERR = "MaxAbsErr" + MAX_RELATIVE_ERR = "MaxRelativeErr" + MIN_RELATIVE_ERR = "MinRelativeErr" + MEAN_RELATIVE_ERR = "MeanRelativeErr" + NORM_RELATIVE_ERR = "NormRelativeErr" + ACCURACY = "Accuracy Reached or Not" + STACK = "NPU_Stack_Info" + DATA_NAME = "Data_name" + ERROR_MESSAGE = "Err_message" + ONE_THOUSANDTH_ERR_RATIO = "One Thousandth Err Ratio" + FIVE_THOUSANDTHS_ERR_RATIO = "Five Thousandths Err Ratio" + NPU_MD5 = "NPU MD5" + BENCH_MD5 = "BENCH MD5" + RESULT = "Result" + + COMPARE_RESULT_HEADER = [ + NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, + ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, + NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE + ] + + SUMMARY_COMPARE_RESULT_HEADER = [ + NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF, + MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR, + NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, RESULT, ERROR_MESSAGE + ] + + MD5_COMPARE_RESULT_HEADER = [ + NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT + ] + + # compare standard + THOUSAND_RATIO_THRESHOLD = 0.001 + FIVE_THOUSAND_RATIO_THRESHOLD = 0.005 + COSINE_THRESHOLD = 0.9999 + + # compare result data + READ_NONE = 'No data' + NONE = 'None' + SHAPE_UNMATCH = 'shape unmatched' + DIFF = 'Different' + UNSUPPORTED = 'unsupported' + NAN = 'Nan' + PASS = 'pass' + WARNING = 'Warning' + ERROR = 'error' + SKIP = 'SKIP' + BFLOAT16_MIN = -3.3895313892515355e+38 + BFLOAT16_MAX = 3.3895313892515355e+38 + BFLOAT16_EPS = 3.90625e-3 # 2 ** -8 + + # accuracy standards + COS_THRESHOLD = 0.99 + MAX_ABS_ERR_THRESHOLD = 0.001 + COS_MAX_THRESHOLD = 0.9 + MAX_ABS_ERR_MAX_THRESHOLD = 1 + ACCURACY_CHECK_YES = "Yes" + ACCURACY_CHECK_NO = "No" + ACCURACY_CHECK_UNMATCH = "Unmatched" + + # error message + NO_BENCH = "No bench data matched." + + # compare const + FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble] + + # highlight xlsx color const + RED = "FFFF0000" + YELLOW = "FFFF00" + BLUE = "0000FF" + + # highlight rules const + OVERFLOW_LIST = ['nan\t', 'inf\t', '-inf\t', 'nan', 'inf', '-inf'] + MAX_DIFF_RED = 1e+10 + ORDER_MAGNITUDE_DIFF_YELLOW = 1 + ONE_THOUSAND_ERROR_IN_RED = 0.9 + ONE_THOUSAND_ERROR_OUT_RED = 0.6 + ONE_THOUSAND_ERROR_DIFF_YELLOW = 0.1 + COSINE_DIFF_YELLOW = 0.1 + MAX_RELATIVE_OUT_RED = 0.5 + MAX_RELATIVE_OUT_YELLOW = 0.1 + MAX_RELATIVE_IN_YELLOW = 0.01 + +class FileCheckConst: + """ + Class for file check const + """ + READ_ABLE = "read" + WRITE_ABLE = "write" + READ_WRITE_ABLE = "read and write" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + JSON_SUFFIX = ".json" + PT_SUFFIX = ".pt" + CSV_SUFFIX = ".csv" + YAML_SUFFIX = ".yaml" + MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_YAML_SIZE = 1048576 # 10 * 1024 * 1024 + DIR = "dir" + FILE = "file" + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + FILE_SIZE_DICT = { + PKL_SUFFIX: MAX_PKL_SIZE, + NUMPY_SUFFIX: MAX_NUMPY_SIZE, + JSON_SUFFIX: MAX_JSON_SIZE, + PT_SUFFIX: MAX_PT_SIZE, + CSV_SUFFIX: MAX_CSV_SIZE, + YAML_SUFFIX: MAX_YAML_SIZE + } + +class OverflowConst: + """ + Class for Overflow + """ + OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE" + OVERFLOW_ORIGINAL_MODE = 0 + OVERFLOW_DEBUG_MODE = 1 diff --git a/debug/accuracy_tools/atat/core/common/file_check.py b/debug/accuracy_tools/atat/core/common/file_check.py index 43207e85e7..2df825aa35 100644 --- a/debug/accuracy_tools/atat/core/common/file_check.py +++ b/debug/accuracy_tools/atat/core/common/file_check.py @@ -19,43 +19,7 @@ import re from atat.core.common.log import logger from atat.core.common.exceptions import FileCheckException - - -class FileCheckConst: - """ - Class for file check const - """ - READ_ABLE = "read" - WRITE_ABLE = "write" - READ_WRITE_ABLE = "read and write" - DIRECTORY_LENGTH = 4096 - FILE_NAME_LENGTH = 255 - FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" - FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' - PKL_SUFFIX = ".pkl" - NUMPY_SUFFIX = ".npy" - JSON_SUFFIX = ".json" - PT_SUFFIX = ".pt" - CSV_SUFFIX = ".csv" - YAML_SUFFIX = ".yaml" - MAX_PKL_SIZE = 1 * 1024 * 1024 * 1024 - MAX_NUMPY_SIZE = 10 * 1024 * 1024 * 1024 - MAX_JSON_SIZE = 1 * 1024 * 1024 * 1024 - MAX_PT_SIZE = 10 * 1024 * 1024 * 1024 - MAX_CSV_SIZE = 1 * 1024 * 1024 * 1024 - MAX_YAML_SIZE = 10 * 1024 * 1024 - DIR = "dir" - FILE = "file" - DATA_DIR_AUTHORITY = 0o750 - DATA_FILE_AUTHORITY = 0o640 - FILE_SIZE_DICT = { - PKL_SUFFIX: MAX_PKL_SIZE, - NUMPY_SUFFIX: MAX_NUMPY_SIZE, - JSON_SUFFIX: MAX_JSON_SIZE, - PT_SUFFIX: MAX_PT_SIZE, - CSV_SUFFIX: MAX_CSV_SIZE, - YAML_SUFFIX: MAX_YAML_SIZE - } +from atat.core.common.const import FileCheckConst class FileChecker: diff --git a/debug/accuracy_tools/atat/core/common/utils.py b/debug/accuracy_tools/atat/core/common/utils.py index 0c74bf038d..088530f3c5 100644 --- a/debug/accuracy_tools/atat/core/common/utils.py +++ b/debug/accuracy_tools/atat/core/common/utils.py @@ -26,7 +26,8 @@ from datetime import datetime, timezone from pathlib import Path import numpy as np -from atat.core.common.file_check import FileOpen, FileChecker, FileCheckConst +from atat.core.common.file_check import FileOpen, FileChecker +from atat.core.common.const import Const, FileCheckConst, CompareConst, OverflowConst from atat.core.common.log import logger @@ -34,206 +35,6 @@ device = collections.namedtuple('device', ['type', 'index']) prefixes = ['api_stack', 'list', 'range', 'acl'] -class Const: - """ - Class for const - """ - SEP = "." - MODEL_TYPE = ['.onnx', '.pb', '.om'] - DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*" - REGEX_PREFIX_MAX_LENGTH = 20 - REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$" - SEMICOLON = ";" - COLON = ":" - EQUAL = "=" - COMMA = "," - DOT = "." - DUMP_RATIO_MAX = 100 - SUMMERY_DATA_NUMS = 256 - FLOAT_EPSILON = np.finfo(float).eps - SUPPORT_DUMP_MODE = ['api', 'acl'] - ON = 'ON' - OFF = 'OFF' - BACKWARD = 'backward' - FORWARD = 'forward' - PRE_FORWARD = "pre_forward" - - # dump mode - ALL = "all" - LIST = "list" - RANGE = "range" - STACK = "stack" - ACL = "acl" - API_LIST = "api_list" - API_STACK = "api_stack" - DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] - AUTO = "auto" - ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF] - SUMMARY = "summary" - MD5 = "md5" - SUMMARY_MODE = [ALL, SUMMARY, MD5] - - WRITE_FLAGS = os.O_WRONLY | os.O_CREAT - WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR - OVERWRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC - - PKL_SUFFIX = ".pkl" - NUMPY_SUFFIX = ".npy" - ONE_GB = 1 * 1024 * 1024 * 1024 - TEN_GB = 10 * 1024 * 1024 * 1024 - FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' - FILE_NAME_LENGTH = 255 - DIRECTORY_LENGTH = 4096 - DISTRIBUTED_PREFIX_LENGTH = 60 - SUMMARY_COLUMN_NUM = 6 - STACK_COLUMN_NUM = 2 - # env dump path - ASCEND_WORK_PATH = "ASCEND_WORK_PATH" - DUMP_DIR = "dump_data" - - KWARGS = 'kwargs' - INPUT = 'input' - OUTPUT = 'output' - INPUT_ARGS = 'input_args' - INPUT_KWARGS = 'input_kwargs' - GRAD_INPUT = 'grad_input' - GRAD_OUTPUT = 'grad_output' - START = "start" - STOP = "stop" - ENV_ENABLE = "1" - ENV_DISABLE = "0" - - MAX_SEED_VALUE = 2**32 - 1 - - INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base"] - - TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"] - LEVEL_LIST = ["L0", "L1", "L2", "mix"] - STATISTICS = "statistics" - TENSOR = "tensor" - OVERFLOW_CHECK = "overflow_check" - FREE_BENCHMARK = "free_benchmark" - KERNEL_DUMP = "kernel_dump" - DATA = "data" - PT_FRAMEWORK = "pytorch" - MS_FRAMEWORK = "mindspore" - DIRECTORY_LENGTH = 4096 - FILE_NAME_LENGTH = 255 - FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' - FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16] - BOOL_TYPE = [bool, np.uint8] - INT_TYPE = [np.int32, np.int64] - NPU = 'NPU' - DISTRIBUTED = 'Distributed' - INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base", "all_to_all_single"] - - -class CompareConst: - """ - Class for compare module const - """ - # compare result column name - NPU_NAME = "NPU Name" - BENCH_NAME = "Bench Name" - NPU_DTYPE = "NPU Dtype" - BENCH_DTYPE = "Bench Dtype" - NPU_SHAPE = "NPU Tensor Shape" - BENCH_SHAPE = "Bench Tensor Shape" - NPU_MAX = "NPU max" - NPU_MIN = "NPU min" - NPU_MEAN = "NPU mean" - NPU_NORM = "NPU l2norm" - BENCH_MAX = "Bench max" - BENCH_MIN = "Bench min" - BENCH_MEAN = "Bench mean" - BENCH_NORM = "Bench l2norm" - MAX_DIFF = "Max diff" - MIN_DIFF = "Min diff" - MEAN_DIFF = "Mean diff" - NORM_DIFF = "L2norm diff" - COSINE = "Cosine" - MAX_ABS_ERR = "MaxAbsErr" - MAX_RELATIVE_ERR = "MaxRelativeErr" - MIN_RELATIVE_ERR = "MinRelativeErr" - MEAN_RELATIVE_ERR = "MeanRelativeErr" - NORM_RELATIVE_ERR = "NormRelativeErr" - ACCURACY = "Accuracy Reached or Not" - STACK = "NPU_Stack_Info" - DATA_NAME = "Data_name" - ERROR_MESSAGE = "Err_message" - ONE_THOUSANDTH_ERR_RATIO = "One Thousandth Err Ratio" - FIVE_THOUSANDTHS_ERR_RATIO = "Five Thousandths Err Ratio" - NPU_MD5 = "NPU MD5" - BENCH_MD5 = "BENCH MD5" - RESULT = "Result" - - COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, - ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, - NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE - ] - - SUMMARY_COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF, - MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR, - NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, RESULT, ERROR_MESSAGE - ] - - MD5_COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT - ] - - # compare standard - THOUSAND_RATIO_THRESHOLD = 0.001 - FIVE_THOUSAND_RATIO_THRESHOLD = 0.005 - COSINE_THRESHOLD = 0.9999 - - # compare result data - READ_NONE = 'No data' - NAN = 'Nan' - NONE = 'None' - SHAPE_UNMATCH = 'shape unmatched' - DTYPE_UNMATCH = 'dtype unmatched' - PASS = 'Pass' - WARNING = 'Warning' - DIFF = 'Different' - UNSUPPORTED = 'unsupported' - - # accuracy standards - COS_THRESHOLD = 0.99 - MAX_ABS_ERR_THRESHOLD = 0.001 - COS_MAX_THRESHOLD = 0.9 - MAX_ABS_ERR_MAX_THRESHOLD = 1 - ACCURACY_CHECK_YES = "Yes" - ACCURACY_CHECK_NO = "No" - ACCURACY_CHECK_UNMATCH = "Unmatched" - - # error message - NO_BENCH = "No bench data matched." - - # compare const - FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble] - - # highlight xlsx color const - RED = "FFFF0000" - YELLOW = "FFFF00" - BLUE = "0000FF" - - # highlight rules const - OVERFLOW_LIST = ['nan\t', 'inf\t', '-inf\t', 'nan', 'inf', '-inf'] - MAX_DIFF_RED = 1e+10 - ORDER_MAGNITUDE_DIFF_YELLOW = 1 - ONE_THOUSAND_ERROR_IN_RED = 0.9 - ONE_THOUSAND_ERROR_OUT_RED = 0.6 - ONE_THOUSAND_ERROR_DIFF_YELLOW = 0.1 - COSINE_DIFF_YELLOW = 0.1 - MAX_RELATIVE_OUT_RED = 0.5 - MAX_RELATIVE_OUT_YELLOW = 0.1 - MAX_RELATIVE_IN_YELLOW = 0.01 - - class CompareException(Exception): """ Class for Accuracy Compare Exception @@ -273,15 +74,6 @@ class DumpException(CompareException): pass -class OverflowConst: - """ - Class for Overflow - """ - OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE" - OVERFLOW_ORIGINAL_MODE = 0 - OVERFLOW_DEBUG_MODE = 1 - - def make_dump_path_if_not_exists(dump_path): if not os.path.exists(dump_path): try: diff --git a/debug/accuracy_tools/atat/core/common_config.py b/debug/accuracy_tools/atat/core/common_config.py index bc4ffd8090..e256372ca8 100644 --- a/debug/accuracy_tools/atat/core/common_config.py +++ b/debug/accuracy_tools/atat/core/common_config.py @@ -1,4 +1,4 @@ -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.core.common.log import logger from atat.core.common.exceptions import MsaccException diff --git a/debug/accuracy_tools/atat/core/data_dump/data_collector.py b/debug/accuracy_tools/atat/core/data_dump/data_collector.py index 2a0bc34ba8..f6a9a70b13 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_collector.py @@ -4,7 +4,7 @@ import os from atat.core.data_dump.scope import build_scope, ListScope from atat.core.data_dump.json_writer import DataWriter from atat.core.common.log import logger -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.core.data_dump.data_processor.factory import DataProcessorFactory diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py index 1ee3314b36..208c053192 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py @@ -4,7 +4,8 @@ from dataclasses import dataclass from typing import Tuple, Dict, Optional, Any import numpy as np from atat.core.common.log import logger -from atat.core.common.utils import Const, convert_tuple +from atat.core.common.utils import convert_tuple +from atat.core.common.const import Const @dataclass diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py index 00f2f72e7a..bcc771f368 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py @@ -1,4 +1,4 @@ -from atat.core.common.utils import Const +from atat.core.common.const import Const class DataProcessorFactory: diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py index 9f96635e9a..cf3c5ebe58 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py @@ -6,9 +6,9 @@ from typing import List import numpy as np import torch from atat.core.common.exceptions import MsaccException -from atat.core.common.file_check import path_len_exceeds_limit, change_mode, FileCheckConst +from atat.core.common.file_check import path_len_exceeds_limit, change_mode from atat.core.common.log import logger -from atat.core.common.utils import Const, OverflowConst +from atat.core.common.const import Const, OverflowConst, FileCheckConst from atat.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo from atat.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow diff --git a/debug/accuracy_tools/atat/core/data_dump/json_writer.py b/debug/accuracy_tools/atat/core/data_dump/json_writer.py index dd0d2f9c7b..23f37b2342 100644 --- a/debug/accuracy_tools/atat/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/atat/core/data_dump/json_writer.py @@ -4,9 +4,9 @@ import fcntl import json from pathlib import Path -from atat.core.common.file_check import FileCheckConst, change_mode +from atat.core.common.file_check import change_mode from atat.core.common.log import logger -from atat.core.common.utils import Const +from atat.core.common.const import Const, FileCheckConst class DataWriter: diff --git a/debug/accuracy_tools/atat/core/data_dump/scope.py b/debug/accuracy_tools/atat/core/data_dump/scope.py index dc473d7e14..e7114f343f 100644 --- a/debug/accuracy_tools/atat/core/data_dump/scope.py +++ b/debug/accuracy_tools/atat/core/data_dump/scope.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from atat.core.common.exceptions import ScopeException -from atat.core.common.utils import Const +from atat.core.common.const import Const def build_scope(scope_class, scope=None, api_list=None): diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py index f4cb441f5e..43b3f40f97 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py +++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py @@ -20,9 +20,9 @@ import os from atat.pytorch.advisor.advisor_result import AdvisorResult from atat.pytorch.advisor.advisor_const import AdvisorConst from atat.pytorch.common.log import logger -from atat.core.common.utils import CompareException, CompareConst, Const -from atat.core.common.file_check import FileChecker, FileCheckConst - +from atat.core.common.utils import CompareException +from atat.core.common.file_check import FileChecker +from atat.core.common.const import Const, CompareConst, FileCheckConst class Advisor: """ diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py index 59845a7541..a24fa2a115 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py +++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py @@ -19,8 +19,8 @@ import time from atat.pytorch.advisor.advisor_const import AdvisorConst from atat.pytorch.common.log import logger -from atat.core.common.utils import Const -from atat.core.common.file_check import FileCheckConst, change_mode +from atat.core.common.const import Const, FileCheckConst +from atat.core.common.file_check import change_mode class AdvisorResult: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py index 022edbfcf3..9e1b02c015 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py @@ -29,8 +29,8 @@ else: IS_GPU = False from atat.pytorch.common.log import logger -from atat.core.common.file_check import FileCheckConst, FileChecker, FileOpen, change_mode, create_directory -from atat.pytorch.common.utils import Const +from atat.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory +from atat.core.common.const import Const, FileCheckConst from atat.core.common.utils import CompareException diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py index a450edb929..3f13534a5a 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py @@ -1,7 +1,7 @@ # 定义比对算法及比对标准 import torch import numpy as np -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst +from atat.core.common.const import CompareConst #cos diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 7e0617eb3a..89ed3a1008 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -8,15 +8,16 @@ import pandas as pd from atat.pytorch.api_accuracy_checker.common.utils import write_csv from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst, API_PRECISION_COMPARE_RESULT_FILE_NAME, \ +from atat.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \ API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \ ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, BINARY_COMPARE_UNSUPPORT_LIST, \ convert_str_to_float, CompareMessage from atat.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn from atat.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path -from atat.core.common.file_check import FileCheckConst, FileChecker, change_mode, check_path_before_create, create_directory +from atat.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory from atat.pytorch.common.log import logger from atat.core.common.utils import CompareException +from atat.core.common.const import CompareConst, FileCheckConst CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) unsupported_message = 'This data type does not support benchmark compare.' diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py index fbba1dca00..cfc783bd75 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py @@ -1,13 +1,10 @@ # 进行比对及结果展示 import os -import csv import torch import numpy as np -from rich.table import Table -from rich.console import Console from atat.pytorch.common.log import logger -from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv, Const -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, \ +from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv +from atat.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \ DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, \ apis_threshold from atat.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn @@ -16,7 +13,7 @@ from atat.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_er get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ check_small_value, check_norm_value, get_abs_bench_with_eps from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig -from atat.core.common.file_check import FileOpen +from atat.core.common.const import Const, CompareConst class Comparator: @@ -36,11 +33,10 @@ class Comparator: else: self.stack_info = None - self.test_result_cnt = { - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0, - "total_num": 0, "forward_or_backward_fail_num": 0 - } - + @staticmethod + def print_pretest_result(): + logger.info("Successfully completed run_ut/multi_run_ut.") + @staticmethod def _compare_dropout(bench_output, device_output): tensor_num = bench_output.numel() @@ -77,80 +73,6 @@ class Comparator: rtol = apis_threshold.get(api_name).get(dtype).get('rtol') return small_value_threshold, small_value_atol, rtol - def print_pretest_result(self): - self.get_statistics_from_result_csv() - total_tests = self.test_result_cnt.get("total_num", 0) - if total_tests != 0: - passing_rate = "{:.2%}".format(self.test_result_cnt.get("success_num", 0) / total_tests) - else: - passing_rate = "0%" - - logger.warning("The follwing tables will be deprecated in the future." - "The following results are for reference only.") - console = Console() - table_total = Table( - show_header=True, title="Overall Statistics", show_lines=True, width=75 - ) - table_total.add_column("Result") - table_total.add_column("Statistics") - table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num", 0))) - table_total.add_row("[yellow]Warning[/yellow]", str(self.test_result_cnt.get("warning_num", 0))) - table_total.add_row("[red]Error[/red]", str(self.test_result_cnt.get("error_num", 0))) - table_total.add_row("Passing Rate", passing_rate) - table_total.add_row("Skip Tests", str(self.test_result_cnt.get("total_skip_num", 0))) - - table_detail = Table( - show_header=True, title="Detail Statistics", show_lines=True, width=75 - ) - table_detail.add_column("Result") - table_detail.add_column("Statistics") - table_detail.add_row("Forward Error", str(self.test_result_cnt.get("forward_fail_num", 0))) - table_detail.add_row("Backward Error", str(self.test_result_cnt.get("backward_fail_num", 0))) - table_detail.add_row("Both Forward & Backward Error", - str(self.test_result_cnt.get("forward_and_backward_fail_num", 0))) - - console.print(table_total) - console.print(table_detail) - - def get_statistics_from_result_csv(self): - checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.SPACE, CompareConst.SKIP, - "skip"] - self.test_result_cnt = { - "success_num": 0, "warning_num": 0, "error_num": 0, - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, - "total_num": 0, "total_skip_num": 0 - } - with FileOpen(self.save_path, 'r') as file: - reader = csv.reader(file) - result_csv_rows = [row for row in reader] - result_csv_name = os.path.basename(self.save_path) - for item in result_csv_rows[1:]: - if not isinstance(item, list) or len(item) < 3: - raise ValueError("The number of columns in %s is incorrect" % result_csv_name) - if not all(item[i] and item[i] in checklist for i in (1, 2)): - raise ValueError( - "The value in the 2nd or 3rd column of %s is wrong, it must be pass, error, warning, skip, or SPACE" - % result_csv_name) - column1 = item[1] - column2 = item[2] - if column1.upper() == CompareConst.SKIP: - self.test_result_cnt["total_skip_num"] += 1 - continue - self.test_result_cnt["total_num"] += 1 - if column1 == CompareConst.PASS and column2 in [CompareConst.PASS, CompareConst.SPACE]: - self.test_result_cnt['success_num'] += 1 - elif column1 == CompareConst.ERROR and column2 == CompareConst.ERROR: - self.test_result_cnt['forward_and_backward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column1 == CompareConst.ERROR: - self.test_result_cnt['forward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column2 == CompareConst.ERROR: - self.test_result_cnt['backward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column1 == CompareConst.WARNING or column2 == CompareConst.WARNING: - self.test_result_cnt['warning_num'] += 1 - def write_csv_title(self): summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]] diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py index bd88d6742f..a018b19275 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py @@ -1,4 +1,4 @@ -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst +from atat.core.common.const import CompareConst class CompareColumn: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py index fe841eb063..bcf6b8ea19 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -3,7 +3,8 @@ import os import numpy as np import torch import yaml -from atat.core.common.utils import Const, CompareException +from atat.core.common.utils import CompareException +from atat.core.common.const import Const from atat.pytorch.common.log import logger from atat.core.common.file_check import FileOpen @@ -77,21 +78,6 @@ precision_configs = { } } - -class CompareConst: - NAN = np.nan - NA = "N/A" - PASS = 'pass' - WARNING = 'warning' - ERROR = 'error' - SKIP = 'SKIP' - TRUE = 'TRUE' - FALSE = 'FALSE' - BFLOAT16_MIN = -3.3895313892515355e+38 - BFLOAT16_MAX = 3.3895313892515355e+38 - BFLOAT16_EPS = 2 ** -8 - SPACE = " " - class ApiPrecisionCompareColumn: API_NAME = 'API Name' diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py index e983413bf0..c6b721eee2 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -20,8 +20,9 @@ import math import torch import numpy -from atat.pytorch.api_accuracy_checker.common.utils import Const, check_file_or_directory_path, check_object_type, get_full_data_path, CompareException +from atat.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, get_full_data_path, CompareException from atat.pytorch.common.log import logger +from atat.core.common.const import Const TORCH_TYPE = ["torch.device", "torch.dtype"] TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index b9d1a4fd1f..d2ab9c1e95 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -13,9 +13,10 @@ from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_ get_validated_details_csv_path, preprocess_forward_content from atat.pytorch.api_accuracy_checker.compare.compare import Comparator from atat.pytorch.common import parse_json_info_forward_backward -from atat.core.common.file_check import FileCheckConst, FileChecker, check_file_suffix, check_link, FileOpen, \ +from atat.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \ check_path_before_create, create_directory from atat.pytorch.common.log import logger +from atat.core.common.const import FileCheckConst def split_json_file(input_file, num_splits, filter_api): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py index 77f3bf714a..47cbd99447 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -27,10 +27,10 @@ from atat.pytorch.hook_module.wrap_functional import FunctionalOPTemplate from atat.pytorch.hook_module.wrap_torch import TorchOPTemplate from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig from atat.pytorch.common.parse_json import parse_json_info_forward_backward -from atat.core.common.file_check import FileOpen, FileCheckConst, FileChecker, \ +from atat.core.common.file_check import FileOpen, FileChecker, \ change_mode, check_file_suffix, check_link, check_path_before_create, create_directory from atat.pytorch.common.log import logger -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const, FileCheckConst current_time = time.strftime("%Y%m%d%H%M%S") UT_ERROR_DATA_DIR = 'ut_error_data' + current_time @@ -40,7 +40,11 @@ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'save_error_data', 'is_continue_run_ut', 'real_data_path']) not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} - +RAISE_PRECISION = { + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.float32: torch.float64 +} tqdm_params = { 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1 'desc': 'Processing', # 进度条前的描述文字 @@ -75,7 +79,7 @@ def deal_detach(arg, to_detach=True): def deal_dtype(arg, raise_dtype=None): - if raise_dtype is None or arg.dtype not in Const.RAISE_PRECISION or raise_dtype == arg.dtype: + if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype: return arg return arg.type(raise_dtype) @@ -120,7 +124,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): return arg_in def is_tensor_with_raise_precision(arg_in, check_kwargs=False): - if arg_in.dtype in Const.RAISE_PRECISION: + if arg_in.dtype in RAISE_PRECISION: return True if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]: return True @@ -139,7 +143,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): need_raise_dtypes = recursive_find_dtypes(input_args) need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True)) if len(need_raise_dtypes) == 1: - raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32) + raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32) elif len(need_raise_dtypes) >= 2: raise_dtype = torch.float32 diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index 2d7bdcfff3..061c9cdfca 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -32,9 +32,10 @@ from atat.pytorch.compare.highlight import HighlightRules, get_header_index from atat.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, get_error_message from atat.pytorch.advisor.advisor import Advisor from atat.pytorch.common.log import logger -from atat.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, CompareConst, \ - format_value, check_file_not_exists, check_configuration_param, task_dumppath_get, Const -from atat.core.common.file_check import FileChecker, FileCheckConst, change_mode, FileOpen, create_directory +from atat.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \ + format_value, check_file_not_exists, check_configuration_param, task_dumppath_get +from atat.core.common.file_check import FileChecker, change_mode, FileOpen, create_directory +from atat.core.common.const import Const, CompareConst, FileCheckConst def check_graph_mode(a_op_name, b_op_name): diff --git a/debug/accuracy_tools/atat/pytorch/compare/highlight.py b/debug/accuracy_tools/atat/pytorch/compare/highlight.py index d94e86b013..3a6898dedb 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/highlight.py +++ b/debug/accuracy_tools/atat/pytorch/compare/highlight.py @@ -1,7 +1,8 @@ import math import abc import numpy as np -from atat.core.common.utils import CompareConst, get_header_index +from atat.core.common.utils import get_header_index +from atat.core.common.const import CompareConst class HighlightCheck(abc.ABC): diff --git a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py index 2e1f22ab3f..0cf4c6c00a 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py @@ -1,6 +1,7 @@ import abc import numpy as np -from atat.core.common.utils import CompareConst, Const, format_value +from atat.core.common.utils import format_value +from atat.core.common.const import Const, CompareConst from atat.pytorch.common.log import logger diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py index 6f2bfe8551..1ad69701e4 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py @@ -1,6 +1,6 @@ from atat.pytorch.common import seed_all from atat.pytorch.common.log import logger -from atat.core.common.utils import Const +from atat.core.common.const import Const class DebuggerConfig: diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py index f86fc41d55..b9d41330a8 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py @@ -1,6 +1,6 @@ from atat.core.common.log import logger from atat.core.common.exceptions import FreeBenchmarkException -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const from .main import FreeBenchmarkCheck from .common.params import UnequalRow diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py index d7c2eba377..2ebc0a6db9 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py @@ -1,7 +1,8 @@ from abc import ABC import torch -from atat.pytorch.free_benchmark import Const, logger +from atat.core.common.const import Const +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import CommonField from atat.pytorch.free_benchmark.common.enums import ( DeviceType, diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index 2df26afc1b..03718e3c4d 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -1,5 +1,6 @@ import torch -from atat.pytorch.free_benchmark import Const, logger +from atat.core.common.const import Const +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import CommonField from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py index c1bfbae24a..c57d7e390a 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py @@ -3,7 +3,8 @@ from abc import ABC, abstractmethod from typing import Any, Optional, Tuple import torch -from atat.pytorch.free_benchmark import Const, logger +from atat.core.common.const import Const +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import ThresholdConfig from atat.pytorch.free_benchmark.common.enums import ( FuzzThreshold, diff --git a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py b/debug/accuracy_tools/atat/pytorch/functional/dump_module.py index 8652f13f9b..675fa2a1bf 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py +++ b/debug/accuracy_tools/atat/pytorch/functional/dump_module.py @@ -1,6 +1,6 @@ import torch.nn as nn from atat.pytorch.common.log import logger -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.hook_module.api_registry import api_register from atat.pytorch.debugger.precision_debugger import PrecisionDebugger from atat.core.common.exceptions import MsaccException diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py index 6910276f94..3b971cc71e 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py @@ -25,7 +25,8 @@ from atat.pytorch.hook_module.wrap_functional import get_functional_ops from atat.pytorch.hook_module.wrap_tensor import get_tensor_ops from atat.pytorch.hook_module.wrap_torch import get_torch_ops from atat.pytorch.hook_module.wrap_vf import get_vf_ops -from atat.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu, Const +from atat.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu +from atat.core.common.const import Const torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py index d45a951d47..57212b6e45 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py @@ -20,7 +20,7 @@ import threading import torch import torch.nn as nn import torch.utils.hooks as full_hooks -from atat.core.common.utils import Const +from atat.core.common.const import Const class HOOKModule(nn.Module): diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py index c247a27082..c5a3c6365d 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py @@ -21,7 +21,8 @@ import torch import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py index 1059bf7488..e02189ac1b 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py @@ -21,7 +21,8 @@ import torch.distributed as dist import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py index 8c829904cb..fa97f5ee31 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py @@ -21,7 +21,8 @@ import torch import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const from atat.pytorch.common.log import logger from atat.core.common.file_check import FileOpen diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py index 90ad9cb9c4..7d0882804f 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py @@ -21,7 +21,8 @@ import torch_npu import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, torch_without_guard_version, Const +from atat.pytorch.common.utils import torch_device_guard, torch_without_guard_version +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py index d53291b78f..6fac181402 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py @@ -21,7 +21,8 @@ import torch import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, parameter_adapter, Const +from atat.pytorch.common.utils import torch_device_guard, parameter_adapter +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py index 3cdece2306..f0bd01fe46 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py @@ -21,7 +21,8 @@ import torch import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py index c5f3cb7ee0..d4c570221d 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py @@ -22,7 +22,8 @@ import yaml from atat.pytorch.hook_module.hook_module import HOOKModule from atat.core.common.file_check import FileOpen -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/atat/pytorch/module_processer.py index f56513907c..8ce9140e32 100644 --- a/debug/accuracy_tools/atat/pytorch/module_processer.py +++ b/debug/accuracy_tools/atat/pytorch/module_processer.py @@ -1,7 +1,7 @@ from functools import wraps import torch from torch.utils.hooks import BackwardHook -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.core.data_dump.scope import ModuleRangeScope diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py index d7f9e4e339..e6d55ca061 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py +++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py @@ -7,7 +7,8 @@ from collections import namedtuple from rich.table import Table from rich.console import Console from .single_compare import single_benchmark_compare_wrap -from .utils import DispatchException, CompareConst +from .utils import DispatchException +from atat.core.common.const import CompareConst from atat.core.common.file_check import FileOpen from atat.pytorch.common.log import logger from atat.core.common.utils import CompareException diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py index 386c3eac17..7502d746ac 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py @@ -22,8 +22,8 @@ from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logge DispatchException from .compare import Comparator from atat.core.common.file_check import FileOpen -from atat.pytorch.common.utils import Const -from atat.core.common.utils import CompareConst, check_file_or_directory_path, check_path_before_create +from atat.core.common.utils import check_file_or_directory_path, check_path_before_create +from atat.core.common.const import Const, CompareConst current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py index b8d824fd1e..cd7c5a3f28 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py @@ -5,11 +5,10 @@ from datetime import datetime, timezone import pandas as pd import torch -from atat.pytorch.common.utils import Const from .utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \ COLOR_RESET, CSV_COLUMN_NAME -from atat.core.common.file_check import FileOpen, change_mode, FileCheckConst -from atat.core.common.utils import CompareConst +from atat.core.common.file_check import FileOpen, change_mode +from atat.core.common.const import CompareConst, FileCheckConst, Const from atat.pytorch.common.log import logger class DispatchRunParam: diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py index 1f9c2e916c..f3fcffb6f2 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py +++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py @@ -12,8 +12,8 @@ except ImportError: else: pta_cpu_device = torch.device("cpu") -from atat.core.common.utils import CompareConst -from atat.core.common.file_check import change_mode, FileCheckConst +from atat.core.common.const import CompareConst, FileCheckConst +from atat.core.common.file_check import change_mode cpu_device = torch._C.device("cpu") COLOR_RED = '\033[31m' @@ -58,17 +58,6 @@ BOOL_TYPE = [bool, np.uint8] INT_TYPE = [np.int32, np.int64] -class CompareConst: - NAN = np.nan - NA = "N/A" - PASS = 'pass' - WARNING = 'warning' - ERROR = 'error' - SKIP = 'SKIP' - TRUE = 'TRUE' - FALSE = 'FALSE' - - def get_callstack(): callstack = [] for (_, path, line, func, code, _) in inspect.stack()[2:]: diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py index aeb1d7f6d2..ce42d242ba 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py @@ -30,7 +30,7 @@ from atat.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc from atat.pytorch.parse_tool.lib.parse_exception import ParseException from atat.core.common.file_check import change_mode, check_other_user_writable,\ check_path_executable, check_path_owner_consistent -from atat.core.common.file_check import FileCheckConst +from atat.core.common.const import FileCheckConst from atat.core.common.file_check import FileOpen from atat.core.common.utils import check_file_or_directory_path from atat.pytorch.common.log import logger diff --git a/debug/accuracy_tools/atat/pytorch/pt_config.py b/debug/accuracy_tools/atat/pytorch/pt_config.py index e04b88bb96..0674b91b34 100644 --- a/debug/accuracy_tools/atat/pytorch/pt_config.py +++ b/debug/accuracy_tools/atat/pytorch/pt_config.py @@ -3,7 +3,7 @@ import os from atat.core.common_config import CommonConfig, BaseConfig from atat.core.common.file_check import FileOpen -from atat.core.common.utils import Const +from atat.core.common.const import Const class TensorConfig(BaseConfig): diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/atat/pytorch/service.py index cd80d0852a..d0b9c4d4b2 100644 --- a/debug/accuracy_tools/atat/pytorch/service.py +++ b/debug/accuracy_tools/atat/pytorch/service.py @@ -3,8 +3,8 @@ import os from pathlib import Path from atat.pytorch.common.log import logger -from atat.core.common.file_check import FileChecker, FileCheckConst, check_path_before_create -from atat.core.common.utils import Const +from atat.core.common.file_check import FileChecker, check_path_before_create +from atat.core.common.const import Const, FileCheckConst from atat.core.common.exceptions import DistributedNotInitializedError, MsaccException from atat.core.data_dump.data_collector import build_data_collector from atat.core.data_dump.scope import BaseScope diff --git a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py index 3305acb0b7..aa7882aa59 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py +++ b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py @@ -20,9 +20,9 @@ from unittest import TestCase from unittest.mock import patch, MagicMock from atat.core.common.log import logger +from atat.core.common.const import FileCheckConst from atat.core.common.exceptions import FileCheckException -from atat.core.common.file_check import (FileCheckConst, - check_link, +from atat.core.common.file_check import (check_link, check_path_length, check_path_exists, check_path_readability, diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/atat/test/core_ut/test_utils.py index fae0e4e255..b3273358e4 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_utils.py +++ b/debug/accuracy_tools/atat/test/core_ut/test_utils.py @@ -21,7 +21,8 @@ from unittest import TestCase from unittest.mock import patch, MagicMock, mock_open from atat.core.common.log import logger -from atat.core.common.utils import (Const, CompareException, +from atat.core.common.const import Const +from atat.core.common.utils import (CompareException, check_seed_all, check_inplace_op, make_dump_path_if_not_exists, diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py index 6be8949684..fe92a90aa1 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py @@ -1,7 +1,7 @@ from unittest import TestCase from unittest.mock import patch, mock_open -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.mindspore.ms_config import parse_json_config diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py b/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py index aab90f122f..701c67b074 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py @@ -9,7 +9,7 @@ from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import ( check_error_rate, get_api_checker_result, ) -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst +from atat.core.common.const import CompareConst class TestApiPrecisionCompare(unittest.TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py index 448b518c33..828d646c52 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py @@ -1,7 +1,7 @@ from unittest import TestCase import torch -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode from atat.pytorch.free_benchmark.common.params import data_pre_deal from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py index 948fdaecea..d46e26e094 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py @@ -2,7 +2,7 @@ from abc import ABC from unittest import TestCase import torch -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig from atat.pytorch.free_benchmark.common.counter import preheat_counter from atat.pytorch.free_benchmark.common.enums import ( diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py index 4c1aa1deff..d326e993c0 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py @@ -4,7 +4,7 @@ from unittest import TestCase import torch import torch.nn as nn -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.free_benchmark import FreeBenchmarkCheck from atat.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig from atat.pytorch.free_benchmark.common.enums import ( diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py index c931c85507..fa52fe0e1b 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py @@ -1,7 +1,7 @@ from unittest import TestCase from unittest.mock import patch, mock_open -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.pt_config import parse_json_config diff --git a/debug/accuracy_tools/grad_tool/common/base_comparator.py b/debug/accuracy_tools/grad_tool/common/base_comparator.py index f940ef5135..d3254ae71f 100644 --- a/debug/accuracy_tools/grad_tool/common/base_comparator.py +++ b/debug/accuracy_tools/grad_tool/common/base_comparator.py @@ -40,9 +40,9 @@ class BaseComparator(ABC): create_directory(output_dir) for rank in tqdm(ranks, desc="rank"): print_info_log(f"now comparing rank {rank}:") - cls.compare(os.path.join(path1, f"rank_{rank}"), - os.path.join(path2, f"rank_{rank}"), - os.path.join(output_dir, f"rank_{rank}")) + cls.compare(os.path.join(path1, f"rank{rank}"), + os.path.join(path2, f"rank{rank}"), + os.path.join(output_dir, f"rank{rank}")) @classmethod def compare(cls, path1: str, path2: str, output_dir: str): @@ -59,15 +59,15 @@ class BaseComparator(ABC): check_file_or_directory_path(path1, file_type=GradConst.DIR) check_file_or_directory_path(path2, file_type=GradConst.DIR) dirs = [] - for dirname in os.listdir(path1): - splits = dirname.split('_') - if not splits or splits[0] != dir_prefix or not splits[1].isdigit(): + for dir_name in os.listdir(path1): + index = dir_name.replace(dir_prefix, "", 1) + if not dir_name.startswith(dir_prefix) or not index.isdigit(): continue - folder2 = os.path.join(path2, dirname) + folder2 = os.path.join(path2, dir_name) if not os.path.isdir(folder2): continue - dirs.append(int(splits[1])) + dirs.append(int(index)) dirs = sorted(dirs) return dirs @@ -101,8 +101,8 @@ class BaseComparator(ABC): total_count_summary = 0 for grad_name in grad_weight_order: grad_file = cls._get_name_matched_grad_file(grad_name, grad_files) - grad1 = os.path.join(path1, f"step_{step}", grad_file) - grad2 = os.path.join(path2, f"step_{step}", grad_file) + grad1 = os.path.join(path1, f"step{step}", grad_file) + grad2 = os.path.join(path2, f"step{step}", grad_file) same_count, total_count = cls._calculate_similarity(grad1, grad2) same_count_summary += same_count total_count_summary += total_count @@ -124,8 +124,8 @@ class BaseComparator(ABC): @classmethod def _get_matched_grad_files(cls, path1: str, path2: str, step: int): - path1 = os.path.join(path1, f"step_{step}") - path2 = os.path.join(path2, f"step_{step}") + path1 = os.path.join(path1, f"step{step}") + path2 = os.path.join(path2, f"step{step}") check_file_or_directory_path(path1, file_type=GradConst.DIR) check_file_or_directory_path(path2, file_type=GradConst.DIR) grad_files = [] diff --git a/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py index 6733d566d6..f3079e622c 100644 --- a/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py +++ b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py @@ -96,8 +96,8 @@ class PtGradientMonitor(BaseMonitor): output_lines.append(grad_info) if self._level_adp["have_grad_direction"]: PtGradientMonitor.save_grad_direction(param_name, grad, - f'{self._output_path}/rank_{self._rank}/step_{self._step}') - output_path = os.path.join(self._output_path, f"rank_{getattr(self, '_rank')}", + f'{self._output_path}/rank{self._rank}/step{self._step}') + output_path = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}", f"grad_summary_{self._step}.csv") write_csv(output_path, output_lines, GradStatCsv.generate_csv_header(self._level_adp, self._bounds)) diff --git a/plugins/tensorboard-plugins/ OWNERS b/plugins/tensorboard-plugins/OWNERS similarity index 93% rename from plugins/tensorboard-plugins/ OWNERS rename to plugins/tensorboard-plugins/OWNERS index 34c383beaf..507672c739 100644 --- a/plugins/tensorboard-plugins/ OWNERS +++ b/plugins/tensorboard-plugins/OWNERS @@ -1,9 +1,9 @@ -options: - no_parent_owners: true -approvers: -- wo-wenjie -- ly-qianxiao -reviewers: -- wo-wenjie -- ly-qianxiao -- leo920320 +options: + no_parent_owners: true +approvers: +- wo-wenjie +- ly-qianxiao +reviewers: +- wo-wenjie +- ly-qianxiao +- leo920320 diff --git a/profiler/advisor/README.md b/profiler/advisor/README.md index ccaccdda01..c650f40b3e 100644 --- a/profiler/advisor/README.md +++ b/profiler/advisor/README.md @@ -92,31 +92,31 @@ msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的 - 总体性能瓶颈 ```bash - msprof-analyze advisor all -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [-D] [-h] + msprof-analyze advisor all -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h] ``` - 计算瓶颈 ```bash - msprof-analyze advisor computation -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [-D] [-h] + msprof-analyze advisor computation -d {profiling_path} [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h] ``` - 调度瓶颈 ```bash - msprof-analyze advisor schedule -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-D] [-h] + msprof-analyze advisor schedule -d {profiling_path} [-cv cann_version] [-tv torch_version] [--debug] [-h] ``` #### 参数介绍 | 参数 | 说明 | 是否必选 | | ---------------------------------- | ------------------------------------------------------------ | -------- | -| -d
--profiling_path | 性能数据所在目录。性能数据通过Profiling工具采集获取。请确保性能数据采集时配置“aic-metrics”参数为“PipeUtilization”,“aicpu”参数为“on”。advisor依赖Profiling工具解析后的timeline数据、summary数据以及info.json*文件,请确保指定的“profiling_dir”目录下存在以上文件。 | 是 | +| -d
--profiling_path | 性能数据文件或目录所在路径,Ascend PyTorch Profiler采集场景指定为`*_ascend_pt`性能数据结果目录,其他场景指定为`PROF_XXX`性能数据结果目录。建议通过Ascend PyTorch Profiler获取性能数据。
advisor依赖Profiling工具解析后的timeline数据(.json)、summary(.csv)数据以及info.json*文件,请确保指定的“profiling_path”目录下存在以上文件。 | 是 | | -bp
--benchmark_profiling_path | 基准性能数据所在目录,用于性能比对。性能数据通过Profiling工具采集获取。
**computation和schedule不支持该参数。** | 否 | | -cv
--cann_version | 使用Profiling工具采集时对应的CANN软件版本,可通过在环境中执行如下命令获取其version字段,目前配套的兼容版本为“6.3.RC2”,“7.0.RC1”、“7.0.0”、“8.0.RC1”,此字段不填默认按“8.0.RC1”版本数据进行处理,其余版本采集的Profiling数据在分析时可能会导致不可知问题:`cat /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/ascend_toolkit_install.info` | 否 | | -tv
--torch_version | 运行环境的torch版本,默认为1.11.0,支持torch1.11.0和torch2.1.0,当运行环境torch版本为其他版本如torch1.11.3时,可以忽略小版本号差异选择相近的torch版本如1.11.0。 | 否 | -| -pt
--profiling_type | 配置性能数据采集使用的Profiling工具类型。可取值:
ascend_pytorch_profiler:使用Ascend PyThon Profiler接口方式采集的性能数据时配置,默认值。
msprof:使用msprof命令行方式采集的性能数据时配置。
mslite:使用[Benchmark](https://gitee.com/ascend/tools/tree/master/ais-bench_workload/tool/ais_bench)工具采集的性能数据时配置。
**schedule不支持该参数。** | 否 | -| -D
--debug | 工具执行报错时可打开此开关,将会展示详细保存堆栈信息。 | 否 | +| -pt
--profiling_type | 配置性能数据采集使用的Profiling工具类型。可取值:
ascend_pytorch_profiler:使用Ascend PyThon Profiler接口方式采集的性能数据时配置,默认值。
msprof:使用msprof命令行方式采集的性能数据时配置。功能完善中,暂不建议使用。
mslite:使用[Benchmark](https://gitee.com/ascend/tools/tree/master/ais-bench_workload/tool/ais_bench)工具采集的性能数据时配置。不建议使用。
**schedule不支持该参数。** | 否 | +| --debug | 工具执行报错时可打开此开关,将会展示详细保存堆栈信息。 | 否 | | -h,-H
--help | 在需要查询当前命令附属子命令或相关参数时,给出帮助建议。 | 否 | ### 报告解析 diff --git a/profiler/compare_tools/compare_backend/utils/constant.py b/profiler/compare_tools/compare_backend/utils/constant.py index 1b77b214c8..e2854692ae 100644 --- a/profiler/compare_tools/compare_backend/utils/constant.py +++ b/profiler/compare_tools/compare_backend/utils/constant.py @@ -74,7 +74,7 @@ class Constant(object): MEMORY_LIST = "memory_list" COMMUNICATION_DICT = "comm_dict" - #compare type + # compare type OVERALL_COMPARE = "overall" BWD_LIST = ["bwd", "backward", "back"] diff --git a/profiler/module_visualization/__init__.py b/profiler/module_visualization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/module_visualization/graph/__init__.py b/profiler/module_visualization/graph/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/module_visualization/graph/prof_node.py b/profiler/module_visualization/graph/prof_node.py new file mode 100644 index 0000000000..cfcdabbb99 --- /dev/null +++ b/profiler/module_visualization/graph/prof_node.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from profiler.prof_common.constant import Constant +from profiler.prof_common.base_node import BaseNode +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class ProfNode(BaseNode): + MODULE_TYPE = 1 + + def __init__(self, event: TraceEventBean, parent_node=None): + super().__init__(event, parent_node) + self._kernel_total_list = [] + + @property + def node_id(self): + return self._event.unique_id + + @property + def total_kernels(self): + return self._kernel_total_list + + @property + def host_total_dur(self): + if self.is_root_node: + return sum((node.host_total_dur for node in self.child_nodes)) + return self._event.dur + + @property + def host_self_dur(self): + return self.host_total_dur - sum((node.host_total_dur for node in self.child_nodes)) + + @property + def device_total_dur(self): + if self.is_root_node: + return sum((node.device_total_dur for node in self.child_nodes)) + return sum((kernel.dur for kernel in self._kernel_total_list)) + + @property + def device_self_dur(self): + return self.device_total_dur - sum((node.device_total_dur for node in self.child_nodes)) + + @property + def input_data(self) -> dict: + data = {} + input_dim = self._event.args.get("Input Dims") + if input_dim: + data["Input Dims"] = input_dim + input_type = self._event.args.get("Input type") + if input_type: + data["Input type"] = input_type + return data + + @property + def data(self): + return {"Input Data": self.input_data, + "Host Self Duration(us)": round(self.host_self_dur, 2), + "Host Total Duration(us)": round(self.host_total_dur, 2), + "Device Self Duration(us)": round(self.device_self_dur, 2), + "Device Total Duration(us)": round(self.device_total_dur, 2)} + + @property + def info(self): + return {"id": self.node_id, + "node_type": self.MODULE_TYPE, + "data": self.data, + "upnode": self.parent_node.node_id if self.parent_node else "None", + "subnodes": [node.node_id for node in iter(self.child_nodes)]} + + @property + def is_root_node(self): + return self.node_id == Constant.NPU_ROOT_ID + + def update_child_nodes(self, node): + self._child_nodes.append(node) + + def update_kernel_total_list(self, kernel_list: list): + self._kernel_total_list.extend(kernel_list) diff --git a/profiler/module_visualization/graph_build/__init__.py b/profiler/module_visualization/graph_build/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/module_visualization/graph_build/fwd_module_node.py b/profiler/module_visualization/graph_build/fwd_module_node.py new file mode 100644 index 0000000000..34d7ab8296 --- /dev/null +++ b/profiler/module_visualization/graph_build/fwd_module_node.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from profiler.prof_common.base_node import BaseNode +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class FwdModuleNode(BaseNode): + def __init__(self, event: TraceEventBean, parent_node=None): + super().__init__(event, parent_node) + self._bwd_op_list = [] + + @property + def bwd_op_list(self): + return self._bwd_op_list + + def update_bwd_op(self, bwd_op_list: list): + self._bwd_op_list.extend(bwd_op_list) diff --git a/profiler/module_visualization/graph_build/prof_graph_builder.py b/profiler/module_visualization/graph_build/prof_graph_builder.py new file mode 100644 index 0000000000..83331b6250 --- /dev/null +++ b/profiler/module_visualization/graph_build/prof_graph_builder.py @@ -0,0 +1,115 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from profiler.module_visualization.graph.prof_node import ProfNode +from profiler.module_visualization.graph_build.fwd_module_node import FwdModuleNode +from profiler.prof_common.tree_builder import TreeBuilder +from profiler.prof_common.trace_event_bean import TraceEventBean +from profiler.prof_common.constant import Constant +from profiler.module_visualization.prof_parse.prof_data_pre_process import ProfDataPreProcess + + +class ProfGraphBuilder: + def __init__(self, prof_data_path: str): + self._prof_data_path = prof_data_path + self._prof_data = {} + + @classmethod + def _create_event_bean_from_ops(cls, op_list: list, name: str) -> TraceEventBean: + min_start = min((op.start_time for op in iter(op_list))) + max_end = max((op.end_time for op in iter(op_list))) + # 以反向算子的区间作为反向module的区间范围,为了module包含算子,做了+1 +2处理 + return TraceEventBean({"ts": min_start - 1, "dur": float(max_end - min_start) + 2, "name": name}) + + @classmethod + def _trans_flow_to_dict(cls, flow_events: dict, end_events: list) -> dict: + end_event_dict = {} + for event in end_events: + end_event_dict[event.start_time] = event + result_data = {} + for flow in flow_events.values(): + start_point = flow.get("start") + end_point = flow.get("end") + if not start_point or not end_point: + continue + end_event = end_event_dict.get(end_point.start_time) + if end_event: + result_data.setdefault(start_point.start_time, []).append(end_event) + return result_data + + def build_graph(self): + self._prof_data = ProfDataPreProcess(self._prof_data_path).run() + all_data = [*self._prof_data.get(Constant.MODULE_EVENT, []), + *self.find_bwd_module(), + *self._prof_data.get(Constant.CPU_OP_EVENT, [])] + all_data.sort(key=lambda x: x.start_time) + name_dict = {} + for event in all_data: + order_id = name_dict.get(event.name, 0) + event.set_id(f"{event.name}_{order_id}") + name_dict[event.name] = order_id + 1 + root_node = TreeBuilder.build_tree(all_data, ProfNode, TraceEventBean({}, Constant.NPU_ROOT_ID)) + kernel_flow_dict = self._trans_flow_to_dict(self._prof_data.get(Constant.TORCH_TO_NPU_FLOW, {}), + self._prof_data.get(Constant.KERNEL_EVENT, [])) + for start_time, kernels in kernel_flow_dict.items(): + matched_node = root_node.binary_search(start_time) + while matched_node != Constant.INVALID_RETURN: + matched_node.update_kernel_total_list(kernels) + matched_node = matched_node.binary_search(start_time) + all_data = root_node.find_all_child_nodes() + all_data.append(root_node) + return all_data + + def find_bwd_module(self) -> list: + bwd_module_list = [] + fwdbwd_flow = self._prof_data.get(Constant.FWD_BWD_FLOW, {}) + module_list = self._prof_data.get(Constant.MODULE_EVENT, []) + cpu_op_list = self._prof_data.get(Constant.CPU_OP_EVENT, []) + if not fwdbwd_flow or not module_list or not cpu_op_list: + return bwd_module_list + fwd_tid = module_list[0].tid + bwd_tid = fwd_tid + for end_point in (flow.get("end") for flow in fwdbwd_flow.values()): + if end_point: + bwd_tid = end_point.tid + break + if fwd_tid == bwd_tid: + return bwd_module_list + # 将每一个反向包成一个module,名字叫“nn.Module: BACKWARD_0” + cpu_op_list.sort(key=lambda x: x.start_time) + pre_status = Constant.FWD_OR_OPT + bwd_op_list = [] + for op in cpu_op_list: + if op.tid == bwd_tid: + bwd_op_list.append(op) + pre_status = Constant.BACKWARD + elif pre_status == Constant.BACKWARD: + bwd_module_list.append(self._create_event_bean_from_ops(bwd_op_list, "nn.Module: BACKWARD")) + bwd_op_list.clear() + pre_status = Constant.FWD_OR_OPT + + # 通过连线匹配正向module,构建出反向的整体module关系 + root_node = TreeBuilder.build_tree(module_list, FwdModuleNode, TraceEventBean({})) + fwdbwd_flow_dict = self._trans_flow_to_dict(fwdbwd_flow, cpu_op_list) + for start_time, end_events in fwdbwd_flow_dict.items(): + matched_node = root_node.binary_search(start_time) + while matched_node != Constant.INVALID_RETURN: + matched_node.update_bwd_op(end_events) + matched_node = matched_node.binary_search(start_time) + all_nodes = root_node.find_all_child_nodes() + for module_node in all_nodes: + if module_node.bwd_op_list: + bwd_module_list.append( + self._create_event_bean_from_ops(module_node.bwd_op_list, f"{module_node.name} [BACKWARD]")) + return bwd_module_list diff --git a/profiler/module_visualization/prof_graph_export.py b/profiler/module_visualization/prof_graph_export.py new file mode 100644 index 0000000000..d336e97f74 --- /dev/null +++ b/profiler/module_visualization/prof_graph_export.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from datetime import datetime + +from profiler.prof_common.constant import Constant +from profiler.prof_common.file_reader import FileReader +from profiler.prof_common.path_manager import PathManager +from profiler.module_visualization.graph_build.prof_graph_builder import ProfGraphBuilder + + +class ProfGraphExport: + @staticmethod + def export_to_json(prof_data_path: str, output_path: str): + logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") + try: + PathManager.input_path_common_check(prof_data_path) + PathManager.check_input_directory_path(output_path) + PathManager.make_dir_safety(output_path) + all_nodes = ProfGraphBuilder(prof_data_path).build_graph() + result_data = {"root": Constant.NPU_ROOT_ID, "node": {}} + for node in all_nodes: + result_data["node"][node.node_id] = node.info + file_name = "prof_graph_json_{}.vis".format(datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:-3]) + FileReader.write_json_file(output_path, result_data, file_name) + except RuntimeError as err: + logging.error(err) diff --git a/profiler/module_visualization/prof_parse/__init__.py b/profiler/module_visualization/prof_parse/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/module_visualization/prof_parse/prof_data_pre_process.py b/profiler/module_visualization/prof_parse/prof_data_pre_process.py new file mode 100644 index 0000000000..9dc820e4ca --- /dev/null +++ b/profiler/module_visualization/prof_parse/prof_data_pre_process.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from profiler.prof_common.file_reader import FileReader +from profiler.prof_common.constant import Constant +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class ProfDataPreProcess: + def __init__(self, prof_data_path: str): + self._prof_data_path = prof_data_path + self._trace_path = "" + self._kernel_pid = None + self._result_data = {Constant.CPU_OP_EVENT: [], Constant.MODULE_EVENT: [], Constant.KERNEL_EVENT: [], + Constant.TORCH_TO_NPU_FLOW: {}, Constant.FWD_BWD_FLOW: {}} + + def run(self) -> dict: + self._check_trace_path() + self._parse_trace_events() + self._check_result_data() + return self._result_data + + def _check_trace_path(self): + if os.path.isfile(self._prof_data_path): + (split_file_path, split_file_name) = os.path.split(self._prof_data_path) + (shot_name, extension) = os.path.splitext(split_file_name) + if extension != ".json": + msg = f"Invalid profiling path suffix: {self._prof_data_path}. " \ + f"You should input in a json file path, such as trace_view.json." + raise RuntimeError(msg) + self._trace_path = self._prof_data_path + return + ascend_output = os.path.join(self._prof_data_path, "ASCEND_PROFILER_OUTPUT") + profiler_output = ascend_output if os.path.isdir(ascend_output) else self._prof_data_path + json_path = os.path.join(profiler_output, "trace_view.json") + if not os.path.isfile(json_path): + msg = f"Invalid profiling path: {self._prof_data_path}. The data path should be the " \ + f"folder that ends with the ascend_pt collected by the Ascend PyTorch Profiler." + raise RuntimeError(msg) + self._trace_path = json_path + + def _parse_trace_events(self): + trace_data = FileReader.read_json_file(self._trace_path) + self._check_trace_data(trace_data) + iter_trace_data = iter(trace_data) + for event in iter_trace_data: + bean = TraceEventBean(event) + if bean.is_optimizer(): + self._result_data[Constant.MODULE_EVENT].append(bean) + elif bean.is_cpu_op(): + if not bean.is_step(): + self._result_data[Constant.CPU_OP_EVENT].append(bean) + elif bean.is_nn_module(): + self._result_data[Constant.MODULE_EVENT].append(bean) + elif bean.is_torch_to_npu(): + if bean.is_flow_start(): + self._result_data[Constant.TORCH_TO_NPU_FLOW].setdefault(bean.id, {})["start"] = bean + else: + self._result_data[Constant.TORCH_TO_NPU_FLOW].setdefault(bean.id, {})["end"] = bean + elif bean.is_fwd_bwd_flow(): + if bean.is_flow_start(): + self._result_data[Constant.FWD_BWD_FLOW].setdefault(bean.id, {})["start"] = bean + else: + self._result_data[Constant.FWD_BWD_FLOW].setdefault(bean.id, {})["end"] = bean + elif bean.is_kernel_event(self._kernel_pid): + self._result_data[Constant.KERNEL_EVENT].append(bean) + + def _check_trace_data(self, trace_data): + if not isinstance(trace_data, list): + msg = f"Invalid profiling data path, this feature only supports performance data " \ + f"collected by Ascend PyTorch Profiler." + raise RuntimeError(msg) + iter_trace_data = iter(trace_data) + for event in iter_trace_data: + bean = TraceEventBean(event) + if bean.is_npu_process(): + self._kernel_pid = bean.pid + break + if self._kernel_pid is None: + msg = f"There is no operator on the NPU side for this data, please check whether the NPU switch is enabled." + raise RuntimeError(msg) + + def _check_result_data(self): + if not self._result_data.get(Constant.CPU_OP_EVENT): + msg = f"This data does not have any aten operator, please make sure to enable the CPU switch." + raise RuntimeError(msg) + if not self._result_data.get(Constant.MODULE_EVENT): + msg = f"This data does not collect any modules, please make sure to turn on the with_stack switch." + raise RuntimeError(msg) diff --git a/profiler/prof_common/base_node.py b/profiler/prof_common/base_node.py new file mode 100644 index 0000000000..b7cd678000 --- /dev/null +++ b/profiler/prof_common/base_node.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from math import ceil +from queue import Queue + +from decimal import Decimal + +from profiler.prof_common.constant import Constant +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class BaseNode: + def __init__(self, event: TraceEventBean, parent_node=None): + self._event = event + self._parent_node = parent_node + self._child_nodes = [] + + @property + def parent_node(self): + return self._parent_node + + @property + def child_nodes(self): + return self._child_nodes + + @property + def name(self): + return self._event.name + + @property + def start_time(self) -> Decimal: + return self._event.start_time + + @property + def end_time(self) -> Decimal: + return self._event.end_time + + def update_child_nodes(self, node): + self._child_nodes.append(node) + + def binary_search(self, ts_time): + if not self.child_nodes: + return Constant.INVALID_RETURN + right = len(self.child_nodes) - 1 + left = 0 + while right > left: + mid = left + ceil((right - left) / 2) + if ts_time >= self.child_nodes[mid].start_time: + left = mid + else: + right = mid - 1 + if self.child_nodes[left].start_time < ts_time < self.child_nodes[left].end_time: + return self.child_nodes[left] + return Constant.INVALID_RETURN + + def find_all_child_nodes(self) -> list: + result_data = [] + node_queue = Queue() + for child_node in self.child_nodes: + node_queue.put(child_node) + while not node_queue.empty(): + tree_node = node_queue.get() + result_data.append(tree_node) + for child_node in tree_node.child_nodes: + node_queue.put(child_node) + return result_data diff --git a/profiler/prof_common/constant.py b/profiler/prof_common/constant.py index 5789b89cb1..87bc51b56b 100644 --- a/profiler/prof_common/constant.py +++ b/profiler/prof_common/constant.py @@ -15,4 +15,17 @@ class Constant(object): COLLECTION_PATH = "collection_path" ANALYSIS_MODE = "analysis_mode" - CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help']) \ No newline at end of file + CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help']) + + MAX_FILE_SIZE_5_GB = 1024 * 1024 * 1024 * 5 + + MODULE_EVENT = "module_event" + CPU_OP_EVENT = "op_event" + TORCH_TO_NPU_FLOW = "torch_to_device" + KERNEL_EVENT = "kernel_event" + FWD_BWD_FLOW = "fwd_to_bwd" + NPU_ROOT_ID = "NPU" + + FWD_OR_OPT = 0 + BACKWARD = 1 + INVALID_RETURN = -1 diff --git a/profiler/prof_common/file_reader.py b/profiler/prof_common/file_reader.py new file mode 100644 index 0000000000..d8a9c8fb4d --- /dev/null +++ b/profiler/prof_common/file_reader.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import os + +from profiler.prof_common.path_manager import PathManager +from profiler.prof_common.constant import Constant + + +class FileReader: + DATA_FILE_AUTHORITY = 0o640 + DATA_DIR_AUTHORITY = 0o750 + + @classmethod + def read_json_file(cls, file_path: str) -> any: + PathManager.check_path_readable(file_path) + if not os.path.isfile(file_path): + raise FileNotFoundError("File not exists.") + file_size = os.path.getsize(file_path) + if file_size <= 0: + return [] + if file_size > Constant.MAX_FILE_SIZE_5_GB: + msg = f"The file({file_path}) size exceeds the preset max value, failed to read the file." + raise RuntimeError(msg) + try: + with open(file_path, "rt") as file: + json_data = json.loads(file.read()) + except Exception as e: + msg = f"Can't read file: {file_path}" + raise RuntimeError(msg) from e + return json_data + + @classmethod + def write_json_file(cls, output_path: str, data: dict, file_name: str, format_json: bool = False) -> None: + if not data: + return + output_file = os.path.join(output_path, file_name) + PathManager.check_path_writeable(output_path) + try: + with os.fdopen( + os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), 'w' + ) as file: + indent = 4 if format_json else None + file.write(json.dumps(data, indent=indent)) + except Exception as e: + raise RuntimeError(f"Can't create the file: {output_path}") from e diff --git a/profiler/prof_common/path_manager.py b/profiler/prof_common/path_manager.py new file mode 100644 index 0000000000..3e41b8b50a --- /dev/null +++ b/profiler/prof_common/path_manager.py @@ -0,0 +1,191 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +import shutil +import platform + + +class PathManager: + MAX_PATH_LENGTH = 4096 + MAX_FILE_NAME_LENGTH = 255 + DATA_FILE_AUTHORITY = 0o640 + DATA_DIR_AUTHORITY = 0o750 + WINDOWS = "windows" + + @classmethod + def check_input_directory_path(cls, path: str): + """ + Function Description: + check whether the path is valid, some businesses can accept a path that does not exist, + so the function do not verify whether the path exists + Parameter: + path: the path to check, whether the incoming path is absolute or relative depends on the business + Exception Description: + when invalid data throw exception + """ + cls.input_path_common_check(path) + base_name = os.path.basename(path) + if os.path.isfile(path): + msg = f"Invalid input path which is a file path: {base_name}" + raise RuntimeError(msg) + + @classmethod + def check_input_file_path(cls, path: str): + """ + Function Description: + check whether the file path is valid, some businesses can accept a path that does not exist, + so the function do not verify whether the path exists + Parameter: + path: the file path to check, whether the incoming path is absolute or relative depends on the business + Exception Description: + when invalid data throw exception + """ + cls.input_path_common_check(path) + base_name = os.path.basename(path) + if os.path.isdir(path): + msg = f"Invalid input path which is a directory path: {base_name}" + raise RuntimeError(msg) + + @classmethod + def check_path_length(cls, path: str): + if len(path) > cls.MAX_PATH_LENGTH: + raise RuntimeError("Length of input path exceeds the limit.") + path_split_list = path.split("/") + for path in path_split_list: + path_list = path.split("\\") + for name in path_list: + if len(name) > cls.MAX_FILE_NAME_LENGTH: + raise RuntimeError("Length of input path exceeds the limit.") + + @classmethod + def input_path_common_check(cls, path: str): + cls.check_path_length(path) + + if os.path.islink(path): + msg = f"Invalid input path which is a soft link." + raise RuntimeError(msg) + + if platform.system().lower() == cls.WINDOWS: + pattern = r'(\.|:|\\|/|_|-|\s|[~0-9a-zA-Z\u4e00-\u9fa5])+' + else: + pattern = r'(\.|/|_|-|\s|[~0-9a-zA-Z])+' + if not re.fullmatch(pattern, path): + msg = f"Invalid input path." + raise RuntimeError(msg) + + @classmethod + def check_path_owner_consistent(cls, path: str): + """ + Function Description: + check whether the path belong to process owner + Parameter: + path: the path to check + Exception Description: + when invalid path, prompt the user + """ + base_name = os.path.basename(path) + if not os.path.exists(path): + msg = f"Invalid path: {base_name}" + raise RuntimeError(msg) + if platform.system().lower() == cls.WINDOWS: + return + if os.stat(path).st_uid != os.getuid(): + check_msg = input("The path does not belong to you, do you want to continue? [y/n]") + if check_msg.lower() != "y": + raise RuntimeError("The user choose not to continue.") + + @classmethod + def check_path_writeable(cls, path): + """ + Function Description: + check whether the path is writable + Parameter: + path: the path to check + Exception Description: + when invalid data throw exception + """ + cls.check_path_owner_consistent(path) + if os.path.islink(path): + msg = f"Invalid path which is a soft link." + raise RuntimeError(msg) + base_name = os.path.basename(path) + if not os.access(path, os.W_OK): + msg = f"The path permission check failed: {base_name}" + raise RuntimeError(msg) + + @classmethod + def check_path_readable(cls, path): + """ + Function Description: + check whether the path is writable + Parameter: + path: the path to check + Exception Description: + when invalid data throw exception + """ + cls.check_path_owner_consistent(path) + if os.path.islink(path): + msg = f"Invalid path which is a soft link." + raise RuntimeError(msg) + base_name = os.path.basename(path) + if not os.access(path, os.R_OK): + msg = f"The path permission check failed: {base_name}" + raise RuntimeError(msg) + + @classmethod + def remove_path_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to remove path: {base_name}" + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + try: + shutil.rmtree(path) + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def make_dir_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to make directory: {base_name}" + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + return + try: + os.makedirs(path, mode=cls.DATA_DIR_AUTHORITY) + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def create_file_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to create file: {base_name}" + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + return + try: + os.close(os.open(path, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY)) + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def get_realpath(cls, path: str) -> str: + if os.path.islink(path): + msg = f"Invalid input path which is a soft link." + raise RuntimeError(msg) + return os.path.realpath(path) diff --git a/profiler/prof_common/trace_event_bean.py b/profiler/prof_common/trace_event_bean.py new file mode 100644 index 0000000000..2d4b96e4f6 --- /dev/null +++ b/profiler/prof_common/trace_event_bean.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from decimal import Decimal + +from profiler.prof_common.utils import convert_to_decimal +from profiler.prof_common.analyze_dict import AnalyzeDict + + +class TraceEventBean(AnalyzeDict): + def __init__(self, data: dict, unique_id: int = None): + super().__init__(data) + self._id = unique_id + + @property + def unique_id(self): + return self._id + + @property + def start_time(self) -> Decimal: + return convert_to_decimal(self.ts) + + @property + def end_time(self) -> Decimal: + return self.start_time + convert_to_decimal(self.dur) + + def set_id(self, name_id): + self._id = name_id + + def is_cpu_op(self): + return self.cat == "cpu_op" + + def is_optimizer(self): + return self.cat == "cpu_op" and self.name.lower().startswith("optimizer") + + def is_nn_module(self): + return self.cat == "python_function" and self.name.lower().startswith("nn.module") + + def is_step(self): + return self.name.lower().startswith("profilerstep#") + + def is_torch_to_npu(self): + return self.cat == "async_npu" + + def is_fwd_bwd_flow(self): + return self.cat == "fwdbwd" + + def is_flow_start(self): + return self.ph == "s" + + def is_flow_end(self): + return self.ph == "f" + + def is_kernel_event(self, kernel_pid): + return self.ph == "X" and self.pid == kernel_pid + + def is_npu_process(self): + return self.ph == "M" and self.name == "process_name" and self.args.get("name", "") == "Ascend Hardware" diff --git a/profiler/prof_common/tree_builder.py b/profiler/prof_common/tree_builder.py new file mode 100644 index 0000000000..b7d3e1baf6 --- /dev/null +++ b/profiler/prof_common/tree_builder.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class TreeBuilder: + @staticmethod + def build_tree(event_list: list, node_class: any, root_bean: any): + root_node = node_class(root_bean) + event_list.sort(key=lambda x: x.start_time) + last_node = root_node + for event in event_list: + while last_node: + if last_node != root_node and event.start_time > last_node.end_time: + last_node = last_node.parent_node + continue + tree_node = node_class(event, last_node) + last_node.update_child_nodes(tree_node) + last_node = tree_node + break + return root_node diff --git a/profiler/prof_common/utils.py b/profiler/prof_common/utils.py new file mode 100644 index 0000000000..a9db41ad0b --- /dev/null +++ b/profiler/prof_common/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from decimal import Decimal + + +def convert_to_decimal(data: any) -> Decimal: + try: + decimal_value = Decimal(data) + except Exception: + logging.error('Invalid profiling data which failed to convert data to decimal.') + return 0.0 + return decimal_value -- Gitee