From fb503eb57f723712d3d169f52e2d612895de2107 Mon Sep 17 00:00:00 2001 From: curry3 <485078529@qq.com> Date: Wed, 3 Jul 2024 17:03:10 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E3=80=90feature=E3=80=91atat=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0ut=E6=A1=86=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../atat/test/core_ut/test_utils.py | 32 ++++++++++ .../atat/test/mindspore_ut/test_ms_config.py | 31 ++++++++++ .../test_perturbed_layser.py | 0 .../atat/test/pytorch_ut/test_pt_config.py | 38 ++++++++++++ debug/accuracy_tools/atat/test/run_test.sh | 30 +++++++++ debug/accuracy_tools/atat/test/run_ut.py | 62 +++++++++++++++++++ 6 files changed, 193 insertions(+) create mode 100644 debug/accuracy_tools/atat/test/core_ut/test_utils.py create mode 100644 debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py rename debug/accuracy_tools/{test/pytorch/free_benchmark => atat/test/pytorch_ut/free_benchmark/perturbed_layers}/test_perturbed_layser.py (100%) create mode 100644 debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py create mode 100644 debug/accuracy_tools/atat/test/run_test.sh create mode 100644 debug/accuracy_tools/atat/test/run_ut.py diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/atat/test/core_ut/test_utils.py new file mode 100644 index 00000000000..9492bbc9f97 --- /dev/null +++ b/debug/accuracy_tools/atat/test/core_ut/test_utils.py @@ -0,0 +1,32 @@ +from unittest import TestCase +from unittest.mock import patch + +from atat.core.utils import check_seed_all, Const, CompareException + + +class TestUtils(TestCase): + @patch("atat.core.utils.print_error_log") + def test_check_seed_all(self, mock_print_error_log): + self.assertIsNone(check_seed_all(1234, True)) + self.assertIsNone(check_seed_all(0, True)) + self.assertIsNone(check_seed_all(Const.MAX_SEED_VALUE, True)) + + with self.assertRaises(CompareException) as context: + check_seed_all(-1, True) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + + with self.assertRaises(CompareException) as context: + check_seed_all(Const.MAX_SEED_VALUE + 1, True) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + + with self.assertRaises(CompareException) as context: + check_seed_all("1234", True) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_print_error_log.assert_called_with("Seed must be integer.") + + with self.assertRaises(CompareException) as context: + check_seed_all(1234, 1) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_print_error_log.assert_called_with("seed_all mode must be bool.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py new file mode 100644 index 00000000000..0029e24bdac --- /dev/null +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py @@ -0,0 +1,31 @@ +from unittest import TestCase +from unittest.mock import patch, mock_open + +from atat.core.utils import Const +from atat.mindspore.ms_config import parse_json_config + + +class TestMsConfig(TestCase): + def test_parse_json_config(self): + mock_json_data = { + "dump_path": "./dump/", + "rank": [], + "step": [], + "level": "L1", + "seed": 1234, + "statistics": { + "scope": [], + "list": [], + "data_mode": ["all"], + "summary_mode": "statistics" + } + } + with (patch("atat.mindspore.ms_config.FileOpen", mock_open(read_data='')), + patch("atat.mindspore.ms_config.json.load", return_value=mock_json_data)): + common_config, task_config = parse_json_config("./config.json") + self.assertEqual(common_config.task, Const.STATISTICS) + self.assertEqual(task_config.data_mode, ["all"]) + + with self.assertRaises(Exception) as context: + parse_json_config(None) + self.assertEqual(str(context.exception), "json file path is None") diff --git a/debug/accuracy_tools/test/pytorch/free_benchmark/test_perturbed_layser.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py similarity index 100% rename from debug/accuracy_tools/test/pytorch/free_benchmark/test_perturbed_layser.py rename to debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py new file mode 100644 index 00000000000..8279c207765 --- /dev/null +++ b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py @@ -0,0 +1,38 @@ +from unittest import TestCase +from unittest.mock import patch, mock_open + +from atat.core.utils import Const +from atat.pytorch.pt_config import parse_json_config + + +class TestPtConfig(TestCase): + def test_parse_json_config(self): + mock_json_data = { + "task": "statistics", + "dump_path": "./dump/", + "rank": [], + "step": [], + "level": "L1", + "seed": 1234, + "statistics": { + "scope": [], + "list": [], + "data_mode": ["all"], + }, + "tensor": { + "file_format": "npy" + } + } + with (patch("atat.pytorch.pt_config.os.path.join", return_value="/path/config.json"), + patch("atat.pytorch.pt_config.FileOpen", mock_open(read_data='')), + patch("atat.pytorch.pt_config.json.load", return_value=mock_json_data)): + common_config, task_config = parse_json_config(None, None) + self.assertEqual(common_config.task, Const.STATISTICS) + self.assertEqual(task_config.data_mode, ["all"]) + + with (patch("atat.pytorch.pt_config.os.path.join", return_value="/path/config.json"), + patch("atat.pytorch.pt_config.FileOpen", mock_open(read_data='')), + patch("atat.pytorch.pt_config.json.load", return_value=mock_json_data)): + common_config, task_config = parse_json_config(None, Const.TENSOR) + self.assertEqual(common_config.task, Const.STATISTICS) + self.assertEqual(task_config.file_format, "npy") diff --git a/debug/accuracy_tools/atat/test/run_test.sh b/debug/accuracy_tools/atat/test/run_test.sh new file mode 100644 index 00000000000..1bf0ccb7713 --- /dev/null +++ b/debug/accuracy_tools/atat/test/run_test.sh @@ -0,0 +1,30 @@ +#!/bin/bash +CUR_DIR=$(dirname $(readlink -f $0)) +TOP_DIR=${CUR_DIR}/.. +TEST_DIR=${TOP_DIR}/"test" +SRC_DIR=${TOP_DIR}/../ + +install_pytest() { + if ! pip show pytest &> /dev/null; then + echo "pytest not found, trying to install..." + pip install pytest + fi + + if ! pip show pytest-cov &> /dev/null; then + echo "pytest-cov not found, trying to install..." + pip install pytest-cov + fi +} + +run_ut() { + install_pytest + + export PYTHONPATH=${SRC_DIR}:${PYTHONPATH} + python3 run_ut.py +} + +main() { + cd ${TEST_DIR} && run_ut +} + +main $@ diff --git a/debug/accuracy_tools/atat/test/run_ut.py b/debug/accuracy_tools/atat/test/run_ut.py new file mode 100644 index 00000000000..7f51d266c24 --- /dev/null +++ b/debug/accuracy_tools/atat/test/run_ut.py @@ -0,0 +1,62 @@ +import os +import shutil +import subprocess +import sys + +from atat.core.log import print_info_log, print_error_log + + +def get_ignore_dirs(cur_dir): + ignore_dirs = [] + try: + import torch + import torch_npu + except ImportError: + print_info_log(f"Skipping the {cur_dir}/pytorch_ut directory") + ignore_dirs.extend(["--ignore", f"{cur_dir}/pytorch_ut"]) + + try: + import mindspore + except ImportError: + print_info_log(f"Skipping the {cur_dir}/mindspore_ut directory") + ignore_dirs.extend(["--ignore", f"{cur_dir}/mindspore_ut"]) + + return ignore_dirs + + +def run_ut(): + cur_dir = os.path.realpath(os.path.dirname(__file__)) + ut_path = cur_dir + ignore_dirs = get_ignore_dirs(cur_dir) + cov_dir = os.path.dirname(cur_dir) + report_dir = os.path.join(cur_dir, "report") + final_xml_path = os.path.join(report_dir, "final.xml") + cov_report_path = os.path.join(report_dir, "coverage.xml") + + if os.path.exists(report_dir): + shutil.rmtree(report_dir) + os.makedirs(report_dir) + + cmd = ["python3", "-m", "pytest", ut_path, "--junitxml=" + final_xml_path, "--cov=" + cov_dir, + "--cov-branch", "--cov-report=xml:" + cov_report_path] + ignore_dirs + result_ut = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + while result_ut.poll() is None: + line = result_ut.stdout.readline().strip() + if line: + print_info_log(str(line)) + + ut_flag = False + if result_ut.returncode == 0: + ut_flag = True + print_info_log("run ut successfully.") + else: + print_error_log("run ut failed.") + + return ut_flag + + +if __name__ == "__main__": + if run_ut(): + sys.exit(0) + else: + sys.exit(1) -- Gitee From 0423c19b8d4b133bae82284049d9ad8723a41f1d Mon Sep 17 00:00:00 2001 From: curry3 <485078529@qq.com> Date: Sat, 6 Jul 2024 16:49:16 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E3=80=90improvement=E3=80=91clean=20code?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=95=B4=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/atat/core/utils.py | 5 + .../accuracy_tools/atat/pytorch/dump/dump.py | 455 ------------------ .../accuracy_tools/atat/pytorch/dump/utils.py | 357 -------------- .../free_benchmark/compare/grad_saver.py | 5 +- .../compare/single_benchmark.py | 81 ++-- .../atat/pytorch/free_benchmark/main.py | 8 +- .../perturbed_layers/npu/add_noise.py | 70 +-- .../perturbed_layers/npu/bit_noise.py | 89 ++-- .../perturbed_layers/npu/change_value.py | 29 +- .../perturbed_layers/npu/improve_precision.py | 40 +- .../perturbed_layers/npu/no_change.py | 6 +- .../perturbed_layers/npu/npu_base_layser.py | 30 +- .../result_handlers/base_handler.py | 64 +-- .../result_handlers/preheat_handler.py | 150 +++--- .../atat/pytorch/functional/data_collector.py | 70 +-- .../atat/pytorch/functional/data_processor.py | 252 +++++----- .../atat/pytorch/functional/json_writer.py | 61 ++- .../atat/pytorch/functional/scope.py | 51 +- .../atat/pytorch/hook_module/api_registry.py | 15 +- .../atat/pytorch/hook_module/wrap_vf.py | 2 - .../atat/pytorch/overflow_check/__init__.py | 0 .../atat/pytorch/overflow_check/info_dump.py | 252 ---------- .../pytorch/overflow_check/overflow_check.py | 190 -------- .../atat/pytorch/overflow_check/utils.py | 114 ----- 24 files changed, 529 insertions(+), 1867 deletions(-) delete mode 100644 debug/accuracy_tools/atat/pytorch/dump/dump.py delete mode 100644 debug/accuracy_tools/atat/pytorch/dump/utils.py delete mode 100644 debug/accuracy_tools/atat/pytorch/overflow_check/__init__.py delete mode 100644 debug/accuracy_tools/atat/pytorch/overflow_check/info_dump.py delete mode 100644 debug/accuracy_tools/atat/pytorch/overflow_check/overflow_check.py delete mode 100644 debug/accuracy_tools/atat/pytorch/overflow_check/utils.py diff --git a/debug/accuracy_tools/atat/core/utils.py b/debug/accuracy_tools/atat/core/utils.py index fdaa33e3ced..aed28a881d8 100644 --- a/debug/accuracy_tools/atat/core/utils.py +++ b/debug/accuracy_tools/atat/core/utils.py @@ -54,9 +54,12 @@ class Const: SUPPORT_DUMP_MODE = ['api', 'acl'] ON = 'ON' OFF = 'OFF' + MAX = 'Max' + MIN = 'Min' BACKWARD = 'backward' FORWARD = 'forward' PRE_FORWARD = "pre_forward" + DATA = 'data' # dump mode ALL = "all" @@ -105,6 +108,8 @@ class Const: OVERFLOW_CHECK = "overflow_check" FREE_BENCHMARK = "free_benchmark" + ATTR_NAME_PREFIX = "wrap_" + class CompareConst: """ Class for compare module const diff --git a/debug/accuracy_tools/atat/pytorch/dump/dump.py b/debug/accuracy_tools/atat/pytorch/dump/dump.py deleted file mode 100644 index 64652bdaec5..00000000000 --- a/debug/accuracy_tools/atat/pytorch/dump/dump.py +++ /dev/null @@ -1,455 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import inspect -import json -import os -import threading -from pathlib import Path - -import numpy as np -import torch - -try: - import torch_npu -except ImportError: - is_gpu = True -else: - is_gpu = False - -from atat.core.utils import (print_warn_log, Const, print_info_log, modify_dump_path, check_inplace_op, CompareConst, - print_error_log) -from atat.core.file_check_util import FileOpen, change_mode, FileCheckConst -from atat.pytorch.common.utils import get_md5_for_tensor -from ..dump.utils import check_writable -from .utils import (DumpUtil, check_if_in_api_list, make_dump_data_dir, get_tensor_rank, create_dirs_if_not_exist, - CompareException, check_single_rank_folder) - - -forward_init_status = False -backward_init_status = False - -thread_lock = threading.Lock() -pkl_name = "" -rank = os.getpid() + 100000 -multi_output_apis = ["_sort_", "npu_flash_attention"] -module_count = {} - - -class APIList(list): - threshold = 1000 - - def __init__(self, *args): - self.dump_count = 0 - self.pkl_mode_changed = False - super().__init__(*args) - - def flush(self): - pkl_path = get_pkl_file_path() - if len(self) == 0 or pkl_path == "": - return - with FileOpen(pkl_path, 'a') as f: - try: - f.write('\n'.join(json.dumps(item) for item in self)) - f.write('\n') - except IOError as ex: - raise Exception("write to disk failed") from ex - self.dump_count += 1 - print_info_log(f"write {len(self)} items to {pkl_path} the {self.dump_count} time") - if not self.pkl_mode_changed: - change_mode(pkl_path, FileCheckConst.DATA_FILE_AUTHORITY) - self.pkl_mode_changed = True - self.clear() - - def append(self, data): - list.append(self, data) - if len(self) >= APIList.threshold: - self.flush() - - -api_list = APIList() - - -class DataInfo(object): - def __init__(self, save_data, summary_data, dtype, shape, md5=None): - if md5 is None: - md5 = [] - self.save_data = save_data - self.summary_data = summary_data - self.dtype = dtype - self.shape = shape - self.md5 = md5 - - -def get_not_float_tensor_info(data): - if DumpUtil.summary_mode == "md5": - return DataInfo([], [], str(data.dtype), tuple(data.shape), get_md5_for_tensor(data)) - if data.numel() == 0 or data.dtype == torch.bool: - tensor_max = [] - tensor_min = [] - tensor_mean = [] - elif len(data.shape) == 0: - item = data.float().item() - tensor_max = item - tensor_min = item - tensor_mean = item - else: - tensor_max = torch._C._VariableFunctionsClass.max(data).float().item() - tensor_min = torch._C._VariableFunctionsClass.min(data).float().item() - tensor_mean = torch._C._VariableFunctionsClass.mean(data.float()).float().item() - return get_tensor_data_info(data, tensor_max, tensor_min, tensor_mean, CompareConst.NAN) - - -def get_scalar_data_info(data): - summary_data = [data, data, data, data] - return DataInfo(data, summary_data, str(type(data)), str([])) - - -def get_float_tensor_info(data): - if DumpUtil.summary_mode == "md5": - return DataInfo([], [], str(data.dtype), tuple(data.shape), get_md5_for_tensor(data)) - tensor_max = torch._C._VariableFunctionsClass.max(data).float().item() - tensor_min = torch._C._VariableFunctionsClass.min(data).float().item() - tensor_mean = torch._C._VariableFunctionsClass.mean(data).float().item() - tensor_norm = torch._C._VariableFunctionsClass.norm(data).float().item() - return get_tensor_data_info(data, tensor_max, tensor_min, tensor_mean, tensor_norm) - - -def get_tensor_data_info(data, *tensor_args): - summary_data = [] - summary_data.extend([*tensor_args]) - if DumpUtil.summary_mode == "all": - saved_tensor = data.contiguous().cpu().detach() - if data.dtype == torch.bfloat16: - saved_numpy = saved_tensor.to(torch.float32).numpy() - else: - saved_numpy = saved_tensor.numpy() - return DataInfo(saved_numpy, summary_data, str(data.dtype), tuple(data.shape)) - return DataInfo([], summary_data, str(data.dtype), tuple(data.shape)) - - -def dump_tensor(x, prefix, dump_step): - if isinstance(x, (tuple, list)) and x: - for i, item in enumerate(x): - dump_tensor(item, "{}.{}".format(prefix, i), dump_step) - return - elif isinstance(x, torch.Tensor): - if x.is_meta: - print_info_log(f"Meta tensor {prefix} is skipped.") - return - x_clone = x.clone().detach() - if x_clone.numel() == 0 or len(x_clone.shape) == 0 or not x_clone.is_floating_point(): - if DumpUtil.dump_filter_switch == Const.OFF: - data_info = get_not_float_tensor_info(x_clone) - dump_data_by_rank_count(dump_step, prefix, data_info) - else: - return - else: - data_info = get_float_tensor_info(x_clone) - dump_data_by_rank_count(dump_step, prefix, data_info) - - elif DumpUtil.dump_filter_switch == Const.OFF: - if isinstance(x, bool) or isinstance(x, int) or isinstance(x, float): - data_info = get_scalar_data_info(x) - dump_data_by_rank_count(dump_step, prefix, data_info) - - -def append_pkl_data(dump_step, prefix, data_info): - global api_list - thread_lock.acquire() - api_list.append([prefix, dump_step, data_info.md5, data_info.dtype, data_info.shape, data_info.summary_data]) - thread_lock.release() - - -def dump_data(prefix, data_info): - if DumpUtil.summary_mode != "all": - return - output_path = os.path.join(DumpUtil.dump_data_dir, f'{prefix}.npy') - try: - np.save(output_path, data_info.save_data) - change_mode(output_path, FileCheckConst.DATA_FILE_AUTHORITY) - except Exception as e: - print_warn_log("Dump data failed, error: {}".format(e)) - - -def thread_dump_data(prefix, data_info): - DumpUtil.dump_thread_pool.submit(dump_data, prefix, data_info) - - -def dump_data_by_rank_count(dump_step, prefix, data_info): - print_info_log(f"ptdbg is analyzing rank{rank} api: {prefix}" + " " * 10, end='\r') - if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: - thread_dump_data(prefix, data_info) - else: - dump_data(prefix, data_info) - append_pkl_data(dump_step, prefix, data_info) - - -def dump_stack_info(name_template): - if check_inplace_op(name_template) and Const.PRE_FORWARD in name_template: - return - - stack_str = [] - try: - for (_, path, line, func, code, _) in inspect.stack()[4:]: - if code: - stack_line = [path, str(line), func, code[0].strip() if code else code] - else: - stack_line = [path, str(line), func, code] - stack_str.append(stack_line) - except Exception as e: - print_warn_log("Dump stack info failed, error: {}".format(e)) - stack_str.append('') - - prefix = name_template.format("stack_info") - if DumpUtil.dump_switch_mode in Const.DUMP_MODE: - complement_set = set(['forward', 'backward', 'input', 'output']) - set(DumpUtil.dump_mode) - if not any(mode in prefix for mode in complement_set): - api_list.append([prefix, stack_str]) - else: - api_list.append([prefix, stack_str]) - - -def dump_api_tensor(dump_step, in_feat, name_template, out_feat): - if check_inplace_op(name_template): - if Const.PRE_FORWARD in name_template: - name_template = name_template.replace(Const.PRE_FORWARD, Const.FORWARD) - else: - if Const.BACKWARD in name_template and Const.BACKWARD in DumpUtil.dump_mode: - return - elif Const.BACKWARD not in name_template and Const.FORWARD in DumpUtil.dump_mode: - if "output" in DumpUtil.dump_mode: - dump_tensor(in_feat, name_template.format("output"), dump_step) - if "input" in DumpUtil.dump_mode: - return - - if Const.BACKWARD in name_template and Const.BACKWARD in DumpUtil.dump_mode: - if 'input' in DumpUtil.dump_mode: - dump_tensor(out_feat, name_template.format("input"), dump_step) - if 'output' in DumpUtil.dump_mode: - dump_tensor(in_feat, name_template.format("output"), dump_step) - elif Const.BACKWARD not in name_template and Const.FORWARD in DumpUtil.dump_mode: - if 'input' in DumpUtil.dump_mode: - dump_tensor(in_feat, name_template.format("input"), dump_step) - if 'output' in DumpUtil.dump_mode: - dump_tensor(out_feat, name_template.format("output"), dump_step) - - -def rename_(): - global rank - global pkl_name - if rank is not None and pkl_name is not None: - dir_name = os.path.join(DumpUtil.dump_root, "step{}".format(DumpUtil.iter_num), "rank{}".format(os.getpid() + 100000)) - new_name = os.path.join(DumpUtil.dump_root, "step{}".format(DumpUtil.iter_num), "rank{}".format(rank)) - if not os.path.exists(new_name) and os.path.exists(dir_name): - _, file_name = os.path.split(pkl_name) - os.rename(dir_name, new_name) - pkl_name = os.path.join(new_name, file_name) - - -def dump_acc_cmp(name, in_feat, out_feat, dump_step, module): - if not DumpUtil.get_dump_switch(): - return - if DumpUtil.dump_switch_mode == Const.API_LIST and not check_if_in_api_list(name): - return - if DumpUtil.dump_switch_mode in [Const.LIST, Const.ACL, Const.RANGE, Const.STACK] and not DumpUtil.check_switch_scope(name): - return - dump_file = DumpUtil.get_dump_path() - dump_file = modify_dump_path(dump_file, DumpUtil.dump_switch_mode) - global rank - dump_dir, dump_filename = os.path.split(dump_file) - dump_dir = os.path.join(dump_dir, "step{}".format(DumpUtil.iter_num)) - if not os.path.exists(dump_dir): - Path(dump_dir).mkdir(mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) - dump_file = os.path.join(dump_dir, dump_filename) - rank_this = get_tensor_rank(in_feat, out_feat) - DumpUtil.dump_root = os.path.dirname(DumpUtil.dump_path) - if rank_this is not None and rank != rank_this: - rank = rank_this - rename_() - if not DumpUtil.dump_init_enable: - if '.pkl' in dump_filename: - npy_dir = dump_filename[:-4] - else: - npy_dir = dump_filename - DumpUtil.dump_data_dir = os.path.join(DumpUtil.dump_root, "step{}".format(DumpUtil.iter_num), "rank{}".format(rank), npy_dir) - if DumpUtil.target_rank is not None: - if rank != DumpUtil.target_rank: - return - dump_file = create_dirs_if_not_exist(rank, dump_file) - global pkl_name - pkl_name = dump_file - if DumpUtil.dump_init_enable: - DumpUtil.dump_init_enable = False - DumpUtil.dump_data_dir = make_dump_data_dir(dump_file) \ - if DumpUtil.dump_switch_mode not in [Const.STACK, Const.ACL] and DumpUtil.summary_mode == "all" else "" - if os.path.exists(dump_file) and not os.path.isdir(dump_file): - check_writable(dump_file) - try: - os.remove(dump_file) - except FileNotFoundError as e: - print_warn_log("The file does not exist, error: {}".format(e)) - - name_prefix = name - name_template = f"{name_prefix}" + "_{}" - if DumpUtil.is_single_rank is None: - DumpUtil.is_single_rank = check_single_rank_folder(dump_dir) - if DumpUtil.dump_switch_mode in [Const.ALL, Const.API_LIST]: - dump_api_tensor(dump_step, in_feat, name_template, out_feat) - elif DumpUtil.dump_switch_mode == Const.API_STACK: - dump_api_tensor(dump_step, in_feat, name_template, out_feat) - dump_stack_info(name_template) - else: - if DumpUtil.dump_switch_mode == Const.ACL: - acl_dump(module, name, name_prefix) - elif DumpUtil.dump_switch_mode != Const.STACK: - dump_api_tensor(dump_step, in_feat, name_template, out_feat) - dump_stack_info(name_template) - - -def acl_dump(module, module_name, name_prefix): - if name_prefix in DumpUtil.backward_input: - dump_mode_backward_acl_dump(module, module_name, DumpUtil.backward_input.get(name_prefix)) - else: - forward_acl_dump(module, module_name) - - -def Op_Need_Trigger(module_name): - if 'Tensor.__getitem__.' in module_name: - return True - return False - - -def forward_acl_dump(module, module_name): - global forward_init_status - global backward_init_status - if not forward_init_status and not backward_init_status: - forward_init_status = True - torch_npu.npu.synchronize() - torch_npu.npu.init_dump() - torch_npu.npu.set_dump(DumpUtil.dump_config) - torch_npu.npu.synchronize() - if Op_Need_Trigger(module_name): - module.forward(*module.input_args, **module.input_kwargs).cpu() - else: - module.forward(*module.input_args, **module.input_kwargs) - torch_npu.npu.synchronize() - torch_npu.npu.finalize_dump() - torch_npu.npu.synchronize() - del module.input_args - del module.input_kwargs - forward_init_status = False - print_info_log("Dump %s op file." % module_name) - - -def acl_backward_dump_status(output, grad, module_name): - if isinstance(output, torch.Tensor): - output.backward(grad, retain_graph=True) - return True - - for api_name in multi_output_apis: - if api_name in module_name: - output[0].backward(grad, retain_graph=True) - return True - return False - - -def dump_mode_backward_acl_dump(module, module_name, grad_path): - global forward_init_status - global backward_init_status - module_name = module_name.replace(Const.FORWARD, Const.BACKWARD) - if not forward_init_status and not backward_init_status: - forward_init_status = True - module.input_args = list(module.input_args) - for i, data in enumerate(module.input_args): - if isinstance(data, torch.Tensor) and data.grad_fn: - module.input_args[i] = data.detach().requires_grad_() - output = module.forward(*module.input_args, **module.input_kwargs) - grad = torch.tensor(np.load(grad_path)).to("npu").requires_grad_() - torch_npu.npu.init_dump() - torch_npu.npu.set_dump(DumpUtil.dump_config) - torch_npu.npu.synchronize() - if not acl_backward_dump_status(output, grad, module_name): - print_warn_log("The output of {} is not of tensor type and cannot be automatically derived. " - "you can manually construct a single API backward case for ACL dump.".format(module_name)) - torch_npu.npu.synchronize() - torch_npu.npu.finalize_dump() - del module.input_args - del module.input_kwargs - forward_init_status = False - print_info_log("Dump %s op file." % module_name) - - -def module_count_func(name, name_template): - module_name = name.split("_")[-3] - if Const.FORWARD in name_template: - if module_name not in module_count: - module_count[module_name] = [0, [0]] - else: - if module_count[module_name][-1] and \ - module_count[module_name][0] != module_count[module_name][-1][-1]: - module_count[module_name][-1].pop() - module_count[module_name][0] += 1 - module_count[module_name][-1].append(module_count[module_name][0]) - index = module_count[module_name][0] - else: - backward_stack = module_count[module_name][-1] if module_name in module_count else [] - if not backward_stack: - print_warn_log("The backward stack of {} is empty.".format(module_name)) - index = "abnormal" - else: - index = backward_stack.pop() - return index - - -def acc_cmp_dump(name, **kwargs): - dump_step = kwargs.get('dump_step', 1) - pid = kwargs.get('pid') - name_template = name - if not pid: - return RuntimeError("Not get the specified process pid.") - - def acc_cmp_hook(module, in_feat, out_feat=None): - nonlocal name, name_template - if "_{}_" in name_template: - try: - index = module_count_func(name, name_template) - except IndexError as e: - print_error_log(f"Get module {name_template} index failed.") - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - name = name_template.format(index) - if pid == os.getpid(): - dump_acc_cmp(name, in_feat, out_feat, dump_step, module) - if hasattr(module, "input_args"): - del module.input_args - if hasattr(module, "input_kwargs"): - del module.input_kwargs - - return acc_cmp_hook - - -def write_to_disk(): - api_list.flush() - - -def get_pkl_file_path(): - return pkl_name - - -def reset_module_count(): - global module_count - module_count = {} diff --git a/debug/accuracy_tools/atat/pytorch/dump/utils.py b/debug/accuracy_tools/atat/pytorch/dump/utils.py deleted file mode 100644 index 8e58f35606a..00000000000 --- a/debug/accuracy_tools/atat/pytorch/dump/utils.py +++ /dev/null @@ -1,357 +0,0 @@ -import os -import re -import shutil -from pathlib import Path -import torch -import torch.distributed as dist - -from atat.core.utils import print_error_log, CompareException, DumpException, Const, get_time, print_info_log, \ - check_mode_valid, check_switch_valid, check_dump_mode_valid, check_summary_only_valid, generate_compare_script, \ - check_file_valid, make_dump_path_if_not_exists, check_path_before_create, check_summary_mode_valid -from atat.core.file_check_util import FileChecker, FileCheckConst, check_path_length, check_path_pattern_vaild -from atat.pytorch.common.utils import check_is_npu - -from ..dump import dump - -dump_count = 0 -range_begin_flag, range_end_flag = False, False - - -def check_list_or_acl_mode(name_prefix): - global dump_count - for item in DumpUtil.dump_switch_scope: - if name_prefix.startswith(item): - dump_count = dump_count + 1 - return True - return False - - -def check_range_mode(name_prefix): - global range_begin_flag - global range_end_flag - if name_prefix.startswith(DumpUtil.dump_switch_scope[0]): - range_begin_flag = True - return True - if name_prefix.startswith(DumpUtil.dump_switch_scope[1]): - range_end_flag = True - return True - if range_begin_flag and not range_end_flag: - return True - return False - - -def check_stack_mode(name_prefix): - if len(DumpUtil.dump_switch_scope) == 0: - return True - elif len(DumpUtil.dump_switch_scope) == 1: - return name_prefix.startswith(DumpUtil.dump_switch_scope[0]) - elif len(DumpUtil.dump_switch_scope) == 2: - return check_range_mode(name_prefix) - else: - print_error_log("dump scope is invalid, Please set the scope mode in" - " set_dump_switch with 'all', 'list', 'range', 'stack', 'acl', 'api_list'!") - return False - - -class DumpConfig: - def __init__(self, mode=None, scope=None, api_list=None, filter_switch=None, dump_mode=None, summary_only=False, summary_mode="all"): - self.mode = mode - self.scope = scope - self.api_list = api_list - self.filter_switch = filter_switch - self.dump_mode = dump_mode - self.summary_only = summary_only - self.summary_mode = summary_mode - - -class DumpUtil(object): - dump_root = None - dump_data_dir = None - dump_path = None - dump_switch = None - dump_switch_mode = Const.ALL # all, api_stack, list, stack... - dump_switch_scope = [] - dump_init_enable = False - dump_api_list = [] - dump_filter_switch = None - dump_mode = ['forward', 'backward', 'input', 'output'] - backward_input = {} - dump_dir_tag = 'ptdbg_dump' - dump_config = None - dataloader_iter = 0 - target_iter = None - iter_num = 0 - target_rank = None - summary_only = False - need_replicate = False - summary_mode = "all" - is_single_rank = None - dump_thread_pool = None - - - @staticmethod - def set_dump_path(save_path): - DumpUtil.dump_path = save_path - DumpUtil.dump_init_enable = True - - @staticmethod - def set_acl_config(acl_config): - if not acl_config: - raise ValueError("acl_config must be configured when mode is 'acl'") - acl_config_checker = FileChecker(acl_config, FileCheckConst.FILE, FileCheckConst.READ_ABLE, - FileCheckConst.JSON_SUFFIX) - acl_config = acl_config_checker.common_check() - DumpUtil.dump_config = acl_config - - @staticmethod - def set_dump_switch(switch, dump_config): - DumpUtil.dump_switch = switch - if dump_config.mode is not None: - DumpUtil.dump_switch_mode = dump_config.mode - DumpUtil.dump_init_enable = True - if dump_config.scope is not None: - DumpUtil.dump_switch_scope = dump_config.scope - if dump_config.api_list is not None: - DumpUtil.dump_api_list = [api.lower() for api in dump_config.api_list] - if dump_config.filter_switch is not None: - DumpUtil.dump_filter_switch = dump_config.filter_switch - if dump_config.dump_mode is not None: - DumpUtil.dump_mode = dump_config.dump_mode if isinstance(dump_config.dump_mode, list) else [dump_config.dump_mode] - - if dump_config.mode == Const.ACL: - DumpUtil.dump_switch_scope = [api_name.replace("backward", "forward") for api_name in dump_config.scope] - - DumpUtil.summary_only = dump_config.summary_only - DumpUtil.summary_mode = dump_config.summary_mode - - check_mapper = { - Const.LIST: check_list_or_acl_mode, - Const.ACL: check_list_or_acl_mode, - Const.RANGE: check_range_mode, - Const.STACK: check_stack_mode - } - - @staticmethod - def check_switch_scope(name_prefix): - if DumpUtil.dump_switch_mode in DumpUtil.check_mapper: - check_func = DumpUtil.check_mapper[DumpUtil.dump_switch_mode] - return check_func(name_prefix) - return False - - @staticmethod - def get_dump_path(): - if DumpUtil.dump_path: - return DumpUtil.dump_path - - if DumpUtil.dump_switch_mode == Const.ALL: - raise RuntimeError("get_dump_path: the file path is empty," - " you must use set_dump_path to set a valid dump path!!!") - else: - dir_path = os.path.realpath("./") - dump_file_name = "scope_dump_{}_{}_{}.pkl".format( - DumpUtil.dump_switch_mode, DumpUtil.dump_switch_scope[0], get_time()) - DumpUtil.dump_path = os.path.join(dir_path, dump_file_name) - return DumpUtil.dump_path - - @staticmethod - def get_dump_switch(): - return DumpUtil.dump_switch == "ON" - - -def set_dump_path(fpath=None, dump_tag='ptdbg_dump'): - fpath = load_env_dump_path(fpath) - check_file_valid(fpath) - if not re.match(Const.FILE_PATTERN, dump_tag): - print_error_log('The file path {} contains special characters.'.format(dump_tag)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - real_path = os.path.realpath(fpath) - make_dump_path_if_not_exists(real_path) - fpath_checker = FileChecker(real_path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) - fpath_checker.common_check() - DumpUtil.set_dump_path(real_path) - DumpUtil.dump_dir_tag = dump_tag - - -def get_tensor_rank(in_feat, out_feat): - if dist.is_initialized(): - return dist.get_rank() - - def get_tensor_rank_single(x): - if isinstance(x, (list, tuple)): - if len(x) > 0: - return get_tensor_rank_single(x[0]) - return None - elif isinstance(x, torch.Tensor): - device = x.device - if device.type == 'cpu': - return None - else: - return device.index - return None - in_rank = get_tensor_rank_single(in_feat) - if in_rank is None: - out_rank = get_tensor_rank_single(out_feat) - if out_rank is None: - return None - return out_rank - return in_rank - - -def create_dirs_if_not_exist(rank, dump_file): - dump_path, file_name = os.path.split(dump_file) - rank_dir = os.path.join(dump_path, f"rank{rank}") - dump_file = os.path.join(rank_dir, file_name) - if not os.path.isdir(rank_dir): - check_path_pattern_vaild(dump_file) - check_path_length(dump_file, name_length=200) - Path(rank_dir).mkdir(mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) - return dump_file - - -def generate_dump_path_str(): - if DumpUtil.dump_switch_mode == 'acl': - if DumpUtil.dump_config == '': - print_error_log("Please provide dump config for register hook before turning on dump switch!") - raise DumpException(DumpException.NONE_ERROR) - dump_path = f"according to dump config {DumpUtil.dump_config}" - else: - dump_dir, dump_file = os.path.split(DumpUtil.dump_path) - if not dump_file.endswith(".pkl"): - dump_dir = DumpUtil.dump_path - dump_path = f"to {dump_dir}" - return dump_path - - -def set_dump_switch(switch, mode=Const.ALL, scope=None, api_list=None, filter_switch=Const.OFF, dump_mode=None, - summary_only=False): - if scope is None: - scope = [] - if api_list is None: - api_list = [] - if dump_mode is None: - dump_mode = [Const.ALL] - check_switch_valid(switch) - if not DumpUtil.dump_path: - set_dump_path() - dump_config = DumpConfig(summary_only=summary_only) - DumpUtil.set_dump_switch(switch, dump_config) - dump_path_str = generate_dump_path_str() - if switch == "OFF": - dump.write_to_disk() - if check_is_npu() and DumpUtil.dump_switch_mode in [Const.ALL, Const.API_STACK, Const.LIST, Const.RANGE, Const.API_LIST]: - generate_compare_script(DumpUtil.dump_data_dir, dump.get_pkl_file_path(), DumpUtil.dump_switch_mode) - set_dump_switch_print_info(switch, mode, dump_path_str) - set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, dump_mode=dump_mode, - summary_only=summary_only) - - -def set_dump_switch_config(mode=Const.ALL, scope=None, api_list=None, filter_switch=Const.OFF, dump_mode=None, - summary_only=False, summary_mode="all"): - if scope is None: - scope = [] - if api_list is None: - api_list = [] - if dump_mode is None: - dump_mode = [Const.ALL] - try: - check_summary_mode_valid(summary_mode) - check_mode_valid(mode, scope, api_list) - check_switch_valid(filter_switch) - dump_mode = check_dump_mode_valid(dump_mode) - summary_only = check_summary_only_valid(summary_only) - except (CompareException, AssertionError) as err: - print_error_log(str(err)) - raise CompareException(CompareException.INVALID_PARAM_ERROR) from err - switch = DumpUtil.dump_switch - dump_config = DumpConfig(mode, scope, api_list, filter_switch, dump_mode, summary_only, summary_mode) - DumpUtil.set_dump_switch("OFF", dump_config) - DumpUtil.dump_switch = switch - - -def set_dump_switch_print_info(switch, mode, dump_path_str): - global dump_count - if switch == "ON": - print_info_log(f"Dump switch is turned on. Dump data will be saved {dump_path_str}. ") - if mode == Const.LIST: - dump_count = 0 - else: - print_info_log(f"Dump switch is turned off. ") - if mode == Const.LIST: - print_info_log("The number of matched dump is {}".format(dump_count)) - - -def check_if_in_api_list(name): - if not DumpUtil.dump_api_list: - return False - for api in DumpUtil.dump_api_list: - if api.lower() in name.lower(): - return True - return False - - -def set_backward_input(backward_input): - for index, api_name in enumerate(DumpUtil.dump_switch_scope): - DumpUtil.backward_input[api_name] = backward_input[index] - - -def make_dump_data_dir(dump_file_name): - dump_path, file_name = os.path.split(os.path.realpath(dump_file_name)) - name_body, name_extension = os.path.splitext(file_name) - output_dir = os.path.join(dump_path, f"{name_body}") - check_path_before_create(output_dir) - if not os.path.exists(output_dir): - Path(output_dir).mkdir(mode=0o750, exist_ok=True) - else: - shutil.rmtree(output_dir, ignore_errors=True) - Path(output_dir).mkdir(mode=0o750, exist_ok=True) - return output_dir - - -def make_dump_dirs(): - dump_file_name, dump_file_name_body = "dump.pkl", "dump" - dump_root_dir = load_env_dump_path(DumpUtil.dump_path) - tag_dir = os.path.join(dump_root_dir, DumpUtil.dump_dir_tag) - check_path_length(tag_dir) - check_path_pattern_vaild(tag_dir) - Path(tag_dir).mkdir(mode=0o750, parents=True, exist_ok=True) - DumpUtil.dump_dir = tag_dir - dump_file_path = os.path.join(tag_dir, dump_file_name) - DumpUtil.set_dump_path(dump_file_path) - - -def check_writable(dump_file): - if not os.access(dump_file, os.W_OK): - print_error_log( - 'The path {} does not have permission to write. Please check the path permission'.format( - dump_file)) - raise DumpException(DumpException.INVALID_PATH_ERROR) - - -def load_env_dump_path(dump_path): - if not dump_path: - dump_path = os.getenv(Const.ASCEND_WORK_PATH) - if dump_path: - try: - dump_path = os.path.join(str(dump_path), Const.DUMP_DIR) - except TypeError as err: - print_error_log("Generating dump path from environment variables ASCEND_WORK_PATH failed.") - raise DumpException(DumpException.INVALID_PATH_ERROR) from err - else: - print_error_log("Dump path is None, you can configure it in the following ways:\n" - "1. Configure set_dump_path function.\n" - "2. Configure the dump_path parameter of PrecisionDebugger.\n" - "3. Set environment variables ASCEND_WORK_PATH.") - raise DumpException(DumpException.INVALID_PATH_ERROR) - return dump_path - - -def check_single_rank_folder(dump_path): - rank_folder_pattern = re.compile(r'^rank\d+$') - rank_folder_count = 0 - for item in os.listdir(dump_path): - full_path = os.path.join(dump_path, item) - if os.path.isdir(full_path) and rank_folder_pattern.match(item): - rank_folder_count += 1 - if rank_folder_count > 1: - return False - return rank_folder_count == 1 diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py index a8752656ed7..c11855898de 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py @@ -1,12 +1,11 @@ import torch from atat.pytorch.free_benchmark import print_info_log_rank_0, print_warn_log_rank_0 -from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams from atat.pytorch.free_benchmark.common.constant import CommonField -from atat.pytorch.free_benchmark.common.utils import Tools +from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams +from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory from atat.pytorch.free_benchmark.result_handlers.handler_factory import ( FuzzHandlerFactory, ) -from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory class GradSaver: diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py index ed834c468ba..79c515b6ddf 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py @@ -1,9 +1,9 @@ -import torch import math +import torch from atat.pytorch.free_benchmark import print_warn_log_rank_0 -from atat.pytorch.free_benchmark.common.utils import TorchC from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.utils import TorchC class SingleCompare: @@ -13,6 +13,45 @@ class SingleCompare: self.eb = None self.threshold = None + @staticmethod + def compare_float_seq(actual, golden): + return math.isclose(actual, golden) + + @staticmethod + def compare_other_seq(actual, golden): + return actual == golden + + @staticmethod + def filter_overflow(tensor) -> int: + inf_num = TorchC.sum(TorchC.isinf(tensor)) + nan_num = TorchC.sum(TorchC.isnan(tensor)) + return inf_num + nan_num + + @staticmethod + def replace_inf_or_nan(tensor): + finite_mask = TorchC.isfinite(tensor) + inf_or_nan_mask = TorchC.logical_not(finite_mask) + inf_or_nan_num = TorchC.sum(inf_or_nan_mask).item() + if inf_or_nan_num > 0: + tensor[inf_or_nan_mask] = 1 + return tensor + + def compare_dict_seq(self, actual, golden): + if len(actual) != len(golden): + return False + for key, value in golden.items(): + if not self.compare_seq(value, actual.get(key)): + return False + return True + + def compare_list_seq(self, actual, golden): + if len(actual) != len(golden): + return False + for index_, value in enumerate(golden): + if not self.compare_seq(value, actual[index_]): + return False + return True + def compare_seq(self, actual, golden): if isinstance(golden, torch.Tensor): return self.compare_tensor_seq(actual, golden) @@ -45,7 +84,6 @@ class SingleCompare: return False return True - def _cal_compare_metrics(self, actual, golden): diff_value = TorchC.subtract(actual, golden) diff_abs = TorchC.abs(diff_value) @@ -64,40 +102,3 @@ class SingleCompare: TorchC.ge(TorchC.abs(golden), self.threshold.small_value), golden_abs, 1 ) self.eb = TorchC.mean(TorchC.div(diff_value, divided)) - - def compare_dict_seq(self, actual, golden): - if len(actual) != len(golden): - return False - for key, value in golden.items(): - if not self.compare_seq(value, actual.get(key)): - return False - return True - - def compare_list_seq(self, actual, golden): - if len(actual) != len(golden): - return False - for index_, value in enumerate(golden): - if not self.compare_seq(value, actual[index_]): - return False - return True - - def compare_float_seq(self, actual, golden): - return math.isclose(actual, golden) - - def compare_other_seq(self, actual, golden): - return actual == golden - - @staticmethod - def filter_overflow(tensor) -> int: - inf_num = TorchC.sum(TorchC.isinf(tensor)) - nan_num = TorchC.sum(TorchC.isnan(tensor)) - return inf_num + nan_num - - @staticmethod - def replace_inf_or_nan(tensor): - finite_mask = TorchC.isfinite(tensor) - inf_or_nan_mask = TorchC.logical_not(finite_mask) - inf_or_nan_num = TorchC.sum(inf_or_nan_mask).item() - if inf_or_nan_num > 0: - tensor[inf_or_nan_mask] = 1 - return tensor diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py index c2e0005181d..a51c05356ae 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py @@ -1,16 +1,14 @@ -import importlib from abc import ABC import torch from atat.pytorch.free_benchmark import Const, print_warn_log_rank_0 - -from atat.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params from atat.pytorch.free_benchmark.common.enums import ( PerturbationMode, FuzzLevel, DeviceType, HandlerType ) +from atat.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params from atat.pytorch.free_benchmark.compare.grad_saver import GradSaver from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory from atat.pytorch.free_benchmark.result_handlers.handler_factory import ( @@ -33,9 +31,9 @@ class FreeBenchmarkCheck(ABC): def update_iter(self, update_iter): self.current_iter = update_iter - + def if_fix(self): - if self.config.handler_type==HandlerType.FIX: + if self.config.handler_type == HandlerType.FIX: return True return False diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py index d03dbe931d9..0c11a5947de 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -4,15 +4,51 @@ from atat.pytorch.free_benchmark import ( print_warn_log_rank_0, ) from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.utils import TorchC -from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) class AddNoiseLayer(NpuBaseLayer): + def __init__(self, api_name): + super().__init__(api_name) + self.perturbed_value = None + self.is_added = False + + def add_noise(self, tensor_obj): + if isinstance(tensor_obj, torch.Tensor): + self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get( + tensor_obj.dtype + ) + if not self.pre_check(tensor_obj): + return tensor_obj + noise = self._get_noise(tensor_obj) + result = TorchC.where( + TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value ** 0.5), + TorchC.add(noise, tensor_obj), + tensor_obj, + ).to(tensor_obj.dtype) + self.is_added = True + return result + if isinstance(tensor_obj, dict): + return {key: self.add_noise(value) for key, value in tensor_obj.items()} + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)([self.add_noise(value) for value in tensor_obj]) + return tensor_obj + + def handle(self, params: DataParams) -> torch.Any: + """ + 对输入添加扰动并返回 + """ + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is " + f"{PerturbationMode.ADD_NOISE} of {self.api_name}." + ) + params.perturbed_value = self.add_noise(params.args[params.valid_input_index]) + return self.perturbed_result(params) def _get_noise(self, tensor_obj): dtype = tensor_obj.dtype @@ -59,35 +95,3 @@ class AddNoiseLayer(NpuBaseLayer): ) return False return True - - def add_noise(self, tensor_obj): - if isinstance(tensor_obj, torch.Tensor): - self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get( - tensor_obj.dtype - ) - if not self.pre_check(tensor_obj): - return tensor_obj - noise = self._get_noise(tensor_obj) - result = TorchC.where( - TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value**0.5), - TorchC.add(noise, tensor_obj), - tensor_obj, - ).to(tensor_obj.dtype) - self.is_added = True - return result - if isinstance(tensor_obj, dict): - return {key: self.add_noise(value) for key, value in tensor_obj.items()} - if isinstance(tensor_obj, (tuple, list)): - return type(tensor_obj)([self.add_noise(value) for value in tensor_obj]) - return tensor_obj - - def handle(self, params: DataParams) -> torch.Any: - """ - 对输入添加扰动并返回 - """ - print_info_log_rank_0( - f"[atat] Free benchmark: Perturbation is " - f"{PerturbationMode.ADD_NOISE} of {self.api_name}." - ) - params.perturbed_value = self.add_noise(params.args[params.valid_input_index]) - return self.perturbed_result(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py index 72d04af4120..f8e20626cd8 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -4,9 +4,9 @@ from atat.pytorch.free_benchmark import ( print_warn_log_rank_0, ) from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.utils import TorchC -from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -18,6 +18,50 @@ class BitNoiseLayer(NpuBaseLayer): self.bit_mode = TorchC.bitwise_xor self.bit_tail: int = 1 self.bit_type = None + self.is_added = False + + def add_bit_noise(self, tensor_obj): + """ + 对输入添加噪声 + """ + # finfo应该列入黑名单 + + if isinstance(tensor_obj, torch.Tensor): + self._set_perturbation_bit(tensor_obj) + if not self.pre_check(tensor_obj): + return tensor_obj + sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal + noise = TorchC.full( + tensor_obj.shape, + self.bit_tail, + device=tensor_obj.device, + dtype=self.bit_type, + ) + result = tensor_obj.view(self.bit_type) + result = TorchC.where( + TorchC.gt(TorchC.abs(tensor_obj), sub_normal), + self.bit_mode(result, noise), + result, + ).view(tensor_obj.dtype) + + self.is_added = True + return result + if isinstance(tensor_obj, dict): + return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()} + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj]) + return tensor_obj + + def handle(self, params: DataParams) -> torch.Any: + """ + 对输入添加扰动并返回 + """ + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is " + f"{PerturbationMode.BIT_NOISE} of {self.api_name}." + ) + params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index]) + return self.perturbed_result(params) def _check_details(self, tensor_obj): """ @@ -62,46 +106,3 @@ class BitNoiseLayer(NpuBaseLayer): if bit_len_type: self.bit_tail = 1 self.bit_type = bit_len_type - - def add_bit_noise(self, tensor_obj): - """ - 对输入添加噪声 - """ - # finfo应该列入黑名单 - - if isinstance(tensor_obj, torch.Tensor): - self._set_perturbation_bit(tensor_obj) - if not self.pre_check(tensor_obj): - return tensor_obj - sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal - noise = TorchC.full( - tensor_obj.shape, - self.bit_tail, - device=tensor_obj.device, - dtype=self.bit_type, - ) - result = tensor_obj.view(self.bit_type) - result = TorchC.where( - TorchC.gt(TorchC.abs(tensor_obj), sub_normal), - self.bit_mode(result, noise), - result, - ).view(tensor_obj.dtype) - - self.is_added = True - return result - if isinstance(tensor_obj, dict): - return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()} - if isinstance(tensor_obj, (tuple, list)): - return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj]) - return tensor_obj - - def handle(self, params: DataParams) -> torch.Any: - """ - 对输入添加扰动并返回 - """ - print_info_log_rank_0( - f"[atat] Free benchmark: Perturbation is " - f"{PerturbationMode.BIT_NOISE} of {self.api_name}." - ) - params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index]) - return self.perturbed_result(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py index ab91bcb7eee..510104cdaa1 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -1,8 +1,8 @@ import torch from atat.pytorch.free_benchmark import print_warn_log_rank_0, print_info_log_rank_0 +from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.utils import TorchC -from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -13,18 +13,7 @@ class ChangeValueLayer(NpuBaseLayer): super().__init__(api_name) self.head: int = 0 self.tail: int = -1 - - def _check_details(self, tensor_obj): - """ - 判断是否需要添加扰动, 首尾值交换 - """ - if tensor_obj.size(0) < 2: - print_warn_log_rank_0( - f"[atat] Free Benchmark: For {self.api_name}, " - f"size 0 must greater than 1. Cancel change value." - ) - return False - return True + self.is_added = False def change_value(self, tensor_obj): """ @@ -42,7 +31,7 @@ class ChangeValueLayer(NpuBaseLayer): temp_last = TorchC.clone(new_tensor[self.tail][self.tail]) new_tensor[self.head][self.head] = temp_last new_tensor[self.tail][self.tail] = temp_first - + self.is_added = True return new_tensor if isinstance(tensor_obj, dict): @@ -61,3 +50,15 @@ class ChangeValueLayer(NpuBaseLayer): ) params.perturbed_value = self.change_value(params.args[params.valid_input_index]) return self.perturbed_result(params) + + def _check_details(self, tensor_obj): + """ + 判断是否需要添加扰动, 首尾值交换 + """ + if tensor_obj.size(0) < 2: + print_warn_log_rank_0( + f"[atat] Free Benchmark: For {self.api_name}, " + f"size 0 must greater than 1. Cancel change value." + ) + return False + return True diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index fb126972c68..a650af8755e 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -1,8 +1,8 @@ import torch from atat.pytorch.free_benchmark import Const, print_info_log_rank_0 from atat.pytorch.free_benchmark.common.constant import CommonField -from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.enums import PerturbationMode +from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -10,23 +10,9 @@ from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( class ImprovePrecisionLayer(NpuBaseLayer): - def _set_improve_valus(self, inputs): - # TODO why - if inputs.dtype in [torch.float16, torch.bfloat16]: - self.perturbed_value = torch.float32 - - def _change_dtype(self, inputs): - if hasattr(inputs, CommonField.DEVICE): - device = inputs.device - if device is CommonField.META: - new_inputs = inputs.to( - device=CommonField.META, dtype=self.perturbed_value - ) - else: - new_inputs = inputs.to(dtype=self.perturbed_value).to(device) - else: - new_inputs = inputs.to(dtype=self.perturbed_value) - return new_inputs + def __init__(self, api_name): + super().__init__(api_name) + self.perturbed_value = None def improve_tensor_precision(self, tensor_obj): if ( @@ -62,3 +48,21 @@ class ImprovePrecisionLayer(NpuBaseLayer): new_kwargs["inplace"] = False params.perturbed_result = params.origin_func(*new_args, **new_kwargs) return params.perturbed_result + + def _set_improve_valus(self, inputs): + # TODO why + if inputs.dtype in [torch.float16, torch.bfloat16]: + self.perturbed_value = torch.float32 + + def _change_dtype(self, inputs): + if hasattr(inputs, CommonField.DEVICE): + device = inputs.device + if device is CommonField.META: + new_inputs = inputs.to( + device=CommonField.META, dtype=self.perturbed_value + ) + else: + new_inputs = inputs.to(dtype=self.perturbed_value).to(device) + else: + new_inputs = inputs.to(dtype=self.perturbed_value) + return new_inputs diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py index 7ec5870fb72..ecd133dba8e 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py @@ -1,13 +1,16 @@ import torch from atat.pytorch.free_benchmark import print_info_log_rank_0 -from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.enums import PerturbationMode +from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) class NoChangeLayer(NpuBaseLayer): + def __init__(self, api_name): + super().__init__(api_name) + self.is_added = False def no_change(self, tensor_obj): """ @@ -16,7 +19,6 @@ class NoChangeLayer(NpuBaseLayer): self.is_added = True return tensor_obj - def handle(self, params: DataParams) -> torch.Any: """ 对输入添加扰动并返回 diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py index ca502365e1b..f45ebd358a8 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py @@ -1,8 +1,7 @@ from abc import abstractmethod from typing import Any + import torch -from atat.pytorch.free_benchmark.common.constant import CommonField, ThresholdConfig -from atat.pytorch.free_benchmark.common.utils import TorchC from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer @@ -17,7 +16,20 @@ class NpuBaseLayer(BaseLayer): def handle(self, params: DataParams) -> Any: pass - def _check_details(self, tensor_obj): + @staticmethod + def perturbed_result(params: DataParams) -> Any: + args_front = params.args[: params.valid_input_index] + args_rear = params.args[params.valid_input_index + 1:] + # 此处会将有inplace属性的算子换为非inplace + if "inplace" in params.kwargs: + params.kwargs["inplace"] = False + params.perturbed_result = params.origin_func( + *args_front, params.perturbed_value, *args_rear, **params.kwargs + ) + return params.perturbed_result + + @staticmethod + def _check_details(tensor_obj): return True def pre_check(self, tensor_obj): @@ -32,15 +44,3 @@ class NpuBaseLayer(BaseLayer): if not self._check_details(tensor_obj): return False return True - - @staticmethod - def perturbed_result(params: DataParams) -> Any: - args_front = params.args[: params.valid_input_index] - args_rear = params.args[params.valid_input_index + 1 :] - # 此处会将有inplace属性的算子换为非inplace - if "inplace" in params.kwargs: - params.kwargs["inplace"] = False - params.perturbed_result = params.origin_func( - *args_front, params.perturbed_value, *args_rear, **params.kwargs - ) - return params.perturbed_result diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py index 1d59ef9fc3a..b27250e6ca5 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py @@ -7,7 +7,6 @@ from atat.pytorch.free_benchmark import ( Const, print_warn_log_rank_0, ) -from atat.pytorch.free_benchmark.common.utils import TorchC from atat.pytorch.free_benchmark.common.constant import ThresholdConfig from atat.pytorch.free_benchmark.common.enums import ( FuzzThreshold, @@ -15,6 +14,7 @@ from atat.pytorch.free_benchmark.common.enums import ( PerturbationMode, ) from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams, make_unequal_row +from atat.pytorch.free_benchmark.common.utils import TorchC class FuzzHandler(ABC): @@ -25,9 +25,9 @@ class FuzzHandler(ABC): @staticmethod def pre_process(origin_ouput, perturbed_output): if ( - isinstance(origin_ouput, tuple) - and hasattr(origin_ouput, "values") - and hasattr(origin_ouput, "indices") + isinstance(origin_ouput, tuple) + and hasattr(origin_ouput, "values") + and hasattr(origin_ouput, "indices") ): origin_ouput = origin_ouput.values perturbed_output = perturbed_output.values @@ -41,19 +41,27 @@ class FuzzHandler(ABC): abs_tol, ) - def get_ratio_from_specific_norm( - self, origin_output, perturbed_output, norm_type, abs_tol - ): - if norm_type == NormType.ENDLESS_NORM: - return self.get_endless_norm(origin_output, perturbed_output, abs_tol) - return ThresholdConfig.COMP_CONSISTENT - @staticmethod def convert_overflow_ratio_to_consistent(ratio): if math.isnan(ratio) or math.isinf(ratio): return ThresholdConfig.COMP_CONSISTENT return ratio + @abstractmethod + def get_threshold(self, dtype): + pass + + @abstractmethod + def handle(self, data_params: DataParams) -> Any: + pass + + def get_ratio_from_specific_norm( + self, origin_output, perturbed_output, norm_type, abs_tol + ): + if norm_type == NormType.ENDLESS_NORM: + return self.get_endless_norm(origin_output, perturbed_output, abs_tol) + return ThresholdConfig.COMP_CONSISTENT + def get_endless_norm(self, origin_output, perturbed_output, abs_tol): try: ratio_tensor1 = TorchC.where( @@ -72,7 +80,7 @@ class FuzzHandler(ABC): ), 1, ) - except: + except Exception: ratio_tensor1 = TorchC.where( TorchC.gt(TorchC.abs(perturbed_output.to(torch.float32)), abs_tol), TorchC.div( @@ -117,26 +125,13 @@ class FuzzHandler(ABC): if self.params.fuzz_stage == Const.BACKWARD: abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND else: - abs_tol = abs_tol**0.5 + abs_tol = abs_tol ** 0.5 return self.get_ratio_from_specific_norm( origin_output, perturbed_output, norm_type, abs_tol ) - @abstractmethod - def get_threshold(self, dtype): - pass - - def _get_default_threshold(self, dtype): - if self.params.pert_mode == PerturbationMode.NO_CHANGE: - threshold = ThresholdConfig.COMP_CONSISTENT - else: - threshold = ThresholdConfig.DTYPE_PER_THD.get( - dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32) - ) - return threshold - def npu_compare( - self, origin_output, perturbed_output + self, origin_output, perturbed_output ) -> Tuple[bool, Optional[float]]: if isinstance(perturbed_output, int): @@ -190,7 +185,7 @@ class FuzzHandler(ABC): max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio) ) data_params.is_consistent = ( - is_consistent and data_params.is_consistent + is_consistent and data_params.is_consistent ) if not is_consistent and data_params.grad_unequal_flag: self.unequal_rows.append( @@ -205,9 +200,14 @@ class FuzzHandler(ABC): ) return npu_consistent, max_fuzz_ratio - @abstractmethod - def handle(self, data_params: DataParams) -> Any: - pass - def get_unequal_rows(self): return self.unequal_rows + + def _get_default_threshold(self, dtype): + if self.params.pert_mode == PerturbationMode.NO_CHANGE: + threshold = ThresholdConfig.COMP_CONSISTENT + else: + threshold = ThresholdConfig.DTYPE_PER_THD.get( + dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32) + ) + return threshold diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py index b8ff3bccf00..c48c374dc8c 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py @@ -1,16 +1,15 @@ +import math from typing import Any -import torch -import math from atat.pytorch.free_benchmark import print_info_log_rank_0, print_warn_log_rank_0 from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.counter import preheat_counter from atat.pytorch.free_benchmark.common.enums import DeviceType -from atat.pytorch.free_benchmark.common.params import DataParams, make_unequal_row +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.common.params import HandlerParams from atat.pytorch.free_benchmark.common.utils import Tools from atat.pytorch.free_benchmark.compare.single_benchmark import SingleCompare -from atat.pytorch.free_benchmark.common.counter import preheat_counter from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler -from atat.pytorch.free_benchmark.common.params import HandlerParams class PreheatHandler(FuzzHandler): @@ -19,9 +18,79 @@ class PreheatHandler(FuzzHandler): super().__init__(params) self.pure_name = Tools.get_pure_api_name(self.params.api_name) + @staticmethod + def compare_npu_and_cpu(data_params: DataParams): + args = Tools.convert_device_and_dtype( + data_params.args, DeviceType.CPU, change_dtype=True + ) + kwargs = Tools.convert_device_and_dtype( + data_params.kwargs, DeviceType.CPU, change_dtype=True + ) + cpu_result = data_params.origin_func(*args, **kwargs) + return SingleCompare().compare_seq(data_params.original_result, cpu_result) + def get_threshold(self, dtype): return preheat_counter.get_api_thd(self.pure_name, dtype) + def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype): + # 存储当前step所有输出比值和对应npu\cpu比对结果 + preheat_counter.update_preheat_record( + self.pure_name, + first_dtype, + (max_fuzz_ratio, cpu_consistent), + ) + if self._need_adjust_threshold(): + self._adjust_threshold() + + def handle(self, data_params: DataParams) -> Any: + + if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor( + data_params.perturbed_result + ): + return data_params.original_result + + if self.params.step == 0: + preheat_counter.add_one_step_used_api(self.pure_name) + return data_params.original_result + + # 如果当前api,step需要预热 + npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params) + data_params.is_consistent = npu_consistent + + preheat_counter.check_step(self.params.step) + + if self.params.preheat_config.get("preheat_step") <= self.params.step: + return data_params.original_result + + if not data_params.grad_unequal_flag: + data_params.grad_unequal_flag = True + data_params.is_consistent = False + return data_params.original_result + preheat_counter.add_api_called_time(self.pure_name) + + if not self._is_take_a_sample(): + return data_params.original_result + + cpu_consistent = True + try: + cpu_consistent = self.compare_npu_and_cpu(data_params) + except Exception as e: + print_warn_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name}, " + f"when campare to cpu exception raise {e}" + ) + try: + first_dtype = Tools.get_first_tensor_dtype(data_params.perturbed_result) + except RuntimeError: + print_warn_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name}, " + f"the output sequence does not contain tensors." + ) + if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)): + self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype) + + return data_params.original_result + def _is_take_a_sample(self) -> bool: need_sample_set = self._get_need_sample_set() curr_called_seq = preheat_counter.get_api_called_time(self.pure_name) @@ -61,17 +130,6 @@ class PreheatHandler(FuzzHandler): need_sample_set.add(count) return need_sample_set - - def compare_npu_and_cpu(self, data_params: DataParams): - args = Tools.convert_device_and_dtype( - data_params.args, DeviceType.CPU, change_dtype=True - ) - kwargs = Tools.convert_device_and_dtype( - data_params.kwargs, DeviceType.CPU, change_dtype=True - ) - cpu_result = data_params.origin_func(*args, **kwargs) - return SingleCompare().compare_seq(data_params.original_result, cpu_result) - def _need_adjust_threshold(self) -> bool: sample_count_per_step = self._get_sample_count_per_step() sampled_time = preheat_counter.get_api_sample_time(self.pure_name) @@ -112,63 +170,3 @@ class PreheatHandler(FuzzHandler): preheat_counter.update_api_thd( self.pure_name, dtype_str, new_thd, threshold ) - - def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype): - # 存储当前step所有输出比值和对应npu\cpu比对结果 - preheat_counter.update_preheat_record( - self.pure_name, - first_dtype, - (max_fuzz_ratio, cpu_consistent), - ) - if self._need_adjust_threshold(): - self._adjust_threshold() - - def handle(self, data_params: DataParams) -> Any: - - if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor( - data_params.perturbed_result - ): - return data_params.original_result - - if self.params.step == 0: - preheat_counter.add_one_step_used_api(self.pure_name) - return data_params.original_result - - # 如果当前api,step需要预热 - npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params) - data_params.is_consistent = npu_consistent - - preheat_counter.check_step(self.params.step) - - if self.params.preheat_config.get("preheat_step") <= self.params.step: - return data_params.original_result - - if not data_params.grad_unequal_flag: - data_params.grad_unequal_flag = True - data_params.is_consistent = False - return data_params.original_result - preheat_counter.add_api_called_time(self.pure_name) - - - if not self._is_take_a_sample(): - return data_params.original_result - - cpu_consistent = True - try: - cpu_consistent = self.compare_npu_and_cpu(data_params) - except Exception as e: - print_warn_log_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name}, " - f"when campare to cpu exception raise {e}" - ) - try: - first_dtype = Tools.get_first_tensor_dtype(data_params.perturbed_result) - except RuntimeError: - print_warn_log_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name}, " - f"the output sequence does not contain tensors." - ) - if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)): - self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype) - - return data_params.original_result diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_collector.py b/debug/accuracy_tools/atat/pytorch/functional/data_collector.py index 7964c955db6..2cf2cf389c8 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/data_collector.py +++ b/debug/accuracy_tools/atat/pytorch/functional/data_collector.py @@ -1,11 +1,13 @@ import os + import torch -from ..module_processer import ModuleProcesser -from .scope import build_scope, ListScope + +from .data_processor import build_data_processor, DataProcessor from .json_writer import DataWriter +from .scope import build_scope, ListScope from ..common.log import print_info_log, print_warn_log from ..common.utils import Const -from .data_processor import build_data_processor, DataProcessor +from ..module_processer import ModuleProcesser try: import torch_npu @@ -37,12 +39,6 @@ class DataCollector: else: self.scope = build_scope(None, self.config.scope, self.config.list) - def if_return_forward_new_output(self): - return self.data_processor.if_return_forward_new_output() - - def get_forward_new_output(self): - return self.data_processor.get_forward_new_output() - @property def dump_data_dir(self): return self.data_writer.dump_tensor_data_dir @@ -51,6 +47,38 @@ class DataCollector: def dump_file_path(self): return self.data_writer.dump_file_path + @staticmethod + def check_scope_and_pid(scope, name, pid): + return (not scope or scope.check(name)) and pid == os.getpid() + + @staticmethod + def is_inplace(module): + return getattr(module, "op_is_inplace", False) + + @staticmethod + def op_need_trigger(module_name): + if 'Tensor___getitem___' in module_name: + return True + return False + + @staticmethod + def acl_backward_dump_status(output, grad, module_name): + if isinstance(output, torch.Tensor): + output.backward(grad, retain_graph=True) + return True + + for api_name in DataCollector.multi_output_apis: + if api_name in module_name: + output[0].backward(grad, retain_graph=True) + return True + return False + + def if_return_forward_new_output(self): + return self.data_processor.if_return_forward_new_output() + + def get_forward_new_output(self): + return self.data_processor.get_forward_new_output() + def visit_and_clear_overflow_status(self, api_or_module_name): self.data_processor.visit_and_clear_overflow_status(api_or_module_name) @@ -68,14 +96,6 @@ class DataCollector: self.data_writer.update_data(data_info) return msg - @staticmethod - def check_scope_and_pid(scope, name, pid): - return (not scope or scope.check(name)) and pid == os.getpid() - - @staticmethod - def is_inplace(module): - return getattr(module, "op_is_inplace", False) - def pre_forward_data_collect(self, name, module, pid, module_input_output): backward_name = name.replace("forward", "backward") if self.check_scope_and_pid(self.scope, backward_name, pid): @@ -156,11 +176,6 @@ class DataCollector: else: self.dump_mode_backward_acl_dump(module, module_input_output, module_name) - def op_need_trigger(self, module_name): - if 'Tensor___getitem___' in module_name: - return True - return False - def forward_acl_dump(self, module, module_input_output, module_name): global forward_init_status if not forward_init_status: @@ -179,17 +194,6 @@ class DataCollector: forward_init_status = False print_info_log("Dump %s op file." % module_name) - def acl_backward_dump_status(self, output, grad, module_name): - if isinstance(output, torch.Tensor): - output.backward(grad, retain_graph=True) - return True - - for api_name in DataCollector.multi_output_apis: - if api_name in module_name: - output[0].backward(grad, retain_graph=True) - return True - return False - def dump_mode_backward_acl_dump(self, module, module_input_output, module_name): global forward_init_status grad_path = self.config.backward_input.get(module_name) diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py index 1ef1b79acb2..e55765b6336 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py +++ b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py @@ -1,21 +1,24 @@ -import torch -import zlib -import numpy as np -import os import inspect +import os +import zlib from dataclasses import dataclass, asdict -import torch_npu from typing import Tuple, List, Dict, Optional, Union + +import numpy as np +import torch +import torch_npu + +from ..common import recursive_apply_transform from ..common.exceptions import MsaccException from ..common.file_check import path_len_exceeds_limit, change_mode, FileCheckConst from ..common.log import print_warn_log from ..common.utils import Const -from ..common import recursive_apply_transform -from ..functional. json_writer import DataWriter from ..free_benchmark import FreeBenchmarkCheck, UnequalRow +from ...core.utils import Const bits_for_overflow = 8 + def build_data_processor(config, data_writer): if config.task == DataProcessor.full: return FullTensorDataProcessor(config, data_writer) @@ -27,12 +30,12 @@ def build_data_processor(config, data_writer): return FreeBenchmarkDataProcessor(config, data_writer) else: raise MsaccException(MsaccException.INVALID_PARAM_ERROR, - "task should be in [{}, {}, {}, {}]".format( - DataProcessor.full, - DataProcessor.summary, - DataProcessor.overflow, - DataProcessor.free_benchmark - )) + "task should be in [{}, {}, {}, {}]".format( + DataProcessor.full, + DataProcessor.summary, + DataProcessor.overflow, + DataProcessor.free_benchmark + )) @dataclass @@ -44,14 +47,14 @@ class ModuleForwardInputsOutputs: @property def args_tuple(self): if not isinstance(self.args, tuple): - return (self.args, ) + return (self.args,) else: return self.args @property def output_tuple(self): if not isinstance(self.output, tuple): - return (self.output, ) + return (self.output,) else: return self.output @@ -68,14 +71,14 @@ class ModuleBackwardInputsOutputs: @property def grad_input_tuple(self): if not isinstance(self.grad_input, tuple): - return (self.grad_input, ) + return (self.grad_input,) else: return self.grad_input @property def grad_output_tuple(self): if not isinstance(self.grad_output, tuple): - return (self.grad_output, ) + return (self.grad_output,) else: return self.grad_output @@ -104,13 +107,6 @@ class DataProcessor: self._return_forward_new_output = False self._forward_new_output = None - def if_return_forward_new_output(self): - return self._return_forward_new_output - - def get_forward_new_output(self): - self._return_forward_new_output = False - return self._forward_new_output - @staticmethod def get_md5_for_tensor(x): if x.dtype == torch.bfloat16: @@ -156,21 +152,41 @@ class DataProcessor: return builtin_type(arg), type(arg).__name__ return arg, '' - def update_iter(self, current_iter): - self.current_iter = current_iter - - def visit_and_clear_overflow_status(self, api_or_module_name): - if self.current_api_or_module_name != api_or_module_name: - self.current_api_or_module_name = api_or_module_name - self.has_overflow = False + @staticmethod + def handle_tensor_extremum_nan_inf(data_clone, operator): + data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) + if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): + return float('nan') + finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) + if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: + finite_values = data_clone[finite_mask] + return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(finite_values).item() + else: + data_no_nan = data_clone[~data_nan] + return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(data_no_nan).item() - def _analyze_numpy(self, value, numpy_type): - single_arg = {} - single_arg.update({"type": numpy_type}) - single_arg.update({"value": value}) - return single_arg + @staticmethod + def analyze_api_call_stack(name): + stack_str = [] + for (_, path, line, func, code, _) in inspect.stack()[5:]: + if not code: + continue + stack_line = " ".join([ + "File", ", ".join([ + path, + " ".join(["line", str(line)]), + " ".join(["in", func]), + " ".join(["\n", code[0].strip()]) + ]) + ]) + stack_str.append(stack_line) + stack_info_struct = {name: stack_str} + return stack_info_struct - def get_stat_info(self, data): + @staticmethod + def get_stat_info(data): if data.is_meta: return data_clone = data.detach() @@ -184,7 +200,7 @@ class DataProcessor: tensor_min = False not in data_clone tensor_mean = None tensor_norm = None - elif not len(data_clone.shape): + elif not data_clone.shape: tensor_max = data_clone.item() tensor_min = tensor_max tensor_mean = tensor_max @@ -199,7 +215,15 @@ class DataProcessor: return tensor_max, tensor_min, tensor_mean, tensor_norm - def _analyze_builtin(self, arg): + @staticmethod + def _analyze_numpy(value, numpy_type): + single_arg = {} + single_arg.update({"type": numpy_type}) + single_arg.update({"value": value}) + return single_arg + + @staticmethod + def _analyze_builtin(arg): single_arg = {} if isinstance(arg, slice): single_arg.update({"type": "slice"}) @@ -214,12 +238,28 @@ class DataProcessor: single_arg.update({"value": arg}) return single_arg - def _analyze_torch_size(self, arg): + @staticmethod + def _analyze_torch_size(arg): single_arg = {} single_arg.update({"type": "torch.Size"}) single_arg.update({"value": list(arg)}) return single_arg + def if_return_forward_new_output(self): + return self._return_forward_new_output + + def get_forward_new_output(self): + self._return_forward_new_output = False + return self._forward_new_output + + def update_iter(self, current_iter): + self.current_iter = current_iter + + def visit_and_clear_overflow_status(self, api_or_module_name): + if self.current_api_or_module_name != api_or_module_name: + self.current_api_or_module_name = api_or_module_name + self.has_overflow = False + def is_dump_for_data_mode(self, forward_backward, input_output): """ Compare the parameters with data_mode to determine whether to dump. @@ -235,56 +275,6 @@ class DataProcessor: forward_backward in self.config.data_mode or input_output in self.config.data_mode) - @staticmethod - def handle_tensor_extremum_nan_inf(data_clone, operator): - data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) - if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): - return float('nan') - finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) - if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: - finite_values = data_clone[finite_mask] - return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(finite_values).item() - else: - data_no_nan = data_clone[~data_nan] - return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(data_no_nan).item() - - def _analyze_maybe_overflow_tensor(self, tensor_json, tensor): - data_clone = tensor.detach() - if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan(): - if tensor_json['Max'] is None: - return - if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): - tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max") - self.has_overflow = True - if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): - tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min") - self.has_overflow = True - else: - self.has_overflow = check_overflow_npu() - if self.has_overflow: - clear_overflow_npu() - - def _analyze_tensor(self, tensor, suffix): - tensor_max, tensor_min, tensor_mean, tensor_norm = self.get_stat_info(tensor) - - tensor_json = {} - tensor_json.update({'type': 'torch.Tensor'}) - tensor_json.update({'dtype': str(tensor.dtype)}) - tensor_json.update({"shape": tensor.shape}) - tensor_json.update({"Max": tensor_max}) - tensor_json.update({"Min": tensor_min}) - self._analyze_maybe_overflow_tensor(tensor_json, tensor) - tensor_json.update({"Mean": tensor_mean}) - tensor_json.update({"Norm": tensor_norm}) - tensor_json.update({"requires_grad": tensor.requires_grad}) - if self.config.summary_mode == "md5": - tensor_md5 = self.get_md5_for_tensor(tensor) - tensor_json.update({"md5": tensor_md5}) - - return tensor_json - def analyze_single_element(self, element, suffix_stack): if suffix_stack and suffix_stack[-1] in self.torch_object_key: return self.torch_object_key[suffix_stack[-1]](element) @@ -305,31 +295,13 @@ class DataProcessor: def analyze_element(self, element): return recursive_apply_transform(element, self.analyze_single_element) - @staticmethod - def analyze_api_call_stack(name): - stack_str = [] - for (_, path, line, func, code, _) in inspect.stack()[5:]: - if not code: - continue - stack_line = " ".join([ - "File", ", ".join([ - path, - " ".join(["line", str(line)]), - " ".join(["in", func]), - " ".join(["\n", code[0].strip()]) - ]) - ]) - stack_str.append(stack_line) - stack_info_struct = {name: stack_str} - return stack_info_struct - def analyze_pre_forward(self, name, module, - module_input_output: ModuleForwardInputsOutputs): + module_input_output: ModuleForwardInputsOutputs): pass def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): api_info_struct = {} - if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input + if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input api_info_struct[name] = {} self.api_data_category = Const.INPUT args_info_list = self.analyze_element(module_input_output.args_tuple) @@ -339,7 +311,8 @@ class DataProcessor: kwargs_info_list = self.analyze_element(module_input_output.kwargs) api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list - if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): # check whether data_mode contains forward or output + if self.is_dump_for_data_mode(Const.FORWARD, + Const.OUTPUT): # check whether data_mode contains forward or output api_info_struct[name] = api_info_struct.get(name, {}) self.api_data_category = Const.OUTPUT output_info_list = self.analyze_element(module_input_output.output_tuple) @@ -372,7 +345,6 @@ class DataProcessor: return api_info_struct - def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs): api_info_struct = {} if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT): @@ -389,14 +361,52 @@ class DataProcessor: return api_info_struct + def _analyze_maybe_overflow_tensor(self, tensor_json, tensor): + data_clone = tensor.detach() + if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan(): + if tensor_json[Const.MAX] is None: + return + if np.isinf(tensor_json[Const.MAX]) or np.isnan(tensor_json[Const.MAX]): + tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max") + self.has_overflow = True + if np.isinf(tensor_json[Const.MIN]) or np.isnan(tensor_json[Const.MIN]): + tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min") + self.has_overflow = True + else: + self.has_overflow = check_overflow_npu() + if self.has_overflow: + clear_overflow_npu() + + def _analyze_tensor(self, tensor, suffix): + tensor_max, tensor_min, tensor_mean, tensor_norm = self.get_stat_info(tensor) + + tensor_json = {} + tensor_json.update({'type': 'torch.Tensor'}) + tensor_json.update({'dtype': str(tensor.dtype)}) + tensor_json.update({"shape": tensor.shape}) + tensor_json.update({"Max": tensor_max}) + tensor_json.update({"Min": tensor_min}) + self._analyze_maybe_overflow_tensor(tensor_json, tensor) + tensor_json.update({"Mean": tensor_mean}) + tensor_json.update({"Norm": tensor_norm}) + tensor_json.update({"requires_grad": tensor.requires_grad}) + if self.config.summary_mode == "md5": + tensor_md5 = self.get_md5_for_tensor(tensor) + tensor_json.update({"md5": tensor_md5}) + + return tensor_json + class FullTensorDataProcessor(DataProcessor): - def _analyze_tensor(self, tensor, suffix): + def __init__(self, config, data_writer): + super().__init__(config, data_writer) self.data_path = self.data_writer.dump_tensor_data_dir + + def _analyze_tensor(self, tensor, suffix): dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + suffix + ".pt") - file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + file_path = os.path.join(self.data_path, dump_data_name) if not path_len_exceeds_limit(file_path): torch.save(tensor, file_path) change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -412,15 +422,15 @@ class OverflowTensorDataProcessor(DataProcessor): def __init__(self, config, data_writer): super().__init__(config, data_writer) + self.data_path = self.data_writer.dump_tensor_data_dir self.cached_tensors_and_file_paths = {} self.real_overflow_dump_times = 0 self.overflow_nums = config.overflow_num def _analyze_tensor(self, tensor, suffix): - self.data_path = self.data_writer.dump_tensor_data_dir dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + suffix + ".pt") - file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + file_path = os.path.join(self.data_path, dump_data_name) if not path_len_exceeds_limit(file_path): self.cached_tensors_and_file_paths.update({file_path: tensor}) else: @@ -437,7 +447,7 @@ class OverflowTensorDataProcessor(DataProcessor): return api_info_struct if self.has_overflow else None def analyze_backward(self, name, module, - module_input_output: ModuleBackwardInputsOutputs): + module_input_output: ModuleBackwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_backward(name, module, module_input_output) self.maybe_save_overflow_data_and_check_overflow_times() @@ -483,7 +493,7 @@ class FreeBenchmarkDataProcessor(DataProcessor): return def analyze_pre_forward(self, name, module, - module_input_output: ModuleForwardInputsOutputs): + module_input_output: ModuleForwardInputsOutputs): args = module_input_output.args kwargs = module_input_output.kwargs self.checker.pre_forward(name, module, self, args, kwargs) @@ -495,7 +505,7 @@ class FreeBenchmarkDataProcessor(DataProcessor): module_input_output.args, module_input_output.kwargs, module_input_output.output, - ) + ) self.update_unequal_rows(unequal_rows) if self.checker.if_fix(): self._return_forward_new_output = True @@ -507,11 +517,11 @@ class FreeBenchmarkDataProcessor(DataProcessor): return None - def overflow_debug_mode_enable(): overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE) return overflow_mode == Const.ENV_ENABLE + def check_overflow_npu(): if overflow_debug_mode_enable(): float_status = torch.zeros(bits_for_overflow).npu() @@ -523,6 +533,7 @@ def check_overflow_npu(): else: return torch_npu._C._check_overflow_npu() + def clear_overflow_npu(): if overflow_debug_mode_enable(): float_status = torch.zeros(bits_for_overflow).npu() @@ -530,6 +541,7 @@ def clear_overflow_npu(): else: torch_npu._C._clear_overflow_npu() + class OverflowConst: """ Class for Overflow diff --git a/debug/accuracy_tools/atat/pytorch/functional/json_writer.py b/debug/accuracy_tools/atat/pytorch/functional/json_writer.py index 0fee3aa9731..61c34eda6bb 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/json_writer.py +++ b/debug/accuracy_tools/atat/pytorch/functional/json_writer.py @@ -1,16 +1,15 @@ -import os import csv -from pathlib import Path import json +import os +from pathlib import Path + from ..common.file_check import FileCheckConst, change_mode from ..common.log import print_info_log_rank_0 from ..common.utils import Const +from ...core.utils import Const class DataWriter: # TODO: UT - # dump_json_name = "dump.json" - # stack_json_name = "stack.json" - # construct_json_name = "construct.json" def __init__(self, init_json=None) -> None: self.dump_count = 0 @@ -18,17 +17,31 @@ class DataWriter: # TODO: UT self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name) self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name) self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name) - self.free_benchmark_file_path = None + self.free_benchmark_file_path = None self.dump_tensor_data_dir = None self.buffer_size = 1000 - self.cache_data = {"data": {}} + self.cache_data = {Const.DATA: {}} self.cache_stack = {} self.cache_construct = {} + @staticmethod + def write_data_to_csv(result: list, result_header: tuple, file_path: str): + if len(result) == 0: + return + is_exists = os.path.exists(file_path) + append = "a+" if is_exists else "w+" + with os.fdopen( + os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline="" + ) as csv_file: + spawn_writer = csv.writer(csv_file) + if not is_exists: + spawn_writer.writerow(result_header) + spawn_writer.writerows([result, ]) + def initialize_json_file(self, **kwargs): - kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, "data": {}}) + kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) with os.fdopen( - os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w' + os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w' ) as f: json.dump(kwargs, f) @@ -42,7 +55,8 @@ class DataWriter: # TODO: UT Path(self.construct_file_path).touch() change_mode(self.construct_file_path, FileCheckConst.DATA_FILE_AUTHORITY) - def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path): + def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir, + free_benchmark_file_path): self.dump_file_path = dump_file_path self.stack_file_path = stack_file_path self.construct_file_path = construct_file_path @@ -51,13 +65,13 @@ class DataWriter: # TODO: UT def update_data(self, new_data): key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1 - if key in self.cache_data["data"]: - self.cache_data["data"][key].update(new_data[key]) + if key in self.cache_data[Const.DATA]: + self.cache_data[Const.DATA][key].update(new_data[key]) else: - self.cache_data["data"].update(new_data) + self.cache_data[Const.DATA].update(new_data) def flush_data_when_buffer_is_full(self): - if len(self.cache_data["data"]) >= self.buffer_size: + if len(self.cache_data[Const.DATA]) >= self.buffer_size: self.write_data_json(self.dump_file_path) def update_stack(self, new_data): @@ -77,13 +91,13 @@ class DataWriter: # TODO: UT else: self.init_json['data_path'] = self.dump_tensor_data_dir data_to_write = self.init_json - data_to_write['data'].update(self.cache_data['data']) + data_to_write[Const.DATA].update(self.cache_data[Const.DATA]) with open(file_path, 'w+') as f: fcntl.flock(f, fcntl.LOCK_EX) json.dump(data_to_write, f, indent=1) fcntl.flock(f, fcntl.LOCK_UN) - self.cache_data["data"].clear() + self.cache_data[Const.DATA].clear() def write_stack_info_json(self, file_path): import fcntl @@ -103,18 +117,3 @@ class DataWriter: # TODO: UT self.write_data_json(self.dump_file_path) self.write_stack_info_json(self.stack_file_path) self.write_construct_info_json(self.construct_file_path) - - @staticmethod - def write_data_to_csv(result: list, result_header: tuple, file_path: str): - if len(result) == 0: - return - is_exists = os.path.exists(file_path) - append = "a+" if is_exists else "w+" - with os.fdopen( - os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline="" - ) as csv_file: - spawn_writer = csv.writer(csv_file) - if not is_exists: - spawn_writer.writerow(result_header) - spawn_writer.writerows([result,]) - \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/functional/scope.py b/debug/accuracy_tools/atat/pytorch/functional/scope.py index e557b876b1b..e2997d91966 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/scope.py +++ b/debug/accuracy_tools/atat/pytorch/functional/scope.py @@ -3,7 +3,11 @@ from ..common.exceptions import ScopeException from ..common.utils import Const -def build_scope(scope_class, scope=[], api_list=[]): +def build_scope(scope_class, scope=None, api_list=None): + if api_list is None: + api_list = [] + if scope is None: + scope = [] if not scope and not api_list: return None if scope_class: @@ -30,31 +34,35 @@ class BaseScope(ABC): Module_Type_Module = "Module" Module_Type_API = "api" + def __init__(self, scope, api_list): + scope, api_list = self.rectify_args(scope, api_list) + self.scope = scope + self.api_list = api_list + @staticmethod def rectify_args(scope, api_list): if not isinstance(api_list, list): raise ScopeException(ScopeException.InvalidApiStr, - f"api_list参数须配置为列表,实际类型为{type(api_list)}.") + f"api_list参数须配置为列表,实际类型为{type(api_list)}.") for api in api_list: if not isinstance(api, str): raise ScopeException(ScopeException.InvalidApiStr, - f"api_list中的元素须配置为字符串,实际类型为{type(api)}.") + f"api_list中的元素须配置为字符串,实际类型为{type(api)}.") if isinstance(scope, str): scope = [scope] return scope, api_list if not isinstance(scope, list): raise ScopeException(ScopeException.InvalidScope, - f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.") + f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.") for s in scope: if not isinstance(s, str): raise ScopeException(ScopeException.InvalidScope, - f"scope列表元素要求类型为字符串,实际类型为{type(s)}.") + f"scope列表元素要求类型为字符串,实际类型为{type(s)}.") return scope, api_list - def __init__(self, scope, api_list): - scope, api_list = self.rectify_args(scope, api_list) - self.scope = scope - self.api_list = api_list + @abstractmethod + def check(self, name): + pass def check_api_list(self, api_name): if not self.api_list: @@ -62,18 +70,14 @@ class BaseScope(ABC): for api_str in self.api_list: if api_str in api_name: return True - - @abstractmethod - def check(self, name): - pass - + return False class ListScope(BaseScope): @staticmethod def rectify_args(scope, api_list): if scope and api_list: raise ScopeException(ScopeException.ArgConflict, - f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.") + f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.") return super(ListScope, ListScope).rectify_args(scope, api_list) def check(self, module_name): @@ -83,6 +87,12 @@ class ListScope(BaseScope): class RangeScope(BaseScope, ABC): + + def __init__(self, *args): + super().__init__(*args) + self.in_scope = False + self.is_valid = self.check_scope_is_valid() + @staticmethod def rectify_args(scope, api_list): scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list) @@ -91,7 +101,7 @@ class RangeScope(BaseScope, ABC): scope.append(scope[0]) elif len(scope) > 2: raise ScopeException(ScopeException.InvalidScope, - f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.") + f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.") return scope, api_list @@ -99,11 +109,6 @@ class RangeScope(BaseScope, ABC): def check_scope_is_valid(self): pass - def __init__(self, *args): - super().__init__(*args) - self.in_scope = False - self.is_valid = self.check_scope_is_valid() - def begin_module(self, module_name): pass @@ -143,6 +148,7 @@ class ModuleRangeScope(RangeScope): 需要用pre_hook和full_backward_hook来精确控制module的开始和结束, 在这些hook触发时调用begin_module和end_module做区间控制 """ + def check_scope_is_valid(self): if not self.scope: return True @@ -169,6 +175,3 @@ class ModuleRangeScope(RangeScope): if not self.scope or self.in_scope: return self.check_api_list(module_name) return False - - - diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py index 003a8699cd7..fd43ae70c0b 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py @@ -25,6 +25,7 @@ from .wrap_vf import get_vf_ops from .wrap_distributed import get_distributed_ops from .wrap_aten import get_aten_ops from ..common.utils import torch_without_guard_version, npu_distributed_api, is_gpu +from ...core.utils import Const torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' if not is_gpu: @@ -108,19 +109,19 @@ class ApiRegistry: self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr) wrap_tensor.wrap_tensor_ops_and_bind(hook) for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.tensor_hook_attr[attr_name[5:]] = getattr(wrap_tensor.HOOKTensor, attr_name) self.store_ori_attr(torch, get_torch_ops(), self.torch_ori_attr) wrap_torch.wrap_torch_ops_and_bind(hook) for attr_name in dir(wrap_torch.HOOKTorchOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.torch_hook_attr[attr_name[5:]] = getattr(wrap_torch.HOOKTorchOP, attr_name) self.store_ori_attr(torch.nn.functional, get_functional_ops(), self.functional_ori_attr) wrap_functional.wrap_functional_ops_and_bind(hook) for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.functional_hook_attr[attr_name[5:]] = getattr(wrap_functional.HOOKFunctionalOP, attr_name) self.store_ori_attr(dist, get_distributed_ops(), self.distributed_ori_attr) @@ -128,7 +129,7 @@ class ApiRegistry: if not is_gpu and not torch_without_guard_version: self.store_ori_attr(torch_npu.distributed, npu_distributed_api, self.npu_distributed_ori_attr) for attr_name in dir(wrap_distributed.HOOKDistributedOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name) if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api: self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, @@ -138,20 +139,20 @@ class ApiRegistry: self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr) wrap_aten.wrap_aten_ops_and_bind(hook) for attr_name in dir(wrap_aten.HOOKAtenOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.aten_hook_attr[attr_name[5:]] = getattr(wrap_aten.HOOKAtenOP, attr_name) self.store_ori_attr(torch._VF, get_vf_ops(), self.vf_ori_attr) wrap_vf.wrap_vf_ops_and_bind(hook) for attr_name in dir(wrap_vf.HOOKVfOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.vf_hook_attr[attr_name[5:]] = getattr(wrap_vf.HOOKVfOP, attr_name) if not is_gpu: self.store_ori_attr(torch_npu, get_npu_ops(), self.torch_npu_ori_attr) wrap_npu_custom.wrap_npu_ops_and_bind(hook) for attr_name in dir(wrap_npu_custom.HOOKNpuOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.torch_npu_hook_attr[attr_name[5:]] = getattr(wrap_npu_custom.HOOKNpuOP, attr_name) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py index 08d47308e07..351a307ccf5 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py @@ -32,8 +32,6 @@ with FileOpen(yaml_path, 'r') as f: def get_vf_ops(): global WrapVfOps - # _all_functional_ops = dir(torch.nn.functional) - # assert set(WrapFunctionalOps) <= set(_all_functional_ops) return WrapVfOps diff --git a/debug/accuracy_tools/atat/pytorch/overflow_check/__init__.py b/debug/accuracy_tools/atat/pytorch/overflow_check/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/debug/accuracy_tools/atat/pytorch/overflow_check/info_dump.py b/debug/accuracy_tools/atat/pytorch/overflow_check/info_dump.py deleted file mode 100644 index 161e9f23f0f..00000000000 --- a/debug/accuracy_tools/atat/pytorch/overflow_check/info_dump.py +++ /dev/null @@ -1,252 +0,0 @@ -import inspect -import fcntl -import os -import threading - -import json -import numpy as np -import torch - -from atat.core.file_check_util import FileOpen, FileCheckConst, change_mode -from atat.core.utils import get_time -from ..common.utils import print_error_log - - -special_torch_object = ["memory_format"] -lock = threading.Lock() - - -def write_npy(file_path, tensor): - saved_tensor = tensor.contiguous().cpu().detach() - if tensor.dtype == torch.bfloat16: - saved_numpy = saved_tensor.to(torch.float32).numpy() - else: - saved_numpy = saved_tensor.numpy() - if os.path.exists(file_path): - raise ValueError(f"File {file_path} already exists") - np.save(file_path, saved_numpy) - full_path = os.path.abspath(file_path) - return full_path - - -class APIInfo: - def __init__(self, api_name, is_forward, save_real_data=False): - self.rank = os.getpid() - self.api_name = api_name - self.save_real_data = save_real_data - self.torch_object_key = {'device': self.analyze_device_in_kwargs, 'dtype': self.analyze_dtype_in_kwargs} - self.is_forward = is_forward - self.args_num = 0 - - def analyze_element(self, element): - if isinstance(element, (list, tuple)): - out = [] - for item in element: - out.append(self.analyze_element(item)) - return out - elif isinstance(element, dict): - out_dict = {} - for key, value in element.items(): - if key in self.torch_object_key.keys(): - fun = self.torch_object_key[key] - out_dict[key] = fun(value) - elif key in special_torch_object: - continue - else: - out_dict[key] = self.analyze_element(value) - return out_dict - elif isinstance(element, torch.Tensor): - out_tensor = self.analyze_tensor(element, self.save_real_data) - return out_tensor - elif self.is_builtin_class(element): - out_builtin = self.analyze_builtin(element) - return out_builtin - else: - msg = f"Type {type(element)} is unsupported at analyze_element" - print_error_log(msg) - - raise NotImplementedError(msg) - - def analyze_tensor(self, arg, save_real_data): - single_arg = {} - if not save_real_data: - single_arg.update({'type': 'torch.Tensor'}) - single_arg.update({'dtype': str(arg.dtype)}) - single_arg.update({'shape': arg.shape}) - single_arg.update({'Max': self.transfer_types(self.get_tensor_extremum(arg, 'max'), str(arg.dtype))}) - single_arg.update({'Min': self.transfer_types(self.get_tensor_extremum(arg, 'min'), str(arg.dtype))}) - single_arg.update({'requires_grad': arg.requires_grad}) - - else: - dump_path = "./" - api_args = self.api_name + '.' + str(self.args_num) - rank = arg.device.index - if self.is_forward: - forward_real_data_path = os.path.join(dump_path, "forward_real_data_" + get_time(), f"rank{rank}") - if not os.path.exists(forward_real_data_path): - os.makedirs(forward_real_data_path, 0o755) - - file_path = os.path.join(forward_real_data_path, f'{api_args}.npy') - else: - backward_real_data_path = os.path.join(dump_path, "backward_real_data_" + get_time(), f"rank{rank}") - if not os.path.exists(backward_real_data_path): - os.makedirs(backward_real_data_path, 0o755) - file_path = os.path.join(backward_real_data_path, f'{api_args}.npy') - self.args_num += 1 - npy_path = write_npy(file_path, arg) - single_arg.update({'type': 'torch.Tensor'}) - single_arg.update({'datapath': npy_path}) - single_arg.update({'requires_grad': arg.requires_grad}) - return single_arg - - def analyze_builtin(self, arg): - single_arg = {} - if isinstance(arg, slice): - single_arg.update({'type': "slice"}) - single_arg.update({'value': [arg.start, arg.stop, arg.step]}) - else: - single_arg.update({'type': self.get_type_name(str(type(arg)))}) - single_arg.update({'value': arg}) - return single_arg - - def transfer_types(self, data, dtype): - if 'int' in dtype or 'bool' in dtype: - return int(data) - else: - return float(data) - - def is_builtin_class(self, element): - if element is None or isinstance(element, (bool, int, float, str, slice)): - return True - return False - - def analyze_device_in_kwargs(self, element): - single_arg = {} - single_arg.update({'type': 'torch.device'}) - if not isinstance(element, str): - - if hasattr(element, "index"): - device_value = element.type + ":" + str(element.index) - single_arg.update({'value': device_value}) - else: - device_value = element.type - else: - single_arg.update({'value': element}) - return single_arg - - def analyze_dtype_in_kwargs(self, element): - single_arg = {} - single_arg.update({'type': 'torch.dtype'}) - single_arg.update({'value': str(element)}) - return single_arg - - def get_tensor_extremum(self, data, operator): - if data.dtype is torch.bool: - if operator == 'max': - return True in data - elif operator == 'min': - return False not in data - if operator == 'max': - return torch._C._VariableFunctionsClass.max(data).item() - else: - return torch._C._VariableFunctionsClass.min(data).item() - - def get_type_name(self, name): - - left = name.index("'") - right = name.rindex("'") - return name[left + 1: right] - - -class ForwardAPIInfo(APIInfo): - def __init__(self, name, save_real_data, args, kwargs): - super().__init__(name, is_forward=True, save_real_data=save_real_data) - self.analyze_api_input(args, kwargs) - self.analyze_api_call_stack() - - def analyze_api_input(self, args, kwargs): - args_info_list = self.analyze_element(args) - kwargs_info_dict = self.analyze_element(kwargs) - self.api_info_struct = {self.api_name: {"args": args_info_list, "kwargs": kwargs_info_dict}} - - def analyze_api_call_stack(self): - stack_str = [] - for (_, path, line, func, code, _) in inspect.stack()[3:]: - if not code: - continue - stack_line = " ".join([ - "File", ", ".join([path, " ".join(["line", str(line)]), " ".join(["in", func]), - " ".join(["\n", code[0].strip()])])]) - stack_str.append(stack_line) - self.stack_info_struct = {self.api_name: stack_str} - - -class BackwardAPIInfo(APIInfo): - def __init__(self, name, grads): - super().__init__(name, is_forward=False) - self.analyze_api_input(grads) - - def analyze_api_input(self, grads): - grads_info_list = self.analyze_element(grads) - self.grad_info_struct = {self.api_name: grads_info_list} - - -def write_api_info_json(api_info): - dump_path = "./" - rank = api_info.rank - if isinstance(api_info, ForwardAPIInfo): - file_path = os.path.join(dump_path, f'forward_info_{rank}.json') - stack_file_path = os.path.join(dump_path, f'stack_info_{rank}.json') - write_json(file_path, api_info.api_info_struct) - write_json(stack_file_path, api_info.stack_info_struct, indent=4) - - elif isinstance(api_info, BackwardAPIInfo): - file_path = os.path.join(dump_path, f'backward_info_{rank}.json') - write_json(file_path, api_info.grad_info_struct) - else: - raise ValueError(f"Invalid api_info type {type(api_info)}") - - -def write_json(file_path, data, indent=None): - if not os.path.exists(file_path): - with FileOpen(file_path, 'w') as f: - f.write("{\n}") - change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - lock.acquire() - with FileOpen(file_path, 'a+') as f: - fcntl.flock(f, fcntl.LOCK_EX) - try: - f.seek(0, os.SEEK_END) - f.seek(f.tell() - 1, os.SEEK_SET) - f.truncate() - if f.tell() > 3: - f.seek(f.tell() - 1, os.SEEK_SET) - f.truncate() - f.write(',\n') - f.write(json.dumps(data, indent=indent)[1:-1] + '\n}') - except Exception as e: - raise ValueError(f"Json save failed:{e}") from e - finally: - fcntl.flock(f, fcntl.LOCK_UN) - lock.release() - - -def initialize_output_json(): - dump_path = os.path.realpath("./") - files = ['forward_info.json', 'backward_info.json', 'stack_info.json'] - - forward_real_data_path = os.path.join(dump_path, 'forward_real_data') - if os.path.exists(forward_real_data_path): - raise ValueError(f"file {forward_real_data_path} already exists, please remove it first") - else: - os.mkdir(forward_real_data_path, mode=0o750) - - backward_real_data_path = os.path.join(dump_path, 'backward_real_data') - if os.path.exists(backward_real_data_path): - raise ValueError(f"file {backward_real_data_path} already exists, please remove it first") - else: - os.mkdir(backward_real_data_path, mode=0o750) - for file in files: - file_path = os.path.join(dump_path, file) - if os.path.exists(file_path): - raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/overflow_check/overflow_check.py b/debug/accuracy_tools/atat/pytorch/overflow_check/overflow_check.py deleted file mode 100644 index f8f9926b6cd..00000000000 --- a/debug/accuracy_tools/atat/pytorch/overflow_check/overflow_check.py +++ /dev/null @@ -1,190 +0,0 @@ -import os -from pathlib import Path - -import torch - -try: - import torch_npu -except ImportError: - is_gpu = True -else: - is_gpu = False - -from atat.core.file_check_util import FileCheckConst -from atat.core.utils import print_warn_log, get_time, print_info_log -from ..dump.dump import forward_init_status, forward_acl_dump -from .utils import OverFlowUtil, dump_overflow, check_overflow_npu, clear_overflow_npu -from ..dump.utils import DumpUtil, Const, get_tensor_rank, create_dirs_if_not_exist, check_single_rank_folder -from .info_dump import write_api_info_json, ForwardAPIInfo, BackwardAPIInfo -from ..dump import dump - -backward_init_status = False -api_overflow = [] -forward_api_info = {} -backward_api_info = {} -FORWARD_REAL_DATA_PATH = os.path.join('./', 'forward_real_data') -BACKWARD_REAL_DATA_PATH = os.path.join('./', 'backward_real_data') -rank = os.getpid() -pkl_name = '' - - -def check_overflow_environment(pid): - if not OverFlowUtil.get_overflow_check_switch(): - return False - if pid != os.getpid(): - return False - if is_gpu: - print_warn_log("Overflow detection is not supported in the GPU environment.") - return False - global backward_init_status - if backward_init_status or forward_init_status: - return False - return True - - -def check_data_overflow(x): - if isinstance(x, (tuple, list)) and x: - for i, item in enumerate(x): - if True == check_data_overflow(item): - return True - return False - else: - if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool: - if x.is_meta: - return False - if len(x.shape) == 0: - tensor_max = x.cpu().detach().float().numpy().tolist() - tensor_min = tensor_max - else: - tensor_max = torch._C._VariableFunctionsClass.max(x).cpu().detach().float().numpy().tolist() - tensor_min = torch._C._VariableFunctionsClass.min(x).cpu().detach().float().numpy().tolist() - # inf - if tensor_max == float('inf') or tensor_min == float('-inf'): - return True - if x.dtype in [torch.float16, torch.float32, torch.bfloat16] and \ - (tensor_max == torch.finfo(x.dtype).max or tensor_min == torch.finfo(x.dtype).min): - return True - # nan - elif tensor_max != tensor_max or tensor_min != tensor_min: - return True - else: - return False - elif isinstance(x, bool) or isinstance(x, int) or isinstance(x, float): - if x == float('inf') or x == float('-inf') or x != x: - return True - else: - return False - else: - return False - - -def check_path(apis, path): - return any(api in path for api in apis) - - -def overflow_check(name, **kwargs): - overflow_nums = OverFlowUtil.overflow_nums - pid = kwargs.get('pid') - dump_mode = DumpUtil.dump_switch_mode - if not pid: - return RuntimeError("Not get the specified process pid.") - - def overflowcheck_hook(module, in_feat, out_feat=None): - if not check_overflow_environment(pid): - return - dump_file = DumpUtil.get_dump_path() - global rank - dump_dir, dump_filename = os.path.split(dump_file) - dump_dir = os.path.join(dump_dir, "step{}".format(DumpUtil.iter_num)) - if not os.path.exists(dump_dir): - Path(dump_dir).mkdir(mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) - if DumpUtil.is_single_rank is None: - DumpUtil.is_single_rank = check_single_rank_folder(dump_dir) - dump_file = os.path.join(dump_dir, dump_filename) - rank_this = get_tensor_rank(in_feat, out_feat) - DumpUtil.dump_root = os.path.dirname(DumpUtil.dump_path) - if rank_this is not None and rank != rank_this: - rank = rank_this - dump.rename_() - if DumpUtil.target_rank is not None: - if rank != DumpUtil.target_rank: - return - dump_path = create_dirs_if_not_exist(rank, dump_file) - global pkl_name - pkl_name = dump_path - dump_dir = os.path.split(dump_path)[0] - global api_overflow - global forward_api_info - global backward_api_info - - module_name = name - if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan(): - # backward API endwith backward - if module_name.endswith(Const.BACKWARD): - check_feat = in_feat - else: - check_feat = out_feat - module.has_overflow = check_data_overflow(check_feat) - else: - module.has_overflow = check_overflow_npu() - if not module.has_overflow: - if hasattr(module, 'input_args'): - del module.input_args - if hasattr(module, 'input_kwargs'): - del module.input_kwargs - if module.has_overflow and OverFlowUtil.check_overflow_dump_times(overflow_nums): - if overflow_type_judge(in_feat, out_feat, module_name) and DumpUtil.need_replicate: - if module_name.endswith(Const.FORWARD): - forward_api_info.update({name: ForwardAPIInfo(name, True, module.input_args, module.input_kwargs)}) - api_overflow.append(module_name) - else: - api_overflow.append(module_name.replace("backward", "forward")) - backward_api_info.update({name: BackwardAPIInfo(name, out_feat)}) - OverFlowUtil.inc_overflow_dump_times() - dump_file_name = os.path.join(dump_dir, - "{}_{}.pkl".format(module_name, OverFlowUtil.real_overflow_dump_times)) - dump_overflow(module_name, in_feat, out_feat, dump_file_name) - dump.pkl_name = dump_file_name - - print_warn_log("[overflow {} times]: module name :'{}' is overflow and dump file is saved in '{}'." - .format(OverFlowUtil.real_overflow_dump_times, module_name, - os.path.realpath(dump_file_name))) - if dump_mode == "acl": - acl_dump(module, module_name) - dump.write_to_disk() - # clear overflow flag for the next check - clear_overflow_npu() - if not OverFlowUtil.check_overflow_dump_times(overflow_nums): - for key in forward_api_info: - write_api_info_json(forward_api_info[key]) - for key in backward_api_info: - write_api_info_json(backward_api_info[key]) - raise ValueError("[overflow {} times]: dump file is saved in '{}'." - .format(OverFlowUtil.real_overflow_dump_times, os.path.realpath(dump_file_name))) - - def overflow_type_judge(in_feat, out_feat, module_name): - if module_name.endswith(Const.BACKWARD): - check_feat = out_feat - else: - check_feat = in_feat - if check_data_overflow(check_feat): - print_warn_log("module name :'{}' is overflow and its inputs already has an overflow, so you need " - "to go back to find where the overflow started.".format(module_name)) - return False - elif not check_data_overflow(in_feat) and not check_data_overflow(out_feat): - print_warn_log("module name :'{}' is overflow and its inputs and outputs do not overflow, " - "so this is a process overflow".format(module_name)) - return False - else: - print_warn_log("module name :'{}' is overflow. Its input is normal and its output " - "is overflow.".format(module_name)) - return True - - def acl_dump(module, module_name): - if "forward" in module_name: - forward_acl_dump(module, module_name) - if "backward" in module_name: - print_info_log("The overflow is caused by backward operator {}. " - "You can use reverse acl dump(mode='acl') to get operator dump data.".format(module_name)) - - return overflowcheck_hook diff --git a/debug/accuracy_tools/atat/pytorch/overflow_check/utils.py b/debug/accuracy_tools/atat/pytorch/overflow_check/utils.py deleted file mode 100644 index d254d584550..00000000000 --- a/debug/accuracy_tools/atat/pytorch/overflow_check/utils.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -import torch - -try: - import torch_npu -except ImportError: - is_gpu = True -else: - is_gpu = False - -from atat.core.utils import check_switch_valid, check_inplace_op, OverflowConst -from ..common.utils import Const -from ..dump.dump import dump_stack_info, get_scalar_data_info, dump_data_by_rank_count, \ - get_not_float_tensor_info, get_float_tensor_info -from ..dump.utils import DumpUtil, make_dump_data_dir - - -class OverFlowUtil(object): - overflow_check_switch = None - overflow_filter_switch = Const.OFF - real_overflow_dump_times = 0 - overflow_nums = 1 - - @staticmethod - def set_overflow_check_switch(switch, filter_switch): - OverFlowUtil.overflow_check_switch = switch - OverFlowUtil.overflow_filter_switch = filter_switch - - @staticmethod - def get_overflow_check_switch(): - if OverFlowUtil.overflow_check_switch is None: - return True - return OverFlowUtil.overflow_check_switch == "ON" - - @staticmethod - def inc_overflow_dump_times(): - OverFlowUtil.real_overflow_dump_times += 1 - - @staticmethod - def check_overflow_dump_times(need_dump_times): - if need_dump_times == -1: - return True - return OverFlowUtil.real_overflow_dump_times < need_dump_times - - -def set_overflow_check_switch(switch, filter_switch=Const.OFF): - check_switch_valid(switch) - check_switch_valid(filter_switch) - - OverFlowUtil.set_overflow_check_switch(switch, filter_switch) - - -def dump_overflow(module_name, in_feat, out_feat, dump_file): - name_template = f"{module_name}" + "_{}" - DumpUtil.dump_data_dir = make_dump_data_dir(dump_file) - dump_stack_info(name_template) - if check_inplace_op(name_template): - if Const.PRE_FORWARD in name_template: - name_template = name_template.replace(Const.PRE_FORWARD, Const.FORWARD) - else: - _dump_tensor_completely(in_feat, name_template.format("output")) - return - - if "forward" in name_template: - _dump_tensor_completely(in_feat, name_template.format("input")) - _dump_tensor_completely(out_feat, name_template.format("output")) - else: - _dump_tensor_completely(in_feat, name_template.format("output")) - _dump_tensor_completely(out_feat, name_template.format("input")) - - -def _dump_tensor_completely(x, prefix): - dump_flag = Const.DUMP_RATIO_MAX + 1 - if isinstance(x, (tuple, list)) and x: - for i, item in enumerate(x): - _dump_tensor_completely(item, "{}.{}".format(prefix, i)) - elif isinstance(x, torch.Tensor): - if x.numel() == 0 or len(x.shape) == 0 or not x.is_floating_point(): - if OverFlowUtil.overflow_filter_switch == Const.OFF: - data_info = get_not_float_tensor_info(x) - dump_data_by_rank_count(dump_flag, prefix, data_info) - else: - data_info = get_float_tensor_info(x) - dump_data_by_rank_count(dump_flag, prefix, data_info) - - elif OverFlowUtil.overflow_filter_switch == Const.OFF: - if isinstance(x, bool) or isinstance(x, int) or isinstance(x, float): - data_info = get_scalar_data_info(x) - dump_data_by_rank_count(dump_flag, prefix, data_info) - - -def overflow_debug_mode_enalbe(): - overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE) - return overflow_mode == Const.ENV_ENABLE - - -def check_overflow_npu(): - if overflow_debug_mode_enalbe(): - float_status = torch.zeros(8).npu() - result = torch_npu.npu_get_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE) - if (result.cpu()[0] != 0): - return True - else: - return False - else: - return torch_npu._C._check_overflow_npu() - - -def clear_overflow_npu(): - if overflow_debug_mode_enalbe(): - float_status = torch.zeros(8).npu() - torch_npu.npu_clear_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE) - else: - torch_npu._C._clear_overflow_npu() \ No newline at end of file -- Gitee