diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad282e6739a64b8710b2f5beed43a9a1c6546f4d --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py @@ -0,0 +1,8 @@ +from atat.pytorch.common import print_error_log_rank_0, print_info_log_rank_0 +from atat.pytorch.common.exceptions import FreeBenchmarkException +from atat.pytorch.common.utils import Const + +from .main import FreeBenchmarkCheck +from .common.params import UnequalRow + +__all__ = [FreeBenchmarkCheck, UnequalRow] diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/__init__.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..9b72437f2280ca44a20fc5e370f1cfd9b9ea3ac4 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py @@ -0,0 +1,66 @@ +from typing import Dict + +import numpy as np +import torch +from atat.pytorch.free_benchmark.common.enums import FuzzThreshold +from atat.pytorch.free_benchmark.common.params import BenchmarkThd + + +class CommonField: + DEVICE = "device" + META = "meta" + FUZZ_TENSOR = "fuzz_tensor" + REQUIRES_GRAD = "requires_grad" + HOLD_PLACE = "hold_place" + DISTRIBUTED_OP = "torch.distributed" + + +class ThresholdConfig: + PERTURBATION_VALUE_DICT: Dict = { + torch.bfloat16: FuzzThreshold.BF16_THD, + torch.float16: FuzzThreshold.F16_THD, + torch.float32: FuzzThreshold.F32_THD, + torch.float64: FuzzThreshold.F64_THD, + } + + ABS_TOL_VALUE_DICT: Dict = { + torch.bfloat16: FuzzThreshold.BF16_THD, + torch.float16: FuzzThreshold.F16_THD, + torch.float32: FuzzThreshold.F32_THD, + torch.float64: FuzzThreshold.F64_THD, + } + + # bit翻转需要匹配到等长或更长的整型 + PERTURBATION_BIT_DICT = { + torch.bfloat16: torch.int16, + torch.float16: torch.int16, + torch.float32: torch.int32, + torch.float64: torch.int64, + } + + # 输入噪声下界 + NOISE_INPUT_LOWER_BOUND = 1e-8 + COMP_CONSISTENT = 1.0 + COMP_NAN = np.nan + SYMBOL_FLIPPING = "symbol_flipping" + BACKWARD_OUTPUT_LOWER_BOUND = 1e-3 + SMALL_VALUE = 1.0 + # 预热初始阈值 + PREHEAT_INITIAL_THD = 2.05 + API_THD_STEP = 2.0 + + DTYPE_PER_THD = { + torch.float16: 1.002, + torch.float32: 1.0002, + } + BENCHMARK_THD_DICT = { + torch.float32: BenchmarkThd(2**-14, 1.0, 2**-14, 1e-4), + torch.float16: BenchmarkThd(2**-11, 1.0, 2**-11, 1e-4), + torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4), + } + + +class PreheatConfig: + IF_PREHEAT = "if_preheat" + PREHEAT_STEP = "preheat_step" + MAX_SAMPLE = "max_sample" diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/counter.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/counter.py new file mode 100644 index 0000000000000000000000000000000000000000..3d1361a72f16afcdeb1ea810384f7be67c4fe573 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/counter.py @@ -0,0 +1,67 @@ +from collections import defaultdict +from atat.pytorch.free_benchmark.common.constant import ThresholdConfig + + +class PreheatCounter: + def __init__(self) -> None: + self.api_called_time: dict = defaultdict(int) + self.api_sampled_time: dict = defaultdict(int) + self.one_step_used_api: dict = defaultdict(int) + self.api_thd: dict = defaultdict(dict) + self.preheat_record: dict = defaultdict(dict) + self.dtype_map: dict = {} + self.if_preheat: dict = defaultdict(dict) + + def reset(self): + self.__init__() + + def add_api_called_time(self, api_name: str): + self.api_called_time[api_name] += 1 + + def get_api_called_time(self, api_name: str) -> int: + return self.api_called_time[api_name] + + def add_api_sampled_time(self, api_name: str): + self.api_sampled_time[api_name] += 1 + + def get_api_sampled_time(self, api_name: str) -> int: + return self.api_called_time[api_name] + + def add_one_step_used_api(self, api_name: str): + self.one_step_used_api[api_name] += 1 + + def get_one_step_used_api(self, api_name: str): + return self.one_step_used_api[api_name] + + def update_preheat_record(self, step, api_name, dtype, cmp_result): + # 记录预热阶段CPU标杆比对的结果 + if step != self.step: + self.preheat_record = defaultdict(dict) + self.step = step + if str(dtype) not in self.preheat_record[api_name].keys(): + self.preheat_record[api_name][str(dtype)] = list() + self.preheat_record[api_name][str(dtype)].append(cmp_result) + self.dtype_map[str(dtype)] = dtype + + def update_api_thd(self, api_name, dtype, threshold, dthreshold): + self.api_thd[api_name][str(dtype)] = ( + threshold if threshold > dthreshold else dthreshold + ) + + def get_api_thd(self, api_name, dtype): + if not str(dtype) in self.api_thd[api_name]: + self.api_thd[api_name][str(dtype)] = ThresholdConfig.PREHEAT_INITIAL_THD + self.dtype_map[str(dtype)] = dtype + return self.api_thd[api_name][str(dtype)] + + def set_api_preheat(self, api_name, dtype_str, is_preheat=True): + # 标记cpu不一致的dtype 不再进行预热 + self.if_preheat[api_name][dtype_str] = is_preheat + + def get_api_preheat(self, api_name, dtype): + # 标记cpu不一致的dtype 不再进行预热 + if str(dtype) not in self.if_preheat[api_name]: + return True + return self.if_preheat[api_name][str(dtype)] + +preheat_counter = PreheatCounter() \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/enums.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb1bbaa40dc2a535a02aa914f823906b0a374ab --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/enums.py @@ -0,0 +1,37 @@ +class PerturbationMode: + ADD_NOISE = "add_noise" + CHANGE_VALUE = "change_value" + IMPROVE_PRECISION = "improve_precision" + NO_CHANGE = "no_change" + BIT_NOISE = "bit_noise" + TO_CPU = "to_cpu" + + +class DeviceType: + NPU = "npu" + CPU = "cpu" + + +class FuzzThreshold: + BF16_THD = 1e-4 + F16_THD = 1e-6 + F32_THD = 1e-8 + F64_THD = 1e-16 + + +class NormType: + ONE_NORM = (1, "one_norm") + TWO_NORM = (2, "two_norm") + ENDLESS_NORM = (3, "endless_norm") + + +class HandlerType: + CHECK = "check" + PREHEAT = "preheat" + FIX = "fix" + + +class FuzzLevel: + BASE_LEVEL = "L1" + ADV_LEVEL = "L2" + REAL_LEVEL = "L3" diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py new file mode 100644 index 0000000000000000000000000000000000000000..6c64e77cfdac7777913bd3925380bc88b8571500 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py @@ -0,0 +1,133 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from atat.pytorch.free_benchmark import Const, print_error_log_rank_0 +from atat.pytorch.free_benchmark.common.enums import ( + DeviceType, + FuzzLevel, + PerturbationMode, +) +from atat.pytorch.free_benchmark.common.utils import Tools + + +@dataclass +class DataParams: + args: Optional[Tuple] = None + kwargs: Optional[Dict] = None + index: Optional[int] = None + original_result: Optional[Any] = None + perturbed_result: Optional[Any] = None + is_consistent: Optional[bool] = True + perturbed_value: Optional[Any] = None + origin_func: Optional[Callable] = None + cpu_grads: Optional[Any] = None + api_type: Optional[str] = None + fuzz_stage: Optional[str] = None + grad_unequal_flag: Optional[bool] = True + + +@dataclass +class HandlerParams: + handler_type: Optional[str] = None + api_name: Optional[str] = None + pert_mode: Optional[PerturbationMode] = None + step: Optional[int] = None + fuzz_stage: Optional[str] = None + fuzz_device: Optional[DeviceType] = None + preheat_config: Optional[Dict] = None + fuzz_level: Optional[str] = None + + +@dataclass +class UnequalRow: + rank: Optional[int] = None + pert_mode: Optional[PerturbationMode] = None + stage: Optional[str] = None + step: Optional[int] = None + api_name: Optional[str] = None + max_rel: Optional[float] = None + dtype: Optional[str] = None + shape: Optional[str] = None + output_index: Optional[int] = None + + +@dataclass +class BenchmarkThd: + rtol: Optional[float] = None # 相对误差阈值 + small_value: Optional[float] = None # 小值域 + small_value_atol: Optional[float] = None # 小值域绝对阈值 + err_balance: Optional[float] = None # 误差均衡性 + + +def check_args_type(args: Tuple) -> int: + for i, arg in enumerate(args): + if torch.is_tensor(arg): + if arg.is_meta: + continue + if not torch.is_floating_point(arg): + continue + return i + if isinstance(arg, (List, Tuple, Dict)): + return i + return -1 + + +def data_pre_deal(name, func, args, kwargs): + data_params = DataParams(args=args, kwargs=kwargs, origin_func=func) + # data_params.api_type = name.split(Const.SEP)[0] + # TODO unsupported api type + index = check_args_type(args) + data_params.index = index + if index == -1: + print_error_log_rank_0( + f"[atat] Free benchmark: 无标杆工具不支持当前算子的输入类型 {name}." + ) + return data_params + + +def make_handler_params(name, config, step): + handler_params = HandlerParams() + handler_params.api_name = name + handler_params.step = step + handler_params.handler_type = config.handler_type + handler_params.fuzz_stage = config.fuzz_stage + handler_params.fuzz_device = config.fuzz_device + handler_params.preheat_config = config.preheat_config + handler_params.fuzz_level = config.fuzz_level + handler_params.pert_mode = config.pert_mode + return handler_params + + +def make_unequal_row( + data_params: DataParams, + handle_params: HandlerParams, + ratio: float = None, + index: int = None, +): + row = UnequalRow( + api_name=handle_params.api_name, + pert_mode=handle_params.pert_mode, + output_index=index, + stage=handle_params.fuzz_stage, + step=handle_params.step, + ) + if isinstance(ratio, float): + row.max_rel = ratio - 1 + origin_tensor = data_params.original_result + perturbed_tensor = data_params.perturbed_result + if index: + origin_tensor = origin_tensor[index] + perturbed_tensor = perturbed_tensor[index] + row.output_index = index + if isinstance(origin_tensor, torch.Tensor): + row.dtype = origin_tensor.dtype + row.shape = origin_tensor.shape + row.rank = Tools.get_dist_rank() + # 以下暂不支持 + if handle_params.fuzz_level == FuzzLevel.ADV_LEVEL: + pass + if handle_params.fuzz_level == FuzzLevel.REAL_LEVEL: + pass + return row diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..24d25967635b3dcfd1da89e1f54d3282fa1181ed --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py @@ -0,0 +1,98 @@ +import torch +from atat.pytorch.free_benchmark.common.enums import DeviceType + + +class Tools: + + @staticmethod + def is_float_tensor(tensor) -> bool: + if isinstance(tensor, torch.Tensor) and torch.is_floating_point(tensor): + return True + if isinstance(tensor, (list, tuple)): + for value in tensor: + if isinstance(value, torch.Tensor) and torch.is_floating_point(value): + return True + return False + + @staticmethod + def get_dist_rank(): + try: + return torch.distributed.get_rank() + except RuntimeError: + return 0 + + @staticmethod + def get_first_tensor_dtype(tensor_seq): + if isinstance(tensor_seq, torch.Tensor): + return tensor_seq.dtype + if isinstance(tensor_seq, (list, tuple)): + for object_ in tensor_seq: + if isinstance(object_, torch.Tensor): + return object_.dtype + raise RuntimeError("The sequence does not contain tensors.") + + @staticmethod + def get_pure_api_name(api_name: str): + return api_name.rsplit(".", 2)[0] + + @staticmethod + def convert_device_and_dtype( + tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False + ): + if isinstance(tensor_seq, torch.Tensor): + if change_dtype and tensor_seq.dtype in [torch.float16, torch.bfloat16]: + return tensor_seq.detach().to(device).to(torch.float32) + return tensor_seq.detach().to(device) + if isinstance(tensor_seq, dict): + return { + key: Tools.convert_device_and_dtype(value, device, change_dtype) + for key, value in tensor_seq.items() + } + if isinstance(tensor_seq, (tuple, list)): + return type(tensor_seq)( + [ + Tools.convert_device_and_dtype(value, device, change_dtype) + for value in tensor_seq + ] + ) + return tensor_seq + + @staticmethod + def convert_fuzz_output_to_origin(origin, perturbed): + if isinstance(origin, torch.Tensor): + origin.data = perturbed.to(origin.dtype).to(origin.device) + return origin + if isinstance(origin, dict): + output = dict() + for key, value in origin.items(): + output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key]) + return output + if isinstance(origin, (tuple, list)): + result = list() + for index_, value in enumerate(origin): + result.append( + Tools.convert_fuzz_output_to_origin(value, perturbed[index_]) + ) + return type(origin)(result) + return origin + +class TorchC: + sum = torch._C._VariableFunctionsClass.sum + isinf = torch._C._VariableFunctionsClass.isinf + isfinite = torch._C._VariableFunctionsClass.isfinite + isnan = torch._C._VariableFunctionsClass.isnan + logical_not = torch._C._VariableFunctionsClass.logical_not + subtract = torch._C._VariableFunctionsClass.subtract + abs = torch._C._VariableFunctionsClass.abs + where = torch._C._VariableFunctionsClass.where + div = torch._C._VariableFunctionsClass.div + max = torch._C._VariableFunctionsClass.max + min = torch._C._VariableFunctionsClass.min + gt = torch._C._VariableFunctionsClass.gt + ge = torch._C._VariableFunctionsClass.ge + lt = torch._C._VariableFunctionsClass.lt + mean = torch._C._VariableFunctionsClass.mean + full = torch._C._VariableFunctionsClass.full + add = torch._C._VariableFunctionsClass.add + bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor + clone = torch._C._VariableFunctionsClass.clone diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..1d58eadc606e4cdf572b6f5078a218c7f4c266c7 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py @@ -0,0 +1,172 @@ +import torch +from atat.pytorch.free_benchmark import print_info_log_rank_0, print_error_log_rank_0 +from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams +from atat.pytorch.free_benchmark.common.constant import CommonField +from atat.pytorch.free_benchmark.common.utils import Tools +from atat.pytorch.free_benchmark.result_handlers.handler_factory import ( + FuzzHandlerFactory, +) +from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory + + +class GradSaver: + + def __init__(self, origin_func, handler_params: HandlerParams): + + self.handler_params = handler_params + self.api_name = handler_params.api_name + self.origin_func = origin_func + self.data_params = DataParams() + self.is_compare = True + self.kwargs = dict() + self.perturbed_grad_input = tuple() + self.origin_grad_input = tuple() + self.need_grad_flag = list() + self.backward_input = tuple() + + def register_compare_func_for_inputs(self, inputs, data_processor): + _index = 0 + for j, obj in enumerate(inputs): + if torch.is_tensor(obj) and obj.requires_grad: + + def compare_func(grad, new_grad_index=_index, input_index=j): + if not self.is_compare: + return grad + try: + perturbed_grad = self.check_grad_input(grad, new_grad_index) + handler = FuzzHandlerFactory.create(self.handler_params) + self.compare_grad_results( + handler, grad, perturbed_grad, index=input_index + ) + data_processor.update_unequal_rows(handler.get_unequal_rows()) + except Exception as e: + print_error_log_rank_0( + f"[atat] Free benchmark: grad compara error: {e}" + ) + return grad + return grad + + obj.register_hook(compare_func) + _index += 1 + + def compare_grad_results(self, handler, origin_grad, perturbed_grad, index): + # TODO get dtype? + self.data_params.original_result = origin_grad + self.data_params.perturbed_result = perturbed_grad + self.data_params.grad_unequal_flag = False + self.data_params.index = index + try: + handler.handle(self.data_params) + if not self.data_params.is_consistent: + self.is_compare = False + self.data_params.grad_unequal_flag = True + self.data_params.is_consistent = True + self.data_params.perturbed_result = self.perturbed_grad_input + self.data_params.original_result = self.origin_grad_input + handler.handle(self.data_params) + except Exception as e: + print_error_log_rank_0( + f"[atat] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}." + f"{e}" + ) + + def check_grad_input(self, origin_grad, new_grad_index): + if self.perturbed_grad_input is None: + print_info_log_rank_0( + f"[atat] Free benchmark: grad not exsits : {self.api_name}." + ) + return None + try: + with torch.no_grad(): + perturbed_grad = self.perturbed_grad_input[new_grad_index].to( + origin_grad.device + ) + except IndexError: + print_error_log_rank_0( + f"[atat] Free benchmark: grad index out of range. api:{self.handler_params.api_name}." + f"index:{new_grad_index}, perturbation grad len {len(self.perturbed_grad_input)}" + ) + return None + if origin_grad.shape != perturbed_grad.shape: + print_error_log_rank_0( + f"[atat] Free benchmark: grad shapes are unconsistent. api:{self.handler_params.api_name}." + f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}" + ) + return None + return perturbed_grad + + def cache_backward_input(self, inputs): + _inputs = [] + with torch.no_grad(): + for object_ in inputs: + if torch.is_tensor(object_): + _inputs.append( + { + CommonField.DEVICE: object_.device, + CommonField.FUZZ_TENSOR: object_.cpu(), + CommonField.REQUIRES_GRAD: object_.requires_grad, + } + ) + else: + _inputs.append(object_) + self.backward_input = _inputs + + def get_vjp_input(self): + inner_args_tmp = [] + need_grad_tensors = [] + for object_ in self.backward_input: + if isinstance(object_, dict) and CommonField.FUZZ_TENSOR in object_.keys(): + tensor_ = torch.tensor( + object_.get(CommonField.FUZZ_TENSOR).data, + dtype=object_.get(CommonField.FUZZ_TENSOR).dtype, + device=object_.get(CommonField.DEVICE), + requires_grad=object_.get(CommonField.REQUIRES_GRAD), + ) + + if tensor_.requires_grad: + inner_args_tmp.append(CommonField.HOLD_PLACE) + need_grad_tensors.append(tensor_) + self.need_grad_flag.append(True) + else: + self.need_grad_flag.append(False) + inner_args_tmp.append(tensor_) + else: + self.need_grad_flag.append(False) + inner_args_tmp.append(object_) + + return need_grad_tensors, tuple(inner_args_tmp) + + def get_grad_input_from_vjp(self, need_grad_tensors, grad_output, inner_args): + def vjp_func(*inputs): + _real_input = [] + index_ = 0 + for object_ in inner_args: + if object_ is CommonField.HOLD_PLACE: + _real_input.append(inputs[index_]) + index_ += 1 + else: + _real_input.append(object_) + kwargs = self.kwargs.copy() + if 'inplace' in kwargs: + kwargs['inplace'] = False + return self.origin_func(*_real_input, **kwargs) + + _, grad_input = torch.autograd.functional.vjp( + vjp_func, tuple(need_grad_tensors), grad_output + ) + return grad_input + + def calculate_perturbed_grad_input(self, grad_output, need_grad_tensors, inner_args): + self.data_params.args = [need_grad_tensors, grad_output, inner_args] + self.data_params.kwargs = {} + self.data_params.index = 0 + self.data_params.origin_func = self.get_grad_input_from_vjp + layer = LayerFactory.create( + self.handler_params.api_name, + self.handler_params.fuzz_device, + self.handler_params.pert_mode, + ) + layer.handle(self.data_params) + self.perturbed_grad_input = tuple( + [x.cpu() for x in self.data_params.perturbed_result] + ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..1d13bfd190e3f1a60c66310a6af69845affa5bf6 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py @@ -0,0 +1,100 @@ +import torch +import math + +from atat.pytorch.free_benchmark.common.utils import TorchC +from atat.pytorch.free_benchmark.common.constant import ThresholdConfig + + +class SingleCompare: + def __init__(self) -> None: + self.relative_err = None + self.absolute_err = None + self.eb = None + self.threshold = None + + def compare_seq(self, actual, golden): + if isinstance(golden, torch.Tensor): + return self.compare_tensor_seq(actual, golden) + elif isinstance(golden, dict): + return self.compare_dict_seq(actual, golden) + elif isinstance(golden, (tuple, list)): + return self.compare_list_seq(actual, golden) + elif isinstance(golden, float): + return self.compare_float_seq(actual, golden) + else: + return self.compare_other_seq(actual, golden) + + def compare_tensor_seq(self, actual, golden): + self.threshold = ThresholdConfig.BENCHMARK_THD_DICT.get( + actual.dtype, ThresholdConfig.BENCHMARK_THD_DICT.get(torch.float32) + ) + if self.filter_overflow(golden) > 0: + raise RuntimeError("inf and nan in golden tensor is not supported.") + actual = self.replace_inf_or_nan(actual) + actual = actual.to(torch.float64) + golden = golden.to(torch.float64).to(actual.device) + self._cal_compare_metrics(actual, golden) + if self.absolute_err > self.threshold.small_value_atol: + return False + if self.relative_err > self.threshold.rtol: + return False + if self.eb > self.threshold.err_balance: + return False + return True + + + def _cal_compare_metrics(self, actual, golden): + diff_value = TorchC.subtract(actual, golden) + diff_abs = TorchC.abs(diff_value) + golden_abs = TorchC.abs(golden) + # 使用绝对误差的元素 + self.absolute_err = TorchC.max(TorchC.where( + TorchC.lt(TorchC.abs(actual), self.threshold.SMALL_VALUE), diff_abs, 0 + )) + diff_rel = TorchC.div(diff_abs, golden_abs) + # 使用相对误差的元素 + self.relative_err = TorchC.max(TorchC.where( + TorchC.ge(TorchC.abs(actual), self.threshold.SMALL_VALUE), diff_rel, 0 + )) + # 获取误差均衡性 + divided = TorchC.where( + TorchC.ge(TorchC.abs(golden), self.threshold.SMALL_VALUE, golden_abs, 1) + ) + self.eb = TorchC.mean(TorchC.div(diff_value, divided)) + + def compare_dict_seq(self, actual, golden): + if len(actual) != len(golden): + return False + for key, value in golden.items(): + if not self.compare_seq(value, actual.get(key)): + return False + return True + + def compare_list_seq(self, actual, golden): + if len(actual) != len(golden): + return False + for index_, value in enumerate(golden): + if not self.compare_seq(value, actual[index_]): + return False + return True + + def compare_float_seq(self, actual, golden): + return math.isclose(actual, golden) + + def compare_other_seq(self, actual, golden): + return actual == golden + + @staticmethod + def filter_overflow(tensor) -> int: + inf_num = TorchC.sum(TorchC.isinf(tensor)) + nan_num = TorchC.sum(TorchC.isnan(tensor)) + return inf_num + nan_num + + @staticmethod + def replace_inf_or_nan(tensor): + finite_mask = TorchC.isfinite(tensor) + inf_or_nan_mask = TorchC.logical_not(finite_mask) + inf_or_nan_num = TorchC.sum(inf_or_nan_mask).items() + if inf_or_nan_num > 0: + tensor[inf_or_nan_mask] = 1 + return tensor diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7897a6583de4115bc8e55015145f4cec3f5834 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py @@ -0,0 +1,97 @@ +import importlib +from abc import ABC + +import torch +from atat.pytorch.free_benchmark import Const, print_error_log_rank_0 + +from atat.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params +from atat.pytorch.free_benchmark.common.enums import ( + PerturbationMode, + FuzzLevel, + DeviceType, +) +from atat.pytorch.free_benchmark.compare.grad_saver import GradSaver +from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory +from atat.pytorch.free_benchmark.result_handlers.handler_factory import ( + FuzzHandlerFactory, +) + + +class FreeBenchmarkCheck(ABC): + + def __init__(self, config) -> None: + super().__init__() + self.config = config + if self.config.pert_mode is None: + self.config.pert_mode = PerturbationMode.IMPROVE_PRECISION + if self.config.fuzz_level is None: + self.config.fuzz_level = FuzzLevel.BASE_LEVEL + if self.config.fuzz_device is None: + self.config.fuzz_device = DeviceType.NPU + self.current_iter = 0 + + def update_iter(self, update_iter): + self.current_iter = update_iter + + def pre_forward(self, name, module, data_processor, args, kwargs): + if not self.config.fuzz_stage == Const.BACKWARD: + return + # TODO 只支持check模式 + origin_func = ( + module._slow_forward if torch._C._get_tracing_state() else module.forward + ) + handler_params = make_handler_params(name, self.config, self.current_iter) + grad_saver = GradSaver(origin_func, handler_params) + grad_saver.kwargs = kwargs + grad_saver.register_compare_func_for_inputs(args, data_processor) + grad_saver.cache_backward_input(args) + setattr(module, "grad_saver", grad_saver) + + def forward(self, name, module, args, kwargs, output): + if not self.config.fuzz_stage == Const.FORWARD: + return output, [] + origin_func = ( + module._slow_forward if torch._C._get_tracing_state() else module.forward + ) + data_params = data_pre_deal(name, origin_func, args, kwargs) + if data_params.index == -1: + return output, [] + data_params.original_result = output + data_params.fuzz_stage = self.config.fuzz_stage + + layer = LayerFactory.create( + name, self.config.fuzz_device, self.config.pert_mode + ) + layer.handle(data_params) + handler_params = make_handler_params(name, self.config, self.current_iter) + handler = FuzzHandlerFactory.create(handler_params) + handler.handle(data_params) + return output, handler.get_unequal_rows() + + def backward(self, name, module, grad_output): + + if not self.config.fuzz_stage == Const.BACKWARD: + return + try: + grad_saver = getattr(module, "grad_saver") + except AttributeError: + print_error_log_rank_0( + f"[atat] Free benchmark: get grad saver failed. api_name:{name}" + ) + return + + _new_grad_output = grad_output + try: + need_grad_tensors, _inner_args = grad_saver.get_vjp_input() + origin_grad_input = grad_saver.get_grad_input_from_vjp( + tuple(need_grad_tensors), _new_grad_output, _inner_args + ) + grad_saver.origin_grad_input = tuple([x.cpu() for x in origin_grad_input]) + grad_saver.calculate_perturbed_grad_input( + _new_grad_output, need_grad_tensors, _inner_args + ) + except Exception as e: + print_error_log_rank_0( + f"[atat] Free benchmark: grad vjp calculate failed. api_name:{name} error: {e}" + ) + return diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/__init__.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/base_layer.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/base_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa572fd8e8dc8b62493dfa1fecc587b934c83a99 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/base_layer.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod +from typing import Any + +from atat.pytorch.free_benchmark.common.params import DataParams + + +class BaseLayer(ABC): + def __init__(self, api_name: str) -> None: + self.api_name = api_name + + @abstractmethod + def handle(self, params: DataParams) -> Any: + pass diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/layer_factory.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/layer_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..0d09438ce04132c9c5c301d758dc06818805082e --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/layer_factory.py @@ -0,0 +1,41 @@ +from atat.pytorch.free_benchmark import FreeBenchmarkException +from atat.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode +from atat.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import ( + ImprovePrecisionLayer, +) +from atat.pytorch.free_benchmark.perturbed_layers.npu.add_noise import AddNoiseLayer +from atat.pytorch.free_benchmark.perturbed_layers.npu.bit_noise import BitNoiseLayer +from atat.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer +from atat.pytorch.free_benchmark.perturbed_layers.npu.change_value import ( + ChangeValueLayer, +) +from atat.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer + + +class LayerFactory: + layers = { + DeviceType.NPU: { + PerturbationMode.ADD_NOISE: AddNoiseLayer, + PerturbationMode.CHANGE_VALUE: ChangeValueLayer, + PerturbationMode.NO_CHANGE: NoChangeLayer, + PerturbationMode.BIT_NOISE: BitNoiseLayer, + PerturbationMode.IMPROVE_PRECISION: ImprovePrecisionLayer, + }, + DeviceType.CPU: {PerturbationMode.TO_CPU: CpuLayer}, + } + + @staticmethod + def create(api_name: str, device_type: str, mode: str): + layer = LayerFactory.layers.get(device_type) + if not layer: + raise FreeBenchmarkException( + FreeBenchmarkException.UnsupportedType, + f"无标杆工具不支持当前设备 {device_type}", + ) + layer = layer.get(mode) + if not layer: + raise FreeBenchmarkException( + FreeBenchmarkException.UnsupportedType, + f"无标杆工具无法识别该扰动因子 {mode} on {device_type}", + ) + return layer(api_name) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/__init__.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..0b922d09791f04ac4613c37bc3a0246c375eb266 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -0,0 +1,96 @@ +import torch +from atat.pytorch.free_benchmark import ( + print_info_log_rank_0, + print_error_log_rank_0, +) +from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.common.utils import TorchC +from atat.pytorch.free_benchmark.common.enums import PerturbationMode +from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( + NpuBaseLayer, +) + + +class AddNoiseLayer(NpuBaseLayer): + + def _get_noise(self, tensor_obj): + dtype = tensor_obj.dtype + device = str(tensor_obj.device) + noise = TorchC.full( + tensor_obj.shape, + self.perturbed_value, + device=device, + dtype=dtype, + ) + return noise + + def _check_details(self, tensor_obj): + """ + 判断是否需要添加扰动 + """ + if not self.perturbed_value: + print_error_log_rank_0( + f"[atat] Free Benchmark: For {self.api_name}, " + f"dtype unsupported. Cancel perturbation." + ) + return False + if tensor_obj.numel() == 0 or not torch.is_floating_point(tensor_obj): + print_info_log_rank_0( + f"[atat] Free benchmark: For {self.api_name}, unsupported tensor types." + f" Cancel adding noise." + ) + return False + abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get( + tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND + ) + try: + max_val = TorchC.max(TorchC.abs(tensor_obj)).item() + except Exception: + print_info_log_rank_0( + f"[atat] Free Benchmark: For {self.api_name}, " + f"when calculate maximun value, tensor is changed to float32." + ) + max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() + if max_val < abs_tol: + print_info_log_rank_0( + f"[atat] Free Benchmark: For {self.api_name}, " + f"Maximun value is less than the minimun threshold. Cancel add noise." + ) + return False + return True + + def add_noise(self, tensor_obj): + self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get( + tensor_obj.dtype + ) + if isinstance(tensor_obj, torch.Tensor): + self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get( + tensor_obj.dtype + ) + if not self.pre_check(tensor_obj): + return tensor_obj + noise = self._get_noise(tensor_obj) + result = TorchC.where( + TorchC.abs(tensor_obj) > self.perturbed_value**0.5, + TorchC.add(noise, tensor_obj), + tensor_obj, + ).to(tensor_obj.dtype) + self.is_added = True + return result + if isinstance(tensor_obj, dict): + return {key: self.add_noise(value) for key, value in tensor_obj.items()} + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)([self.add_noise(value) for value in tensor_obj]) + return tensor_obj + + def handle(self, params: DataParams) -> torch.Any: + """ + 对输入添加扰动并返回 + """ + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is " + f"{PerturbationMode.ADD_NOISE} of {self.api_name}." + ) + params.perturbed_value = self.add_noise(params.args[params.index]) + return self.perturbed_result(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..af1ca1e927c71fa4db64c4ceaecddc20980a07d2 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -0,0 +1,108 @@ +import torch +from atat.pytorch.free_benchmark import ( + print_info_log_rank_0, + print_error_log_rank_0, +) +from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.common.utils import TorchC +from atat.pytorch.free_benchmark.common.enums import PerturbationMode +from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( + NpuBaseLayer, +) + + +class BitNoiseLayer(NpuBaseLayer): + def __init__(self, api_name): + super().__init__(api_name) + self.bit_mode = TorchC.bitwise_xor + self.bit_tail: int = 1 + self.bit_type = None + + def _check_details(self, tensor_obj): + """ + 判断是否需要添加扰动, bit翻转 + """ + if not self.bit_type: + print_error_log_rank_0( + f"[atat] Free Benchmark: For {self.api_name}, " + f"dtype unsupported. Cancel perturbation." + ) + return False + if tensor_obj.numel() == 0 or not torch.is_floating_point(tensor_obj): + print_info_log_rank_0( + f"[atat] Free benchmark: For {self.api_name}, unsupported tensor types." + f" Cancel adding noise." + ) + return False + abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get( + tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND + ) + try: + max_val = TorchC.max(TorchC.abs(tensor_obj)).item() + except Exception: + print_info_log_rank_0( + f"[atat] Free Benchmark: For {self.api_name}, " + f"when calculate maximun value, tensor is changed to float32." + ) + max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() + if max_val < abs_tol: + print_info_log_rank_0( + f"[atat] Free Benchmark: For {self.api_name}, " + f"Maximun value is less than the minimun threshold. Cancel add noise." + ) + return False + return True + + def _set_perturbation_bit(self, tensor_obj): + """ + 根据不同浮点数确定不同位数扰动值 + """ + bit_len_type = ThresholdConfig.PERTURBATION_BIT_DICT.get(tensor_obj.dtype) + if bit_len_type: + self.bit_tail = 1 + self.bit_type = bit_len_type + + def add_bit_noise(self, tensor_obj): + """ + 对输入添加噪声 + """ + # finfo应该列入黑名单 + + self._set_perturbation_bit(tensor_obj) + if isinstance(tensor_obj, torch.Tensor): + self._set_perturbation_bit(tensor_obj) + if not self.pre_check(tensor_obj): + return tensor_obj + sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal + noise = TorchC.full( + tensor_obj.shape, + self.bit_tail, + device=tensor_obj.device, + dtype=self.bit_type, + ) + result = tensor_obj.view(self.bit_type) + result = TorchC.where( + TorchC.abs(tensor_obj) > sub_normal, + self.bit_mode(result, noise), + result, + ).view(tensor_obj.dtype) + + self.is_added = True + return result + if isinstance(tensor_obj, dict): + return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()} + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj]) + return tensor_obj + + def handle(self, params: DataParams) -> torch.Any: + """ + 对输入添加扰动并返回 + """ + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is " + f"{PerturbationMode.BIT_NOISE} of {self.api_name}." + ) + params.perturbed_value = self.add_bit_noise(params.args[params.index]) + return self.perturbed_result(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py new file mode 100644 index 0000000000000000000000000000000000000000..164ae0396bde3ebc96a9754622cfc7fb8d45687d --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -0,0 +1,63 @@ +import torch +from atat.pytorch.free_benchmark import print_info_log_rank_0 +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.common.utils import TorchC +from atat.pytorch.free_benchmark.common.enums import PerturbationMode +from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( + NpuBaseLayer, +) + + +class ChangeValueLayer(NpuBaseLayer): + def __init__(self, api_name): + super().__init__(api_name) + self.head: int = 0 + self.tail: int = 1 + + def _check_details(self, tensor_obj): + """ + 判断是否需要添加扰动, bit翻转 + """ + if tensor_obj.size(0) < 2: + print_info_log_rank_0( + f"[atat] Free Benchmark: For {self.api_name}, " + f"size 0 must greater than 1. Cancel change value." + ) + return False + return True + + def change_value(self, tensor_obj): + """ + 交换张量首尾 + """ + if isinstance(tensor_obj, torch.Tensor) and self.pre_check(tensor_obj): + new_tensor = TorchC.clone(tensor_obj) + if new_tensor.ndim == 1: + temp_first = TorchC.clone(new_tensor[self.head]) + temp_last = TorchC.clone(new_tensor[self.tail]) + new_tensor[self.head] = temp_first + new_tensor[self.tail] = temp_last + else: + temp_first = TorchC.clone(new_tensor[self.head][self.head]) + temp_last = TorchC.clone(new_tensor[self.tail][self.tail]) + new_tensor[self.head][self.head] = temp_first + new_tensor[self.tail][self.tail] = temp_last + + self.is_added = True + return new_tensor + if isinstance(tensor_obj, dict): + return {key: self.change_value(value) for key, value in tensor_obj.items()} + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)([self.change_value(value) for value in tensor_obj]) + return tensor_obj + + def handle(self, params: DataParams) -> torch.Any: + """ + 对输入添加扰动并返回 + """ + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is " + f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}." + ) + params.perturbed_value = self.change_value(params.args[params.index]) + return self.perturbed_result(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..fb126972c6853b81d24db8138880601f9a3af21a --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -0,0 +1,64 @@ +import torch +from atat.pytorch.free_benchmark import Const, print_info_log_rank_0 +from atat.pytorch.free_benchmark.common.constant import CommonField +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.common.enums import PerturbationMode +from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( + NpuBaseLayer, +) + + +class ImprovePrecisionLayer(NpuBaseLayer): + + def _set_improve_valus(self, inputs): + # TODO why + if inputs.dtype in [torch.float16, torch.bfloat16]: + self.perturbed_value = torch.float32 + + def _change_dtype(self, inputs): + if hasattr(inputs, CommonField.DEVICE): + device = inputs.device + if device is CommonField.META: + new_inputs = inputs.to( + device=CommonField.META, dtype=self.perturbed_value + ) + else: + new_inputs = inputs.to(dtype=self.perturbed_value).to(device) + else: + new_inputs = inputs.to(dtype=self.perturbed_value) + return new_inputs + + def improve_tensor_precision(self, tensor_obj): + if ( + isinstance(tensor_obj, torch.Tensor) + and torch.is_floating_point(tensor_obj) + and tensor_obj.dtype not in [torch.float32, torch.float64] + ): + self._set_improve_valus(tensor_obj) + tensor_obj = self._change_dtype(tensor_obj) + return tensor_obj + if isinstance(tensor_obj, dict): + return { + key: self.improve_tensor_precision(value) + for key, value in tensor_obj.items() + } + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)( + [self.improve_tensor_precision(value) for value in tensor_obj] + ) + return tensor_obj + + def handle(self, params: DataParams) -> torch.Any: + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is " + f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}." + ) + new_args = self.improve_tensor_precision(params.args) + if params.fuzz_stage == Const.BACKWARD: + new_kwargs = {} + else: + new_kwargs = self.improve_tensor_precision(params.kwargs) + if "inplace" in new_kwargs: + new_kwargs["inplace"] = False + params.perturbed_result = params.origin_func(*new_args, **new_kwargs) + return params.perturbed_result diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9850c7637085da25b81824e9a5c06f88c99033 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py @@ -0,0 +1,29 @@ +import torch +from atat.pytorch.free_benchmark import print_info_log_rank_0 +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.common.enums import PerturbationMode +from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( + NpuBaseLayer, +) + + +class NoChangeLayer(NpuBaseLayer): + + def no_change(self, tensor_obj): + """ + 交换张量首尾 + """ + self.is_added = True + return tensor_obj + + + def handle(self, params: DataParams) -> torch.Any: + """ + 对输入添加扰动并返回 + """ + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is " + f"{PerturbationMode.NO_CHANGE} of {self.api_name}." + ) + params.perturbed_value = self.no_change(params.args[params.index]) + return self.perturbed_result(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py new file mode 100644 index 0000000000000000000000000000000000000000..bbbe1511c3088c5dac31722d2de58a447f437f92 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py @@ -0,0 +1,51 @@ +from abc import abstractmethod +from typing import Any +import torch + +from atat.pytorch.free_benchmark import ( + print_info_log_rank_0, + print_error_log_rank_0, +) +from atat.pytorch.free_benchmark.common.constant import CommonField, ThresholdConfig +from atat.pytorch.free_benchmark.common.utils import TorchC +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer + + +class NpuBaseLayer(BaseLayer): + def __init__(self, api_name: str) -> None: + super().__init__(api_name) + self.perturbed_value = None # 扰动的元素 + self.is_added = False # 标记当前算子输入是否调整 + + @abstractmethod + def handle(self, params: DataParams) -> Any: + pass + + def _check_details(self, tensor_obj): + return True + + def pre_check(self, tensor_obj): + """ + 检查张量是否符合标准(float类型且最大值大于对应精度最小值) + """ + # 只针对第一个满足要求的添加扰动 + if self.is_added: + return False + if not torch.is_floating_point(tensor_obj): + return False + if not self._check_details(tensor_obj): + return False + return True + + @staticmethod + def perturbed_result(params: DataParams) -> Any: + args_front = params.args[: params.index] + args_rear = params.args[params.index + 1 :] + # 此处会将有inplace属性的算子换为非inplace + if "inplace" in params.kwargs: + params.kwargs["inplace"] = False + params.perturbed_result = params.origin_func( + *args_front, params.perturbed_value, *args_rear, **params.kwargs + ) + return params.perturbed_result diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ae8ea607b21ae2d96a4057e69a05665d872b76 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py @@ -0,0 +1,21 @@ +import torch +from atat.pytorch.free_benchmark import print_info_log_rank_0 +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.common.utils import Tools +from atat.pytorch.free_benchmark.common.enums import DeviceType +from atat.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer + + +class CpuLayer(BaseLayer): + + def handle(self, params: DataParams) -> torch.Any: + + print_info_log_rank_0( + f"[atat] Free benchmark: Perturbation is to_cpu of {self.api_name}." + ) + new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True) + new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True) + params.perturbed_result = params.origin_func(*new_args, **new_kwargs) + if "inplace" in new_kwargs: + new_kwargs["inplace"] = False + return self.perturbed_result(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/__init__.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..3dac3969df34ac530fd594db5bde717581b634ab --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py @@ -0,0 +1,161 @@ +import math +from abc import ABC, abstractmethod +from typing import Any, Optional, Tuple + +import torch +from atat.pytorch.free_benchmark import ( + Const, + print_error_log_rank_0, +) +from atat.pytorch.free_benchmark.common.utils import TorchC +from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.enums import FuzzThreshold, NormType, PerturbationMode +from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams + + + +class FuzzHandler(ABC): + def __init__(self, params: HandlerParams) -> None: + self.params = params + self.unequal_rows = [] + + @staticmethod + def pre_process(origin_ouput, perturbed_output): + if ( + isinstance(origin_ouput, tuple) + and hasattr(origin_ouput, "values") + and hasattr(origin_ouput, "indices") + ): + origin_ouput = origin_ouput.values + perturbed_output = perturbed_output.values + if hasattr(perturbed_output, "dtype"): + abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype) + else: + abs_tol = FuzzThreshold.F32_THD.value + return ( + origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device), + perturbed_output, + abs_tol, + ) + + def get_ratio_from_specific_norm( + self, origin_output, perturbed_output, norm_type, abs_tol + ): + if norm_type == NormType.ENDLESS_NORM: + return self.get_endless_norm(origin_output, perturbed_output, abs_tol) + return ThresholdConfig.COMP_CONSISTENT + + @staticmethod + def convert_overflow_ratio_to_consistent(ratio): + if math.isnan(ratio) or math.isinf(ratio): + return ThresholdConfig.COMP_CONSISTENT + return ratio + + def get_endless_norm(self, origin_output, perturbed_output, abs_tol): + try: + ratio_tensor1 = TorchC.where( + TorchC.gt(TorchC.abs(perturbed_output), abs_tol), + TorchC.div(origin_output, perturbed_output), + 1, + ) + ratio_tensor2 = TorchC.where( + TorchC.gt(TorchC.abs(origin_output), abs_tol), + TorchC.div(perturbed_output, origin_output), + 1, + ) + except: + ratio_tensor1 = TorchC.where( + TorchC.gt(TorchC.abs(perturbed_output.to(torch.float32)), abs_tol), + TorchC.div( + origin_output.to(torch.float32), perturbed_output.to(torch.float32) + ), + 1, + ) + ratio_tensor2 = TorchC.where( + TorchC.gt(TorchC.abs(origin_output.to(torch.float32)), abs_tol), + TorchC.div( + perturbed_output.to(torch.float32), origin_output.to(torch.float32) + ), + 1, + ) + norm1 = self.convert_overflow_ratio_to_consistent( + TorchC.max(ratio_tensor1).item() + ) + norm2 = self.convert_overflow_ratio_to_consistent( + TorchC.max(ratio_tensor2).item() + ) + norm3 = self.convert_overflow_ratio_to_consistent( + TorchC.min(ratio_tensor1).item() + ) + if norm3 < 0: + ratio = ThresholdConfig.SYMBOL_FLIPPING + else: + ratio = max(norm1, norm2) + return ratio + + def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float: + try: + origin_output, perturbed_output, abs_tol = self.pre_process( + origin_output, perturbed_output + ) + except Exception as e: + print_error_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name}, " + f"when computing ratio," + f" y1 or y2 dtype is not supported {e}" + ) + return ThresholdConfig.COMP_NAN + if self.params.fuzz_stage == Const.BACKWARD: + abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND + else: + abs_tol = abs_tol**0.5 + return self.get_ratio_from_specific_norm( + origin_output, perturbed_output, norm_type, abs_tol + ) + + @abstractmethod + def get_threshold(self, dtype): + pass + + def _get_default_threshold(self, dtype): + if self.params.pert_mode == PerturbationMode.NO_CHANGE: + threshold = ThresholdConfig.COMP_CONSISTENT + else: + threshold = ThresholdConfig.DTYPE_PER_THD.get( + dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32) + ) + return threshold + + def npu_compare( + self, origin_output, perturbed_output + ) -> Tuple[bool, Optional[float]]: + + if isinstance(perturbed_output, int): + return origin_output == perturbed_output, None + elif isinstance(perturbed_output, float): + return ( + math.isclose(origin_output, perturbed_output), + origin_output / perturbed_output, + ) + elif not isinstance(perturbed_output, torch.Tensor): + print_error_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name} " + f"The compare for output type {type(perturbed_output)} is not supported" + ) + + threshold = self.get_threshold(origin_output.dtype) + ratio = self.ratio_calculate( + origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM + ) + if ratio == ThresholdConfig.SYMBOL_FLIPPING: + is_consistent = False + else: + is_consistent = threshold >= ratio >= 1 / threshold + return is_consistent, ratio + + @abstractmethod + def handle(self, data_params: DataParams) -> Any: + pass + + def get_unequal_rows(self): + return self.unequal_rows diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ab433d9bf13c49bb5377822e4af70d18545c0c36 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py @@ -0,0 +1,66 @@ +from typing import Any + +import torch +from atat.pytorch.free_benchmark import print_error_log_rank_0 +from atat.pytorch.free_benchmark.common.enums import DeviceType +from atat.pytorch.free_benchmark.compare.single_benchmark import SingleCompare +from atat.pytorch.free_benchmark.common.params import DataParams, make_unequal_row +from atat.pytorch.free_benchmark.common.utils import Tools +from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler + + +class CheckerHandler(FuzzHandler): + @staticmethod + def other_compare(self, data_params: DataParams) -> bool: + is_consistent = SingleCompare.compare_seq( + data_params.original_result, data_params.perturbed_result + ) + if not is_consistent: + self.unequal_rows.append( + make_unequal_row(data_params, self.params) + ) + + def cmp_output_npu(self, data_params: DataParams): + if isinstance(data_params.original_result, torch.Tensor): + is_consistent, ratio = self.npu_compare( + data_params.original_result, data_params.perturbed_result + ) + data_params.is_consistent = is_consistent and data_params.is_consistent + if not is_consistent and data_params.grad_unequal_flag: + self.unequal_rows.append( + make_unequal_row(data_params, self.params, ratio=ratio) + ) + + elif isinstance(data_params.original_result, (list, tuple)): + for index_, origin_item in enumerate(data_params.original_result): + is_consistent, ratio = self.npu_compare( + origin_item, data_params.perturbed_result[index_] + ) + data_params.is_consistent = is_consistent and data_params.is_consistent + if not is_consistent and data_params.grad_unequal_flag: + self.unequal_rows.append( + make_unequal_row( + data_params, self.params, ratio=ratio, index=index_ + ) + ) + return + + def get_threshold(self, dtype): + return self._get_default_threshold(dtype) + + def handle(self, data_params: DataParams) -> Any: + if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor( + data_params.perturbed_result + ): + return data_params.original_result + try: + if self.params.fuzz_device == DeviceType.NPU: + self.cmp_output_npu(data_params) + else: + self.other_compare(data_params) + except Exception as e: + print_error_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name}, " + f"when campare the result exception raise {e}" + ) + return data_params.original_result diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..3279375a2b47c326e0b0b89eb709b00ed3881c49 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py @@ -0,0 +1,20 @@ +from typing import Any + +from atat.pytorch.free_benchmark.common.params import DataParams +from atat.pytorch.free_benchmark.common.utils import Tools +from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler +from atat.pytorch.free_benchmark import print_error_log_rank_0 + + +class FixHandler(FuzzHandler): + + def handle(self, data_params: DataParams) -> Any: + try: + return Tools.convert_fuzz_output_to_origin( + data_params.original_result, data_params.perturbed_result + ) + except Exception as e: + print_error_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name} " + f"Fix output failed. " + ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..9ecd2d7882a60220d1b1d7dfba775f91ac35cafa --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py @@ -0,0 +1,32 @@ +from atat.pytorch.free_benchmark import FreeBenchmarkException +from atat.pytorch.free_benchmark.common.constant import PreheatConfig +from atat.pytorch.free_benchmark.common.utils import Tools +from atat.pytorch.free_benchmark.common.enums import HandlerType +from atat.pytorch.free_benchmark.common.params import HandlerParams +from atat.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler +from atat.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler +from atat.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler + + +class FuzzHandlerFactory: + + result_handlers = { + HandlerType.CHECK: CheckerHandler, + HandlerType.FIX: FixHandler, + HandlerType.PREHEAT: PreheatHandler, + } + + @staticmethod + def create(params: HandlerParams): + if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT) + if not if_preheat: + handler = FuzzHandlerFactory.result_handlers.get(params.handler_type) + else: + handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT) + # TODO + if not handler: + raise FreeBenchmarkException( + FreeBenchmarkException.UnsupportedType, + f"无标杆工具支持 [ {HandlerType.CHECK}、{HandlerType.PREHEAT}] 形式", + ) + return handler(params) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..1c79562ddff74f851b7cf38bdb464ab2b3e19cd3 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py @@ -0,0 +1,200 @@ +from typing import Any + +import torch +import math +from atat.pytorch.free_benchmark import print_info_log_rank_0, print_error_log_rank_0 +from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from atat.pytorch.free_benchmark.common.enums import DeviceType +from atat.pytorch.free_benchmark.common.params import DataParams, make_unequal_row +from atat.pytorch.free_benchmark.common.utils import Tools +from atat.pytorch.free_benchmark.compare.single_benchmark import SingleCompare +from atat.pytorch.free_benchmark.common.counter import preheat_counter +from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler +from atat.pytorch.free_benchmark.common.params import HandlerParams + + +class PreheatHandler(FuzzHandler): + + def __init__(self, params: HandlerParams) -> None: + super().__init__(params) + self.pure_name = Tools.get_pure_api_name(self.params.api_name) + + def get_threshold(self, dtype): + return preheat_counter.get_api_thd(self.pure_name, dtype) + + def _is_take_a_sample(self) -> bool: + need_sample_set = self._get_need_sample_set() + curr_called_seq = preheat_counter.get_api_called_time(self.pure_name) + res = curr_called_seq in need_sample_set + if res: + total_count = preheat_counter.get_one_step_used_api(self.pure_name) + print_info_log_rank_0( + f"[atat] Free benchmark: preheat sample in step{self.params.step}" + f"api_name {self.params.api_name}, " + f"curr_called_seq: {curr_called_seq}/{total_count}" + ) + return res + + def _get_sample_count_per_step(self, total_count) -> set: + """ + 每一个step中应该采集的样本数 + """ + preheat_step = self.params.preheat_config.get("preheat_step") + max_sample = self.params.preheat_config.get("max_sample") + return min(math.ceil(total_count / preheat_step), max_sample) + + def _get_need_sample_set(self): + """ + 需要采集的api集合 + """ + total_count = preheat_counter.one_step_used_api.get(self.pure_name) + # 每一步样本数 + sample_count_per_step = self._get_sample_count_per_step(total_count) + need_sample_set = set() + prehead_step = self.params.preheat_config.get("preheat_step") + for i in range(1, sample_count_per_step + 1): + count = (prehead_step * (i - 1) + self.params.step) % total_count + if count == 0: + count = total_count + need_sample_set.add(count) + return need_sample_set + + def cmp_output_npu_with_preheat(self, data_params: DataParams): + npu_consistent = True + max_fuzz_ratio = 0 + try: + if isinstance(data_params.original_result, torch.Tensor): + is_consistent, ratio = self.npu_compare( + data_params.original_result, data_params.perturbed_result + ) + npu_consistent = is_consistent + max_fuzz_ratio = ( + max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio) + ) + data_params.is_consistent = is_consistent and data_params.is_consistent + if not is_consistent and data_params.grad_unequal_flag: + self.unequal_rows.append( + make_unequal_row(data_params, self.params, ratio=ratio) + ) + + elif isinstance(data_params.original_result, (list, tuple)): + for index_, origin_item in enumerate(data_params.original_result): + is_consistent, ratio = self.npu_compare( + origin_item, data_params.perturbed_result[index_] + ) + npu_consistent = npu_consistent and is_consistent + max_fuzz_ratio = ( + max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio) + ) + data_params.is_consistent = ( + is_consistent and data_params.is_consistent + ) + if not is_consistent and data_params.grad_unequal_flag: + self.unequal_rows.append( + make_unequal_row( + data_params, self.params, ratio=ratio, index=index_ + ) + ) + except Exception as e: + print_error_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name}, " + f"when campare the result exception raise {e}" + ) + return npu_consistent, max_fuzz_ratio + + def compare_npu_and_cpu(self, data_params: DataParams): + args = Tools.convert_device_and_dtype( + data_params.args, DeviceType.CPU, change_dtype=True + ) + kwargs = Tools.convert_device_and_dtype( + data_params.kwargs, DeviceType.CPU, change_dtype=True + ) + cpu_result = data_params.origin_func(*args, **kwargs) + return SingleCompare.compare_seq(data_params.origin_result, cpu_result) + + def _need_adjust_threshold(self) -> bool: + sample_count_per_step = self._get_sample_count_per_step() + sampled_time = preheat_counter.get_one_step_used_api(self.pure_name) + res = sampled_time >= sample_count_per_step + return res + + def _adjust_threshold_for_dtype(self, dtype_str, compare_result): + con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent] + incon_ratio = [ + ratio for ratio, is_consistent in compare_result if not is_consistent + ] + old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str) + new_thd = old_thd + # 正例负例都存在 + if con_ratio and incon_ratio: + if min(incon_ratio) > max(con_ratio): + new_thd = min(min(incon_ratio), old_thd) + elif con_ratio: + # 存在漏报 + if max(con_ratio) > old_thd: + new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP) + else: + new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP) + else: + new_thd = min(min(incon_ratio), old_thd) + if incon_ratio: + preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False) + return new_thd + + def _adjust_threshold(self): + for dtype_str, compare_result in preheat_counter.preheat_record[ + self.pure_name + ].items(): + new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result) + threshold = self._get_default_threshold( + preheat_counter.dtype_map.get(dtype_str) + ) + preheat_counter.update_api_thd( + self, self.pure_name, dtype_str, new_thd, threshold + ) + + def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype): + preheat_counter.add_api_sampled_time(self.pure_name) + # 存储当前step所有输出比值和对应npu\cpu比对结果 + preheat_counter.update_preheat_record( + self.params.step, + self.pure_name, + first_dtype, + (max_fuzz_ratio, cpu_consistent), + ) + if self._need_adjust_threshold(): + self._adjust_threshold() + + def handle(self, data_params: DataParams) -> Any: + + if not data_params.grad_unequal_flag: + data_params.grad_unequal_flag = True + data_params.is_consistent = False + return data_params.original_result + if self.params.step == 0: + return data_params.original_result + + preheat_counter.add_api_called_time(self.pure_name) + + if not self._is_take_a_sample(): + return data_params.original_result + if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor( + data_params.perturbed_result + ): + return data_params.original_result + + # 如果当前api,step需要预热 + npu_consistent, max_fuzz_ratio = self.cmp_output_npu_with_preheat(data_params) + data_params.is_consistent = npu_consistent + try: + cpu_consistent = self.compare_npu_and_cpu(data_params) + except Exception as e: + print_error_log_rank_0( + f"[atat] Free Benchmark: For {self.params.api_name}, " + f"when campare to cpu exception raise {e}" + ) + first_dtype = Tools.get_first_tensor_dtype(data_params.perturbed_result) + if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)): + self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype) + + return data_params.original_result