diff --git a/debug/accuracy_tools/atat/pytorch/common/exceptions.py b/debug/accuracy_tools/atat/pytorch/common/exceptions.py index 17733b5bfd5f4b8ffcb3cb3602e3f5f54fdef97d..c1adb0cf702c70618c252b7e66141f3f3a875d00 100644 --- a/debug/accuracy_tools/atat/pytorch/common/exceptions.py +++ b/debug/accuracy_tools/atat/pytorch/common/exceptions.py @@ -1,4 +1,3 @@ - class CodedException(Exception): def __init__(self, code, error_info=''): self.error_info = self.err_strs.get(code) + error_info @@ -10,7 +9,7 @@ class CodedException(Exception): class MsaccException(CodedException): INVALID_PARAM_ERROR = 0 OVERFLOW_NUMS_ERROR = 1 - + err_strs = { INVALID_PARAM_ERROR: "[msacc] 无效参数: ", OVERFLOW_NUMS_ERROR: "[msacc] 超过预设溢出次数 当前溢出次数:" @@ -68,8 +67,11 @@ class StepException(CodedException): InvalidPostProcess: "[msacc] 错误的step后处理配置: ", } + class FreeBenchmarkException(CodedException): UnsupportedType = 0 + InvalidGrad = 1 err_strs = { - UnsupportedType: "[msacc] Free benchmark get unsupported type: " - } \ No newline at end of file + UnsupportedType: "[msacc] Free benchmark get unsupported type: ", + InvalidGrad: "[msacc] Free benchmark gradient invalid: ", + } diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/atat/pytorch/common/utils.py index e88d506b2c340f9b6141c2e0bb775a693d61a16c..2f700e915fe6e2b3ac11886325024fddc5d0f38f 100644 --- a/debug/accuracy_tools/atat/pytorch/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/common/utils.py @@ -145,6 +145,8 @@ class Const: GRAD_OUTPUT = 'grad_output' START = "start" STOP = "stop" + MAX = 'Max' + MIN = 'Min' # dump mode ALL = "all" @@ -178,6 +180,7 @@ class Const: # env dump path ASCEND_WORK_PATH = "ASCEND_WORK_PATH" DUMP_DIR = "dump_data" + DATA = "data" ENV_ENABLE = "1" ENV_DISABLE = "0" @@ -192,4 +195,6 @@ class Const: STATISTICS = "statistics" TENSOR = "tensor" OVERFLOW_CHECK = "overflow_check" - FREE_BENCHMARK = "free_benchmark" \ No newline at end of file + FREE_BENCHMARK = "free_benchmark" + + ATTR_NAME_PREFIX = "wrap_" diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py index a8752656ed72bc21773aca2bb06d4e69d96a5c4b..5094da3e2a41b3e96a01cf357d434401e5ed75cc 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py @@ -1,12 +1,12 @@ import torch -from atat.pytorch.free_benchmark import print_info_log_rank_0, print_warn_log_rank_0 -from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams +from atat.pytorch.common.exceptions import FreeBenchmarkException +from atat.pytorch.free_benchmark import print_warn_log_rank_0 from atat.pytorch.free_benchmark.common.constant import CommonField -from atat.pytorch.free_benchmark.common.utils import Tools +from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams +from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory from atat.pytorch.free_benchmark.result_handlers.handler_factory import ( FuzzHandlerFactory, ) -from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory class GradSaver: @@ -39,9 +39,20 @@ class GradSaver: handler, grad, perturbed_grad, index=input_index ) data_processor.update_unequal_rows(handler.get_unequal_rows()) + except IndexError: + print_warn_log_rank_0( + f"[atat] Free benchmark: grad index out of range. api:{self.handler_params.api_name}." + f"index:{new_grad_index}, perturbation grad len {len(self.perturbed_grad_input)}" + ) + return grad + except FreeBenchmarkException as e: + print_warn_log_rank_0( + f"[atat] Free benchmark: grad input check error: {e}" + ) + return grad except Exception as e: print_warn_log_rank_0( - f"[atat] Free benchmark: grad compara error: {e}" + f"[atat] Free benchmark: grad compare error: {e}" ) return grad return grad @@ -72,27 +83,20 @@ class GradSaver: def check_grad_input(self, origin_grad, new_grad_index): if self.perturbed_grad_input is None: - print_info_log_rank_0( - f"[atat] Free benchmark: grad not exsits : {self.api_name}." + raise FreeBenchmarkException( + FreeBenchmarkException.InvalidGrad, + f"grad not exists : {self.api_name}." ) - return None - try: - with torch.no_grad(): - perturbed_grad = self.perturbed_grad_input[new_grad_index].to( - origin_grad.device - ) - except IndexError: - print_warn_log_rank_0( - f"[atat] Free benchmark: grad index out of range. api:{self.handler_params.api_name}." - f"index:{new_grad_index}, perturbation grad len {len(self.perturbed_grad_input)}" + with torch.no_grad(): + perturbed_grad = self.perturbed_grad_input[new_grad_index].to( + origin_grad.device ) - return None if origin_grad.shape != perturbed_grad.shape: - print_warn_log_rank_0( - f"[atat] Free benchmark: grad shapes are unconsistent. api:{self.handler_params.api_name}." + raise FreeBenchmarkException( + FreeBenchmarkException.InvalidGrad, + f"grad shapes are inconsistent. api:{self.handler_params.api_name}." f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}" ) - return None return perturbed_grad def cache_backward_input(self, backward_input_list): @@ -117,12 +121,12 @@ class GradSaver: for object_ in self.backward_input: if isinstance(object_, dict) and CommonField.FUZZ_TENSOR in object_.keys(): tensor_ = torch.tensor( - object_.get(CommonField.FUZZ_TENSOR).data, - dtype=object_.get(CommonField.FUZZ_TENSOR).dtype, - device=object_.get(CommonField.DEVICE), - requires_grad=object_.get(CommonField.REQUIRES_GRAD), - ) - + object_.get(CommonField.FUZZ_TENSOR).data, + dtype=object_.get(CommonField.FUZZ_TENSOR).dtype, + device=object_.get(CommonField.DEVICE), + requires_grad=object_.get(CommonField.REQUIRES_GRAD), + ) + if tensor_.requires_grad: inner_args_tmp.append(CommonField.HOLD_PLACE) need_grad_tensors.append(tensor_) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py index ed834c468ba6f15437da4479a3e2b3257fd7b6c1..80c526be91f90f02603d49d535a7054be99d920c 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py @@ -1,9 +1,9 @@ -import torch import math +import torch from atat.pytorch.free_benchmark import print_warn_log_rank_0 -from atat.pytorch.free_benchmark.common.utils import TorchC from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.utils import TorchC class SingleCompare: @@ -13,6 +13,37 @@ class SingleCompare: self.eb = None self.threshold = None + @staticmethod + def filter_overflow(tensor) -> int: + inf_num = TorchC.sum(TorchC.isinf(tensor)) + nan_num = TorchC.sum(TorchC.isnan(tensor)) + return inf_num + nan_num + + @staticmethod + def replace_inf_or_nan(tensor): + finite_mask = TorchC.isfinite(tensor) + inf_or_nan_mask = TorchC.logical_not(finite_mask) + inf_or_nan_num = TorchC.sum(inf_or_nan_mask).item() + if inf_or_nan_num > 0: + tensor[inf_or_nan_mask] = 1 + return tensor + + def compare_dict_seq(self, actual, golden): + if len(actual) != len(golden): + return False + for key, value in golden.items(): + if not self.compare_seq(value, actual.get(key)): + return False + return True + + def compare_list_seq(self, actual, golden): + if len(actual) != len(golden): + return False + for index_, value in enumerate(golden): + if not self.compare_seq(value, actual[index_]): + return False + return True + def compare_seq(self, actual, golden): if isinstance(golden, torch.Tensor): return self.compare_tensor_seq(actual, golden) @@ -45,6 +76,11 @@ class SingleCompare: return False return True + def compare_float_seq(self, actual, golden): + return math.isclose(actual, golden) + + def compare_other_seq(self, actual, golden): + return actual == golden def _cal_compare_metrics(self, actual, golden): diff_value = TorchC.subtract(actual, golden) @@ -62,42 +98,5 @@ class SingleCompare: # 获取误差均衡性 divided = TorchC.where( TorchC.ge(TorchC.abs(golden), self.threshold.small_value), golden_abs, 1 - ) + ) self.eb = TorchC.mean(TorchC.div(diff_value, divided)) - - def compare_dict_seq(self, actual, golden): - if len(actual) != len(golden): - return False - for key, value in golden.items(): - if not self.compare_seq(value, actual.get(key)): - return False - return True - - def compare_list_seq(self, actual, golden): - if len(actual) != len(golden): - return False - for index_, value in enumerate(golden): - if not self.compare_seq(value, actual[index_]): - return False - return True - - def compare_float_seq(self, actual, golden): - return math.isclose(actual, golden) - - def compare_other_seq(self, actual, golden): - return actual == golden - - @staticmethod - def filter_overflow(tensor) -> int: - inf_num = TorchC.sum(TorchC.isinf(tensor)) - nan_num = TorchC.sum(TorchC.isnan(tensor)) - return inf_num + nan_num - - @staticmethod - def replace_inf_or_nan(tensor): - finite_mask = TorchC.isfinite(tensor) - inf_or_nan_mask = TorchC.logical_not(finite_mask) - inf_or_nan_num = TorchC.sum(inf_or_nan_mask).item() - if inf_or_nan_num > 0: - tensor[inf_or_nan_mask] = 1 - return tensor diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py index d03dbe931d91e5ed91b70c7b2b8fe1fb8f1342fa..d5ba63c6a943ba105e157c6871d1bcdc937620b5 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -4,9 +4,9 @@ from atat.pytorch.free_benchmark import ( print_warn_log_rank_0, ) from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.utils import TorchC -from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -14,6 +14,38 @@ from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( class AddNoiseLayer(NpuBaseLayer): + def add_noise(self, tensor_obj): + if isinstance(tensor_obj, torch.Tensor): + self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get( + tensor_obj.dtype + ) + if not self.pre_check(tensor_obj): + return tensor_obj + noise = self._get_noise(tensor_obj) + result = TorchC.where( + TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value ** 0.5), + TorchC.add(noise, tensor_obj), + tensor_obj, + ).to(tensor_obj.dtype) + self.is_added = True + return result + if isinstance(tensor_obj, dict): + return {key: self.add_noise(value) for key, value in tensor_obj.items()} + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)([self.add_noise(value) for value in tensor_obj]) + return tensor_obj + + def handle(self, params: DataParams) -> torch.Any: + """ + 对输入添加扰动并返回 + """ + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is " + f"{PerturbationMode.ADD_NOISE} of {self.api_name}." + ) + params.perturbed_value = self.add_noise(params.args[params.valid_input_index]) + return self.perturbed_result(params) + def _get_noise(self, tensor_obj): dtype = tensor_obj.dtype device = str(tensor_obj.device) @@ -59,35 +91,3 @@ class AddNoiseLayer(NpuBaseLayer): ) return False return True - - def add_noise(self, tensor_obj): - if isinstance(tensor_obj, torch.Tensor): - self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get( - tensor_obj.dtype - ) - if not self.pre_check(tensor_obj): - return tensor_obj - noise = self._get_noise(tensor_obj) - result = TorchC.where( - TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value**0.5), - TorchC.add(noise, tensor_obj), - tensor_obj, - ).to(tensor_obj.dtype) - self.is_added = True - return result - if isinstance(tensor_obj, dict): - return {key: self.add_noise(value) for key, value in tensor_obj.items()} - if isinstance(tensor_obj, (tuple, list)): - return type(tensor_obj)([self.add_noise(value) for value in tensor_obj]) - return tensor_obj - - def handle(self, params: DataParams) -> torch.Any: - """ - 对输入添加扰动并返回 - """ - print_info_log_rank_0( - f"[atat] Free benchmark: Perturbation is " - f"{PerturbationMode.ADD_NOISE} of {self.api_name}." - ) - params.perturbed_value = self.add_noise(params.args[params.valid_input_index]) - return self.perturbed_result(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py index 72d04af412067882826ea402ed6fa00490bce348..2c1ed9a3e1ceaee11021d5aacbacfd103b9dd9d7 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -4,9 +4,9 @@ from atat.pytorch.free_benchmark import ( print_warn_log_rank_0, ) from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.utils import TorchC -from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -19,6 +19,49 @@ class BitNoiseLayer(NpuBaseLayer): self.bit_tail: int = 1 self.bit_type = None + def add_bit_noise(self, tensor_obj): + """ + 对输入添加噪声 + """ + # finfo应该列入黑名单 + + if isinstance(tensor_obj, torch.Tensor): + self._set_perturbation_bit(tensor_obj) + if not self.pre_check(tensor_obj): + return tensor_obj + sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal + noise = TorchC.full( + tensor_obj.shape, + self.bit_tail, + device=tensor_obj.device, + dtype=self.bit_type, + ) + result = tensor_obj.view(self.bit_type) + result = TorchC.where( + TorchC.gt(TorchC.abs(tensor_obj), sub_normal), + self.bit_mode(result, noise), + result, + ).view(tensor_obj.dtype) + + self.is_added = True + return result + if isinstance(tensor_obj, dict): + return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()} + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj]) + return tensor_obj + + def handle(self, params: DataParams) -> torch.Any: + """ + 对输入添加扰动并返回 + """ + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is " + f"{PerturbationMode.BIT_NOISE} of {self.api_name}." + ) + params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index]) + return self.perturbed_result(params) + def _check_details(self, tensor_obj): """ 判断是否需要添加扰动, bit翻转 @@ -62,46 +105,3 @@ class BitNoiseLayer(NpuBaseLayer): if bit_len_type: self.bit_tail = 1 self.bit_type = bit_len_type - - def add_bit_noise(self, tensor_obj): - """ - 对输入添加噪声 - """ - # finfo应该列入黑名单 - - if isinstance(tensor_obj, torch.Tensor): - self._set_perturbation_bit(tensor_obj) - if not self.pre_check(tensor_obj): - return tensor_obj - sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal - noise = TorchC.full( - tensor_obj.shape, - self.bit_tail, - device=tensor_obj.device, - dtype=self.bit_type, - ) - result = tensor_obj.view(self.bit_type) - result = TorchC.where( - TorchC.gt(TorchC.abs(tensor_obj), sub_normal), - self.bit_mode(result, noise), - result, - ).view(tensor_obj.dtype) - - self.is_added = True - return result - if isinstance(tensor_obj, dict): - return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()} - if isinstance(tensor_obj, (tuple, list)): - return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj]) - return tensor_obj - - def handle(self, params: DataParams) -> torch.Any: - """ - 对输入添加扰动并返回 - """ - print_info_log_rank_0( - f"[atat] Free benchmark: Perturbation is " - f"{PerturbationMode.BIT_NOISE} of {self.api_name}." - ) - params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index]) - return self.perturbed_result(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py index ab91bcb7eeea00085318a21c20bb9f03d69b8908..b4ee67384164cb73abd4e3f3cbaab77b1ffac293 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -1,8 +1,8 @@ import torch from atat.pytorch.free_benchmark import print_warn_log_rank_0, print_info_log_rank_0 +from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.utils import TorchC -from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -14,18 +14,6 @@ class ChangeValueLayer(NpuBaseLayer): self.head: int = 0 self.tail: int = -1 - def _check_details(self, tensor_obj): - """ - 判断是否需要添加扰动, 首尾值交换 - """ - if tensor_obj.size(0) < 2: - print_warn_log_rank_0( - f"[atat] Free Benchmark: For {self.api_name}, " - f"size 0 must greater than 1. Cancel change value." - ) - return False - return True - def change_value(self, tensor_obj): """ 交换张量首尾 @@ -42,7 +30,7 @@ class ChangeValueLayer(NpuBaseLayer): temp_last = TorchC.clone(new_tensor[self.tail][self.tail]) new_tensor[self.head][self.head] = temp_last new_tensor[self.tail][self.tail] = temp_first - + self.is_added = True return new_tensor if isinstance(tensor_obj, dict): @@ -61,3 +49,15 @@ class ChangeValueLayer(NpuBaseLayer): ) params.perturbed_value = self.change_value(params.args[params.valid_input_index]) return self.perturbed_result(params) + + def _check_details(self, tensor_obj): + """ + 判断是否需要添加扰动, 首尾值交换 + """ + if tensor_obj.size(0) < 2: + print_warn_log_rank_0( + f"[atat] Free Benchmark: For {self.api_name}, " + f"size 0 must greater than 1. Cancel change value." + ) + return False + return True diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index fb126972c6853b81d24db8138880601f9a3af21a..07c300c54630334af41f13fdd482f4f67b19d85b 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -1,8 +1,8 @@ import torch from atat.pytorch.free_benchmark import Const, print_info_log_rank_0 from atat.pytorch.free_benchmark.common.constant import CommonField -from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.enums import PerturbationMode +from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -10,24 +10,6 @@ from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( class ImprovePrecisionLayer(NpuBaseLayer): - def _set_improve_valus(self, inputs): - # TODO why - if inputs.dtype in [torch.float16, torch.bfloat16]: - self.perturbed_value = torch.float32 - - def _change_dtype(self, inputs): - if hasattr(inputs, CommonField.DEVICE): - device = inputs.device - if device is CommonField.META: - new_inputs = inputs.to( - device=CommonField.META, dtype=self.perturbed_value - ) - else: - new_inputs = inputs.to(dtype=self.perturbed_value).to(device) - else: - new_inputs = inputs.to(dtype=self.perturbed_value) - return new_inputs - def improve_tensor_precision(self, tensor_obj): if ( isinstance(tensor_obj, torch.Tensor) @@ -62,3 +44,21 @@ class ImprovePrecisionLayer(NpuBaseLayer): new_kwargs["inplace"] = False params.perturbed_result = params.origin_func(*new_args, **new_kwargs) return params.perturbed_result + + def _set_improve_valus(self, inputs): + # TODO why + if inputs.dtype in [torch.float16, torch.bfloat16]: + self.perturbed_value = torch.float32 + + def _change_dtype(self, inputs): + if hasattr(inputs, CommonField.DEVICE): + device = inputs.device + if device is CommonField.META: + new_inputs = inputs.to( + device=CommonField.META, dtype=self.perturbed_value + ) + else: + new_inputs = inputs.to(dtype=self.perturbed_value).to(device) + else: + new_inputs = inputs.to(dtype=self.perturbed_value) + return new_inputs diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py index 7ec5870fb72db30101f41a8ec057bf95d94da9b3..204e649d805a86a83509475f67c2cf477028f356 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py @@ -1,7 +1,7 @@ import torch from atat.pytorch.free_benchmark import print_info_log_rank_0 -from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.enums import PerturbationMode +from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -16,7 +16,6 @@ class NoChangeLayer(NpuBaseLayer): self.is_added = True return tensor_obj - def handle(self, params: DataParams) -> torch.Any: """ 对输入添加扰动并返回 diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py index ca502365e1b1b4ae0b37e2ecc48bff3b203f765c..3784af0953022f1eb981ca26cf88765044f56f3f 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py @@ -1,8 +1,7 @@ from abc import abstractmethod from typing import Any + import torch -from atat.pytorch.free_benchmark.common.constant import CommonField, ThresholdConfig -from atat.pytorch.free_benchmark.common.utils import TorchC from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer @@ -13,13 +12,22 @@ class NpuBaseLayer(BaseLayer): self.perturbed_value = None # 扰动的元素 self.is_added = False # 标记当前算子输入是否调整 + @staticmethod + def perturbed_result(params: DataParams) -> Any: + args_front = params.args[: params.valid_input_index] + args_rear = params.args[params.valid_input_index + 1:] + # 此处会将有inplace属性的算子换为非inplace + if "inplace" in params.kwargs: + params.kwargs["inplace"] = False + params.perturbed_result = params.origin_func( + *args_front, params.perturbed_value, *args_rear, **params.kwargs + ) + return params.perturbed_result + @abstractmethod def handle(self, params: DataParams) -> Any: pass - def _check_details(self, tensor_obj): - return True - def pre_check(self, tensor_obj): """ 检查张量是否符合标准(float类型且最大值大于对应精度最小值) @@ -33,14 +41,5 @@ class NpuBaseLayer(BaseLayer): return False return True - @staticmethod - def perturbed_result(params: DataParams) -> Any: - args_front = params.args[: params.valid_input_index] - args_rear = params.args[params.valid_input_index + 1 :] - # 此处会将有inplace属性的算子换为非inplace - if "inplace" in params.kwargs: - params.kwargs["inplace"] = False - params.perturbed_result = params.origin_func( - *args_front, params.perturbed_value, *args_rear, **params.kwargs - ) - return params.perturbed_result + def _check_details(self, tensor_obj): + return True diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py index 1d59ef9fc3adc2f90a7145d825ce597e209758e4..0b6e7c151115dc09a3d5e83d70424fb9c85a275d 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py @@ -7,7 +7,6 @@ from atat.pytorch.free_benchmark import ( Const, print_warn_log_rank_0, ) -from atat.pytorch.free_benchmark.common.utils import TorchC from atat.pytorch.free_benchmark.common.constant import ThresholdConfig from atat.pytorch.free_benchmark.common.enums import ( FuzzThreshold, @@ -15,6 +14,7 @@ from atat.pytorch.free_benchmark.common.enums import ( PerturbationMode, ) from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams, make_unequal_row +from atat.pytorch.free_benchmark.common.utils import TorchC class FuzzHandler(ABC): @@ -41,52 +41,45 @@ class FuzzHandler(ABC): abs_tol, ) - def get_ratio_from_specific_norm( - self, origin_output, perturbed_output, norm_type, abs_tol - ): - if norm_type == NormType.ENDLESS_NORM: - return self.get_endless_norm(origin_output, perturbed_output, abs_tol) - return ThresholdConfig.COMP_CONSISTENT - @staticmethod def convert_overflow_ratio_to_consistent(ratio): if math.isnan(ratio) or math.isinf(ratio): return ThresholdConfig.COMP_CONSISTENT return ratio + @abstractmethod + def get_threshold(self, dtype): + pass + + @abstractmethod + def handle(self, data_params: DataParams) -> Any: + pass + + def get_ratio_from_specific_norm( + self, origin_output, perturbed_output, norm_type, abs_tol + ): + if norm_type == NormType.ENDLESS_NORM: + return self.get_endless_norm(origin_output, perturbed_output, abs_tol) + return ThresholdConfig.COMP_CONSISTENT + def get_endless_norm(self, origin_output, perturbed_output, abs_tol): - try: - ratio_tensor1 = TorchC.where( - TorchC.gt(TorchC.abs(perturbed_output), abs_tol), - TorchC.div( - TorchC.abs(origin_output), - TorchC.add(TorchC.abs(perturbed_output), abs_tol), - ), - 1, - ) - ratio_tensor2 = TorchC.where( - TorchC.gt(TorchC.abs(origin_output), abs_tol), - TorchC.div( - TorchC.abs(perturbed_output), - TorchC.add(TorchC.abs(origin_output), abs_tol), - ), - 1, - ) - except: - ratio_tensor1 = TorchC.where( - TorchC.gt(TorchC.abs(perturbed_output.to(torch.float32)), abs_tol), - TorchC.div( - origin_output.to(torch.float32), perturbed_output.to(torch.float32) - ), - 1, - ) - ratio_tensor2 = TorchC.where( - TorchC.gt(TorchC.abs(origin_output.to(torch.float32)), abs_tol), - TorchC.div( - perturbed_output.to(torch.float32), origin_output.to(torch.float32) - ), - 1, - ) + ratio_tensor1 = TorchC.where( + TorchC.gt(TorchC.abs(perturbed_output), abs_tol), + TorchC.div( + TorchC.abs(origin_output), + TorchC.add(TorchC.abs(perturbed_output), abs_tol), + ), + 1, + ) + ratio_tensor2 = TorchC.where( + TorchC.gt(TorchC.abs(origin_output), abs_tol), + TorchC.div( + TorchC.abs(perturbed_output), + TorchC.add(TorchC.abs(origin_output), abs_tol), + ), + 1, + ) + norm1 = self.convert_overflow_ratio_to_consistent( TorchC.max(ratio_tensor1).item() ) @@ -117,31 +110,21 @@ class FuzzHandler(ABC): if self.params.fuzz_stage == Const.BACKWARD: abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND else: - abs_tol = abs_tol**0.5 + abs_tol = abs_tol ** 0.5 return self.get_ratio_from_specific_norm( origin_output, perturbed_output, norm_type, abs_tol ) - @abstractmethod - def get_threshold(self, dtype): - pass - - def _get_default_threshold(self, dtype): - if self.params.pert_mode == PerturbationMode.NO_CHANGE: - threshold = ThresholdConfig.COMP_CONSISTENT - else: - threshold = ThresholdConfig.DTYPE_PER_THD.get( - dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32) - ) - return threshold - def npu_compare( - self, origin_output, perturbed_output + self, origin_output, perturbed_output ) -> Tuple[bool, Optional[float]]: if isinstance(perturbed_output, int): return origin_output == perturbed_output, None elif isinstance(perturbed_output, float): + if perturbed_output == 0: + origin_output += FuzzThreshold.F32_THD + perturbed_output = FuzzThreshold.F32_THD return ( math.isclose(origin_output, perturbed_output), origin_output / perturbed_output, @@ -190,7 +173,7 @@ class FuzzHandler(ABC): max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio) ) data_params.is_consistent = ( - is_consistent and data_params.is_consistent + is_consistent and data_params.is_consistent ) if not is_consistent and data_params.grad_unequal_flag: self.unequal_rows.append( @@ -205,9 +188,14 @@ class FuzzHandler(ABC): ) return npu_consistent, max_fuzz_ratio - @abstractmethod - def handle(self, data_params: DataParams) -> Any: - pass - def get_unequal_rows(self): return self.unequal_rows + + def _get_default_threshold(self, dtype): + if self.params.pert_mode == PerturbationMode.NO_CHANGE: + threshold = ThresholdConfig.COMP_CONSISTENT + else: + threshold = ThresholdConfig.DTYPE_PER_THD.get( + dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32) + ) + return threshold diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py index b8ff3bccf00c2dbe699159b4f77da86c75ae4062..1e70067b93d3031d5ac640e5e442087bca9a63aa 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py @@ -1,16 +1,15 @@ +import math from typing import Any -import torch -import math from atat.pytorch.free_benchmark import print_info_log_rank_0, print_warn_log_rank_0 from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.counter import preheat_counter from atat.pytorch.free_benchmark.common.enums import DeviceType -from atat.pytorch.free_benchmark.common.params import DataParams, make_unequal_row +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.common.params import HandlerParams from atat.pytorch.free_benchmark.common.utils import Tools from atat.pytorch.free_benchmark.compare.single_benchmark import SingleCompare -from atat.pytorch.free_benchmark.common.counter import preheat_counter from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler -from atat.pytorch.free_benchmark.common.params import HandlerParams class PreheatHandler(FuzzHandler): @@ -22,6 +21,75 @@ class PreheatHandler(FuzzHandler): def get_threshold(self, dtype): return preheat_counter.get_api_thd(self.pure_name, dtype) + def compare_npu_and_cpu(self, data_params: DataParams): + args = Tools.convert_device_and_dtype( + data_params.args, DeviceType.CPU, change_dtype=True + ) + kwargs = Tools.convert_device_and_dtype( + data_params.kwargs, DeviceType.CPU, change_dtype=True + ) + cpu_result = data_params.origin_func(*args, **kwargs) + return SingleCompare().compare_seq(data_params.original_result, cpu_result) + + def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype): + # 存储当前step所有输出比值和对应npu\cpu比对结果 + preheat_counter.update_preheat_record( + self.pure_name, + first_dtype, + (max_fuzz_ratio, cpu_consistent), + ) + if self._need_adjust_threshold(): + self._adjust_threshold() + + def handle(self, data_params: DataParams) -> Any: + + if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor( + data_params.perturbed_result + ): + return data_params.original_result + + if self.params.step == 0: + preheat_counter.add_one_step_used_api(self.pure_name) + return data_params.original_result + + # 如果当前api,step需要预热 + npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params) + data_params.is_consistent = npu_consistent + + preheat_counter.check_step(self.params.step) + + if self.params.preheat_config.get("preheat_step") <= self.params.step: + return data_params.original_result + + if not data_params.grad_unequal_flag: + data_params.grad_unequal_flag = True + data_params.is_consistent = False + return data_params.original_result + preheat_counter.add_api_called_time(self.pure_name) + + if not self._is_take_a_sample(): + return data_params.original_result + + cpu_consistent = True + try: + cpu_consistent = self.compare_npu_and_cpu(data_params) + except Exception as e: + print_warn_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name}, " + f"when campare to cpu exception raise {e}" + ) + try: + first_dtype = Tools.get_first_tensor_dtype(data_params.perturbed_result) + except RuntimeError: + print_warn_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name}, " + f"the output sequence does not contain tensors." + ) + if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)): + self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype) + + return data_params.original_result + def _is_take_a_sample(self) -> bool: need_sample_set = self._get_need_sample_set() curr_called_seq = preheat_counter.get_api_called_time(self.pure_name) @@ -61,17 +129,6 @@ class PreheatHandler(FuzzHandler): need_sample_set.add(count) return need_sample_set - - def compare_npu_and_cpu(self, data_params: DataParams): - args = Tools.convert_device_and_dtype( - data_params.args, DeviceType.CPU, change_dtype=True - ) - kwargs = Tools.convert_device_and_dtype( - data_params.kwargs, DeviceType.CPU, change_dtype=True - ) - cpu_result = data_params.origin_func(*args, **kwargs) - return SingleCompare().compare_seq(data_params.original_result, cpu_result) - def _need_adjust_threshold(self) -> bool: sample_count_per_step = self._get_sample_count_per_step() sampled_time = preheat_counter.get_api_sample_time(self.pure_name) @@ -112,63 +169,3 @@ class PreheatHandler(FuzzHandler): preheat_counter.update_api_thd( self.pure_name, dtype_str, new_thd, threshold ) - - def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype): - # 存储当前step所有输出比值和对应npu\cpu比对结果 - preheat_counter.update_preheat_record( - self.pure_name, - first_dtype, - (max_fuzz_ratio, cpu_consistent), - ) - if self._need_adjust_threshold(): - self._adjust_threshold() - - def handle(self, data_params: DataParams) -> Any: - - if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor( - data_params.perturbed_result - ): - return data_params.original_result - - if self.params.step == 0: - preheat_counter.add_one_step_used_api(self.pure_name) - return data_params.original_result - - # 如果当前api,step需要预热 - npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params) - data_params.is_consistent = npu_consistent - - preheat_counter.check_step(self.params.step) - - if self.params.preheat_config.get("preheat_step") <= self.params.step: - return data_params.original_result - - if not data_params.grad_unequal_flag: - data_params.grad_unequal_flag = True - data_params.is_consistent = False - return data_params.original_result - preheat_counter.add_api_called_time(self.pure_name) - - - if not self._is_take_a_sample(): - return data_params.original_result - - cpu_consistent = True - try: - cpu_consistent = self.compare_npu_and_cpu(data_params) - except Exception as e: - print_warn_log_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name}, " - f"when campare to cpu exception raise {e}" - ) - try: - first_dtype = Tools.get_first_tensor_dtype(data_params.perturbed_result) - except RuntimeError: - print_warn_log_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name}, " - f"the output sequence does not contain tensors." - ) - if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)): - self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype) - - return data_params.original_result diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_collector.py b/debug/accuracy_tools/atat/pytorch/functional/data_collector.py index 7964c955db64682a1726b131118d5b53e9d17c8a..8e4011a054a85736d005063ca122b8bc0885c9fb 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/data_collector.py +++ b/debug/accuracy_tools/atat/pytorch/functional/data_collector.py @@ -1,11 +1,13 @@ import os + import torch -from ..module_processer import ModuleProcesser -from .scope import build_scope, ListScope + +from .data_processor import build_data_processor, DataProcessor from .json_writer import DataWriter +from .scope import build_scope, ListScope from ..common.log import print_info_log, print_warn_log from ..common.utils import Const -from .data_processor import build_data_processor, DataProcessor +from ..module_processer import ModuleProcesser try: import torch_npu @@ -37,12 +39,6 @@ class DataCollector: else: self.scope = build_scope(None, self.config.scope, self.config.list) - def if_return_forward_new_output(self): - return self.data_processor.if_return_forward_new_output() - - def get_forward_new_output(self): - return self.data_processor.get_forward_new_output() - @property def dump_data_dir(self): return self.data_writer.dump_tensor_data_dir @@ -51,6 +47,20 @@ class DataCollector: def dump_file_path(self): return self.data_writer.dump_file_path + @staticmethod + def check_scope_and_pid(scope, name, pid): + return (not scope or scope.check(name)) and pid == os.getpid() + + @staticmethod + def is_inplace(module): + return getattr(module, "op_is_inplace", False) + + def if_return_forward_new_output(self): + return self.data_processor.if_return_forward_new_output() + + def get_forward_new_output(self): + return self.data_processor.get_forward_new_output() + def visit_and_clear_overflow_status(self, api_or_module_name): self.data_processor.visit_and_clear_overflow_status(api_or_module_name) @@ -68,14 +78,6 @@ class DataCollector: self.data_writer.update_data(data_info) return msg - @staticmethod - def check_scope_and_pid(scope, name, pid): - return (not scope or scope.check(name)) and pid == os.getpid() - - @staticmethod - def is_inplace(module): - return getattr(module, "op_is_inplace", False) - def pre_forward_data_collect(self, name, module, pid, module_input_output): backward_name = name.replace("forward", "backward") if self.check_scope_and_pid(self.scope, backward_name, pid): diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py index 1ef1b79acb2172daa3bc85d11ffe4049d1bca942..5a8ce8b16a74e9107331a4458a1f89bb3ba628e2 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py +++ b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py @@ -1,21 +1,23 @@ -import torch -import zlib -import numpy as np -import os import inspect +import os +import zlib from dataclasses import dataclass, asdict -import torch_npu from typing import Tuple, List, Dict, Optional, Union + +import numpy as np +import torch +import torch_npu + +from ..common import recursive_apply_transform from ..common.exceptions import MsaccException from ..common.file_check import path_len_exceeds_limit, change_mode, FileCheckConst from ..common.log import print_warn_log from ..common.utils import Const -from ..common import recursive_apply_transform -from ..functional. json_writer import DataWriter from ..free_benchmark import FreeBenchmarkCheck, UnequalRow bits_for_overflow = 8 + def build_data_processor(config, data_writer): if config.task == DataProcessor.full: return FullTensorDataProcessor(config, data_writer) @@ -27,12 +29,12 @@ def build_data_processor(config, data_writer): return FreeBenchmarkDataProcessor(config, data_writer) else: raise MsaccException(MsaccException.INVALID_PARAM_ERROR, - "task should be in [{}, {}, {}, {}]".format( - DataProcessor.full, - DataProcessor.summary, - DataProcessor.overflow, - DataProcessor.free_benchmark - )) + "task should be in [{}, {}, {}, {}]".format( + DataProcessor.full, + DataProcessor.summary, + DataProcessor.overflow, + DataProcessor.free_benchmark + )) @dataclass @@ -44,14 +46,14 @@ class ModuleForwardInputsOutputs: @property def args_tuple(self): if not isinstance(self.args, tuple): - return (self.args, ) + return (self.args,) else: return self.args @property def output_tuple(self): if not isinstance(self.output, tuple): - return (self.output, ) + return (self.output,) else: return self.output @@ -68,18 +70,26 @@ class ModuleBackwardInputsOutputs: @property def grad_input_tuple(self): if not isinstance(self.grad_input, tuple): - return (self.grad_input, ) + return (self.grad_input,) else: return self.grad_input @property def grad_output_tuple(self): if not isinstance(self.grad_output, tuple): - return (self.grad_output, ) + return (self.grad_output,) else: return self.grad_output +class TensorStatInfo: + def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): + self.max = max_val + self.min = min_val + self.mean = mean_val + self.norm = norm_val + + class DataProcessor: full = "tensor" summary = "statistics" @@ -104,13 +114,6 @@ class DataProcessor: self._return_forward_new_output = False self._forward_new_output = None - def if_return_forward_new_output(self): - return self._return_forward_new_output - - def get_forward_new_output(self): - self._return_forward_new_output = False - return self._forward_new_output - @staticmethod def get_md5_for_tensor(x): if x.dtype == torch.bfloat16: @@ -156,69 +159,80 @@ class DataProcessor: return builtin_type(arg), type(arg).__name__ return arg, '' - def update_iter(self, current_iter): - self.current_iter = current_iter - - def visit_and_clear_overflow_status(self, api_or_module_name): - if self.current_api_or_module_name != api_or_module_name: - self.current_api_or_module_name = api_or_module_name - self.has_overflow = False + @staticmethod + def handle_tensor_extremum_nan_inf(data_clone, operator): + data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) + if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): + return float('nan') + finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) + if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: + finite_values = data_clone[finite_mask] + return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(finite_values).item() + else: + data_no_nan = data_clone[~data_nan] + return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(data_no_nan).item() - def _analyze_numpy(self, value, numpy_type): - single_arg = {} - single_arg.update({"type": numpy_type}) - single_arg.update({"value": value}) - return single_arg + @staticmethod + def analyze_api_call_stack(name): + stack_str = [] + for (_, path, line, func, code, _) in inspect.stack()[5:]: + if not code: + continue + stack_line = " ".join([ + "File", ", ".join([ + path, + " ".join(["line", str(line)]), + " ".join(["in", func]), + " ".join(["\n", code[0].strip()]) + ]) + ]) + stack_str.append(stack_line) + stack_info_struct = {name: stack_str} + return stack_info_struct def get_stat_info(self, data): + tensor_stat = TensorStatInfo() if data.is_meta: - return + return tensor_stat data_clone = data.detach() if data_clone.numel() == 0: - tensor_max = None - tensor_min = None - tensor_mean = None - tensor_norm = None + return tensor_stat elif data_clone.dtype == torch.bool: - tensor_max = True in data_clone - tensor_min = False not in data_clone - tensor_mean = None - tensor_norm = None - elif not len(data_clone.shape): - tensor_max = data_clone.item() - tensor_min = tensor_max - tensor_mean = tensor_max - tensor_norm = tensor_max + tensor_stat.max = True in data_clone + tensor_stat.min = False not in data_clone + tensor_stat.mean = None + tensor_stat.norm = None + elif not data_clone.shape: + tensor_stat.max = data_clone.item() + tensor_stat.min = tensor_stat.max + tensor_stat.mean = tensor_stat.max + tensor_stat.norm = tensor_stat.max else: if not data_clone.is_floating_point(): data_clone = data_clone.float() - tensor_max = torch._C._VariableFunctionsClass.max(data_clone).item() - tensor_min = torch._C._VariableFunctionsClass.min(data_clone).item() - tensor_mean = torch._C._VariableFunctionsClass.mean(data_clone).item() - tensor_norm = torch._C._VariableFunctionsClass.norm(data_clone).item() + tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item() + tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item() + tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item() + tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item() - return tensor_max, tensor_min, tensor_mean, tensor_norm + return tensor_stat - def _analyze_builtin(self, arg): - single_arg = {} - if isinstance(arg, slice): - single_arg.update({"type": "slice"}) - # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型 - values = [ - value if not isinstance(value, torch.Tensor) else value.item() - for value in [arg.start, arg.stop, arg.step] - ] - single_arg.update({"value": values}) - else: - single_arg.update({"type": type(arg).__name__}) - single_arg.update({"value": arg}) - return single_arg + def if_return_forward_new_output(self): + return self._return_forward_new_output - def _analyze_torch_size(self, arg): - single_arg = {} - single_arg.update({"type": "torch.Size"}) - single_arg.update({"value": list(arg)}) - return single_arg + def get_forward_new_output(self): + self._return_forward_new_output = False + return self._forward_new_output + + def update_iter(self, current_iter): + self.current_iter = current_iter + + def visit_and_clear_overflow_status(self, api_or_module_name): + if self.current_api_or_module_name != api_or_module_name: + self.current_api_or_module_name = api_or_module_name + self.has_overflow = False def is_dump_for_data_mode(self, forward_backward, input_output): """ @@ -235,56 +249,6 @@ class DataProcessor: forward_backward in self.config.data_mode or input_output in self.config.data_mode) - @staticmethod - def handle_tensor_extremum_nan_inf(data_clone, operator): - data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) - if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): - return float('nan') - finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) - if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: - finite_values = data_clone[finite_mask] - return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(finite_values).item() - else: - data_no_nan = data_clone[~data_nan] - return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(data_no_nan).item() - - def _analyze_maybe_overflow_tensor(self, tensor_json, tensor): - data_clone = tensor.detach() - if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan(): - if tensor_json['Max'] is None: - return - if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): - tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max") - self.has_overflow = True - if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): - tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min") - self.has_overflow = True - else: - self.has_overflow = check_overflow_npu() - if self.has_overflow: - clear_overflow_npu() - - def _analyze_tensor(self, tensor, suffix): - tensor_max, tensor_min, tensor_mean, tensor_norm = self.get_stat_info(tensor) - - tensor_json = {} - tensor_json.update({'type': 'torch.Tensor'}) - tensor_json.update({'dtype': str(tensor.dtype)}) - tensor_json.update({"shape": tensor.shape}) - tensor_json.update({"Max": tensor_max}) - tensor_json.update({"Min": tensor_min}) - self._analyze_maybe_overflow_tensor(tensor_json, tensor) - tensor_json.update({"Mean": tensor_mean}) - tensor_json.update({"Norm": tensor_norm}) - tensor_json.update({"requires_grad": tensor.requires_grad}) - if self.config.summary_mode == "md5": - tensor_md5 = self.get_md5_for_tensor(tensor) - tensor_json.update({"md5": tensor_md5}) - - return tensor_json - def analyze_single_element(self, element, suffix_stack): if suffix_stack and suffix_stack[-1] in self.torch_object_key: return self.torch_object_key[suffix_stack[-1]](element) @@ -301,35 +265,18 @@ class DataProcessor: if isinstance(element, (bool, int, float, str, slice)): return self._analyze_builtin(element) + return {} def analyze_element(self, element): return recursive_apply_transform(element, self.analyze_single_element) - @staticmethod - def analyze_api_call_stack(name): - stack_str = [] - for (_, path, line, func, code, _) in inspect.stack()[5:]: - if not code: - continue - stack_line = " ".join([ - "File", ", ".join([ - path, - " ".join(["line", str(line)]), - " ".join(["in", func]), - " ".join(["\n", code[0].strip()]) - ]) - ]) - stack_str.append(stack_line) - stack_info_struct = {name: stack_str} - return stack_info_struct - def analyze_pre_forward(self, name, module, - module_input_output: ModuleForwardInputsOutputs): + module_input_output: ModuleForwardInputsOutputs): pass def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): api_info_struct = {} - if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input + if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input api_info_struct[name] = {} self.api_data_category = Const.INPUT args_info_list = self.analyze_element(module_input_output.args_tuple) @@ -339,7 +286,8 @@ class DataProcessor: kwargs_info_list = self.analyze_element(module_input_output.kwargs) api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list - if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): # check whether data_mode contains forward or output + if self.is_dump_for_data_mode(Const.FORWARD, + Const.OUTPUT): # check whether data_mode contains forward or output api_info_struct[name] = api_info_struct.get(name, {}) self.api_data_category = Const.OUTPUT output_info_list = self.analyze_element(module_input_output.output_tuple) @@ -372,7 +320,6 @@ class DataProcessor: return api_info_struct - def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs): api_info_struct = {} if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT): @@ -389,11 +336,76 @@ class DataProcessor: return api_info_struct + def _analyze_numpy(self, value, numpy_type): + single_arg = {} + single_arg.update({"type": numpy_type}) + single_arg.update({"value": value}) + return single_arg -class FullTensorDataProcessor(DataProcessor): + def _analyze_builtin(self, arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({"type": "slice"}) + # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型 + values = [ + value if not isinstance(value, torch.Tensor) else value.item() + for value in [arg.start, arg.stop, arg.step] + ] + single_arg.update({"value": values}) + else: + single_arg.update({"type": type(arg).__name__}) + single_arg.update({"value": arg}) + return single_arg + + def _analyze_torch_size(self, arg): + single_arg = {} + single_arg.update({"type": "torch.Size"}) + single_arg.update({"value": list(arg)}) + return single_arg + + def _analyze_maybe_overflow_tensor(self, tensor_json, tensor): + data_clone = tensor.detach() + if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan(): + if tensor_json[Const.MAX] is None: + return + if np.isinf(tensor_json[Const.MAX]) or np.isnan(tensor_json[Const.MAX]): + tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max") + self.has_overflow = True + if np.isinf(tensor_json[Const.MIN]) or np.isnan(tensor_json[Const.MIN]): + tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min") + self.has_overflow = True + else: + self.has_overflow = check_overflow_npu() + if self.has_overflow: + clear_overflow_npu() def _analyze_tensor(self, tensor, suffix): + tensor_stat = self.get_stat_info(tensor) + + tensor_json = {} + tensor_json.update({'type': 'torch.Tensor'}) + tensor_json.update({'dtype': str(tensor.dtype)}) + tensor_json.update({"shape": tensor.shape}) + tensor_json.update({"Max": tensor_stat.max}) + tensor_json.update({"Min": tensor_stat.min}) + self._analyze_maybe_overflow_tensor(tensor_json, tensor) + tensor_json.update({"Mean": tensor_stat.mean}) + tensor_json.update({"Norm": tensor_stat.norm}) + tensor_json.update({"requires_grad": tensor.requires_grad}) + if self.config.summary_mode == "md5": + tensor_md5 = self.get_md5_for_tensor(tensor) + tensor_json.update({"md5": tensor_md5}) + + return tensor_json + + +class FullTensorDataProcessor(DataProcessor): + + def __init__(self, config, data_writer): + super().__init__(config, data_writer) self.data_path = self.data_writer.dump_tensor_data_dir + + def _analyze_tensor(self, tensor, suffix): dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + suffix + ".pt") file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) @@ -417,7 +429,6 @@ class OverflowTensorDataProcessor(DataProcessor): self.overflow_nums = config.overflow_num def _analyze_tensor(self, tensor, suffix): - self.data_path = self.data_writer.dump_tensor_data_dir dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + suffix + ".pt") file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) @@ -437,7 +448,7 @@ class OverflowTensorDataProcessor(DataProcessor): return api_info_struct if self.has_overflow else None def analyze_backward(self, name, module, - module_input_output: ModuleBackwardInputsOutputs): + module_input_output: ModuleBackwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_backward(name, module, module_input_output) self.maybe_save_overflow_data_and_check_overflow_times() @@ -483,7 +494,7 @@ class FreeBenchmarkDataProcessor(DataProcessor): return def analyze_pre_forward(self, name, module, - module_input_output: ModuleForwardInputsOutputs): + module_input_output: ModuleForwardInputsOutputs): args = module_input_output.args kwargs = module_input_output.kwargs self.checker.pre_forward(name, module, self, args, kwargs) @@ -495,7 +506,7 @@ class FreeBenchmarkDataProcessor(DataProcessor): module_input_output.args, module_input_output.kwargs, module_input_output.output, - ) + ) self.update_unequal_rows(unequal_rows) if self.checker.if_fix(): self._return_forward_new_output = True @@ -507,11 +518,11 @@ class FreeBenchmarkDataProcessor(DataProcessor): return None - def overflow_debug_mode_enable(): overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE) return overflow_mode == Const.ENV_ENABLE + def check_overflow_npu(): if overflow_debug_mode_enable(): float_status = torch.zeros(bits_for_overflow).npu() @@ -523,6 +534,7 @@ def check_overflow_npu(): else: return torch_npu._C._check_overflow_npu() + def clear_overflow_npu(): if overflow_debug_mode_enable(): float_status = torch.zeros(bits_for_overflow).npu() @@ -530,6 +542,7 @@ def clear_overflow_npu(): else: torch_npu._C._clear_overflow_npu() + class OverflowConst: """ Class for Overflow diff --git a/debug/accuracy_tools/atat/pytorch/functional/json_writer.py b/debug/accuracy_tools/atat/pytorch/functional/json_writer.py index 0fee3aa9731aa79c2f0e5857fb8596a86e86b6d7..216d24882e44b98cc606e4cab5e417602015118c 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/json_writer.py +++ b/debug/accuracy_tools/atat/pytorch/functional/json_writer.py @@ -1,16 +1,14 @@ -import os import csv -from pathlib import Path import json +import os +from pathlib import Path + from ..common.file_check import FileCheckConst, change_mode from ..common.log import print_info_log_rank_0 from ..common.utils import Const class DataWriter: # TODO: UT - # dump_json_name = "dump.json" - # stack_json_name = "stack.json" - # construct_json_name = "construct.json" def __init__(self, init_json=None) -> None: self.dump_count = 0 @@ -18,17 +16,31 @@ class DataWriter: # TODO: UT self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name) self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name) self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name) - self.free_benchmark_file_path = None + self.free_benchmark_file_path = None self.dump_tensor_data_dir = None self.buffer_size = 1000 - self.cache_data = {"data": {}} + self.cache_data = {Const.DATA: {}} self.cache_stack = {} self.cache_construct = {} + @staticmethod + def write_data_to_csv(result: list, result_header: tuple, file_path: str): + if len(result) == 0: + return + is_exists = os.path.exists(file_path) + append = "a+" if is_exists else "w+" + with os.fdopen( + os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline="" + ) as csv_file: + spawn_writer = csv.writer(csv_file) + if not is_exists: + spawn_writer.writerow(result_header) + spawn_writer.writerows([result, ]) + def initialize_json_file(self, **kwargs): - kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, "data": {}}) + kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) with os.fdopen( - os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w' + os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w' ) as f: json.dump(kwargs, f) @@ -42,7 +54,8 @@ class DataWriter: # TODO: UT Path(self.construct_file_path).touch() change_mode(self.construct_file_path, FileCheckConst.DATA_FILE_AUTHORITY) - def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path): + def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir, + free_benchmark_file_path): self.dump_file_path = dump_file_path self.stack_file_path = stack_file_path self.construct_file_path = construct_file_path @@ -51,13 +64,13 @@ class DataWriter: # TODO: UT def update_data(self, new_data): key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1 - if key in self.cache_data["data"]: - self.cache_data["data"][key].update(new_data[key]) + if key in self.cache_data[Const.DATA]: + self.cache_data[Const.DATA][key].update(new_data[key]) else: - self.cache_data["data"].update(new_data) + self.cache_data[Const.DATA].update(new_data) def flush_data_when_buffer_is_full(self): - if len(self.cache_data["data"]) >= self.buffer_size: + if len(self.cache_data[Const.DATA]) >= self.buffer_size: self.write_data_json(self.dump_file_path) def update_stack(self, new_data): @@ -77,13 +90,13 @@ class DataWriter: # TODO: UT else: self.init_json['data_path'] = self.dump_tensor_data_dir data_to_write = self.init_json - data_to_write['data'].update(self.cache_data['data']) + data_to_write[Const.DATA].update(self.cache_data[Const.DATA]) with open(file_path, 'w+') as f: fcntl.flock(f, fcntl.LOCK_EX) json.dump(data_to_write, f, indent=1) fcntl.flock(f, fcntl.LOCK_UN) - self.cache_data["data"].clear() + self.cache_data[Const.DATA].clear() def write_stack_info_json(self, file_path): import fcntl @@ -103,18 +116,3 @@ class DataWriter: # TODO: UT self.write_data_json(self.dump_file_path) self.write_stack_info_json(self.stack_file_path) self.write_construct_info_json(self.construct_file_path) - - @staticmethod - def write_data_to_csv(result: list, result_header: tuple, file_path: str): - if len(result) == 0: - return - is_exists = os.path.exists(file_path) - append = "a+" if is_exists else "w+" - with os.fdopen( - os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline="" - ) as csv_file: - spawn_writer = csv.writer(csv_file) - if not is_exists: - spawn_writer.writerow(result_header) - spawn_writer.writerows([result,]) - \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/functional/scope.py b/debug/accuracy_tools/atat/pytorch/functional/scope.py index e557b876b1b00beef60dd623175374ad20d6a287..cced68f08a3cda86fddd836fabbb50b060994e7a 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/scope.py +++ b/debug/accuracy_tools/atat/pytorch/functional/scope.py @@ -3,7 +3,11 @@ from ..common.exceptions import ScopeException from ..common.utils import Const -def build_scope(scope_class, scope=[], api_list=[]): +def build_scope(scope_class, scope=None, api_list=None): + if api_list is None: + api_list = [] + if scope is None: + scope = [] if not scope and not api_list: return None if scope_class: @@ -30,6 +34,11 @@ class BaseScope(ABC): Module_Type_Module = "Module" Module_Type_API = "api" + def __init__(self, scope, api_list): + scope, api_list = self.rectify_args(scope, api_list) + self.scope = scope + self.api_list = api_list + @staticmethod def rectify_args(scope, api_list): if not isinstance(api_list, list): @@ -51,10 +60,9 @@ class BaseScope(ABC): f"scope列表元素要求类型为字符串,实际类型为{type(s)}.") return scope, api_list - def __init__(self, scope, api_list): - scope, api_list = self.rectify_args(scope, api_list) - self.scope = scope - self.api_list = api_list + @abstractmethod + def check(self, name): + pass def check_api_list(self, api_name): if not self.api_list: @@ -62,11 +70,7 @@ class BaseScope(ABC): for api_str in self.api_list: if api_str in api_name: return True - - @abstractmethod - def check(self, name): - pass - + return False class ListScope(BaseScope): @staticmethod @@ -83,6 +87,12 @@ class ListScope(BaseScope): class RangeScope(BaseScope, ABC): + + def __init__(self, *args): + super().__init__(*args) + self.in_scope = False + self.is_valid = self.check_scope_is_valid() + @staticmethod def rectify_args(scope, api_list): scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list) @@ -99,11 +109,6 @@ class RangeScope(BaseScope, ABC): def check_scope_is_valid(self): pass - def __init__(self, *args): - super().__init__(*args) - self.in_scope = False - self.is_valid = self.check_scope_is_valid() - def begin_module(self, module_name): pass @@ -169,6 +174,3 @@ class ModuleRangeScope(RangeScope): if not self.scope or self.in_scope: return self.check_api_list(module_name) return False - - - diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py index 003a8699cd750a424bf989ae9d1b3fac78f76650..2b4b6a8579a958862f279e015005029ca0c51d2b 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py @@ -17,14 +17,16 @@ import torch import torch.distributed as dist + from . import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten -from .wrap_torch import get_torch_ops +from .wrap_aten import get_aten_ops +from .wrap_distributed import get_distributed_ops from .wrap_functional import get_functional_ops from .wrap_tensor import get_tensor_ops +from .wrap_torch import get_torch_ops from .wrap_vf import get_vf_ops -from .wrap_distributed import get_distributed_ops -from .wrap_aten import get_aten_ops -from ..common.utils import torch_without_guard_version, npu_distributed_api, is_gpu +from ..common.utils import torch_without_guard_version, npu_distributed_api, is_gpu, Const + torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' if not is_gpu: @@ -108,19 +110,19 @@ class ApiRegistry: self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr) wrap_tensor.wrap_tensor_ops_and_bind(hook) for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.tensor_hook_attr[attr_name[5:]] = getattr(wrap_tensor.HOOKTensor, attr_name) self.store_ori_attr(torch, get_torch_ops(), self.torch_ori_attr) wrap_torch.wrap_torch_ops_and_bind(hook) for attr_name in dir(wrap_torch.HOOKTorchOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.torch_hook_attr[attr_name[5:]] = getattr(wrap_torch.HOOKTorchOP, attr_name) self.store_ori_attr(torch.nn.functional, get_functional_ops(), self.functional_ori_attr) wrap_functional.wrap_functional_ops_and_bind(hook) for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.functional_hook_attr[attr_name[5:]] = getattr(wrap_functional.HOOKFunctionalOP, attr_name) self.store_ori_attr(dist, get_distributed_ops(), self.distributed_ori_attr) @@ -128,9 +130,9 @@ class ApiRegistry: if not is_gpu and not torch_without_guard_version: self.store_ori_attr(torch_npu.distributed, npu_distributed_api, self.npu_distributed_ori_attr) for attr_name in dir(wrap_distributed.HOOKDistributedOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name) - if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api: + if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api: self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name) @@ -138,20 +140,20 @@ class ApiRegistry: self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr) wrap_aten.wrap_aten_ops_and_bind(hook) for attr_name in dir(wrap_aten.HOOKAtenOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.aten_hook_attr[attr_name[5:]] = getattr(wrap_aten.HOOKAtenOP, attr_name) self.store_ori_attr(torch._VF, get_vf_ops(), self.vf_ori_attr) wrap_vf.wrap_vf_ops_and_bind(hook) for attr_name in dir(wrap_vf.HOOKVfOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.vf_hook_attr[attr_name[5:]] = getattr(wrap_vf.HOOKVfOP, attr_name) if not is_gpu: self.store_ori_attr(torch_npu, get_npu_ops(), self.torch_npu_ori_attr) wrap_npu_custom.wrap_npu_ops_and_bind(hook) for attr_name in dir(wrap_npu_custom.HOOKNpuOP): - if attr_name.startswith("wrap_"): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): self.torch_npu_hook_attr[attr_name[5:]] = getattr(wrap_npu_custom.HOOKNpuOP, attr_name) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py index 08d47308e077981e65193eea71874d4f9432c6c0..9303aec6e04c01e0aa969141e33f54c75ebb8ca4 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py @@ -21,8 +21,8 @@ import torch import yaml from .hook_module import HOOKModule -from ..common.utils import torch_device_guard, Const from ..common.file_check import FileOpen +from ..common.utils import torch_device_guard, Const cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") @@ -32,8 +32,6 @@ with FileOpen(yaml_path, 'r') as f: def get_vf_ops(): global WrapVfOps - # _all_functional_ops = dir(torch.nn.functional) - # assert set(WrapFunctionalOps) <= set(_all_functional_ops) return WrapVfOps diff --git a/debug/accuracy_tools/atat/pytorch/pt_config.py b/debug/accuracy_tools/atat/pytorch/pt_config.py index a0691915cffc93b4a4505b2453560043b44cdc40..46d9b70cc9f9ef38aca9dc29c4df50be2d90e535 100644 --- a/debug/accuracy_tools/atat/pytorch/pt_config.py +++ b/debug/accuracy_tools/atat/pytorch/pt_config.py @@ -1,11 +1,12 @@ -import os import json +import os + from ..core.common_config import CommonConfig, BaseConfig -from ..core.utils import Const from ..core.file_check_util import FileOpen +from ..core.utils import Const -#特定任务配置类 +# 特定任务配置类 class TensorConfig(BaseConfig): def __init__(self, json_config): super().__init__(json_config) @@ -26,7 +27,7 @@ class StatisticsConfig(BaseConfig): def _check_summary_mode(self): if self.summary_mode and self.summary_mode not in ["statistics", "md5"]: raise Exception("summary_mode is invalid") - + class OverflowCheckConfig(BaseConfig): def __init__(self, json_config): @@ -34,13 +35,14 @@ class OverflowCheckConfig(BaseConfig): self.overflow_num = json_config.get("overflow_nums") self.check_mode = json_config.get("check_mode") self.check_overflow_config() - + def check_overflow_config(self): if self.overflow_num is not None and not isinstance(self.overflow_num, int): raise Exception("overflow_num is invalid") if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]: raise Exception("check_mode is invalid") - + + class FreeBenchmarkCheckConfig(BaseConfig): def __init__(self, json_config): super().__init__(json_config) @@ -55,8 +57,10 @@ class FreeBenchmarkCheckConfig(BaseConfig): self.check_freebenchmark_config() def check_freebenchmark_config(self): - if self.if_preheat and self.handler_type == "fix": + if self.if_preheat and self.handler_type == "fix": raise Exception("Preheating is not supported in fix handler type") + if self.preheat_step and self.preheat_step == 0: + raise Exception("preheat_step cannot be 0") def parse_task_config(task, json_config): default_dic = {} @@ -87,4 +91,4 @@ def parse_json_config(json_file_path, task): task_config = parse_task_config(task, json_config) else: task_config = parse_task_config(common_config.task, json_config) - return common_config, task_config \ No newline at end of file + return common_config, task_config