diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/atat/pytorch/common/utils.py index e88d506b2c340f9b6141c2e0bb775a693d61a16c..fbd92d87297c7bf611bf4c7c8c72b62acef4baa1 100644 --- a/debug/accuracy_tools/atat/pytorch/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/common/utils.py @@ -169,6 +169,7 @@ class Const: NUMPY_SUFFIX = ".npy" ONE_GB = 1 * 1024 * 1024 * 1024 TEN_GB = 10 * 1024 * 1024 * 1024 + ONE_MB = 1 * 1024 * 1024 FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' FILE_NAME_LENGTH = 255 DIRECTORY_LENGTH = 4096 diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py index 9b72437f2280ca44a20fc5e370f1cfd9b9ea3ac4..9e25ef802f926e6b04360011970fc722d933da17 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py @@ -59,6 +59,8 @@ class ThresholdConfig: torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4), } + TENSOR_SPLIT_MAX_CHUNK = 128 + class PreheatConfig: IF_PREHEAT = "if_preheat" diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py index 24d25967635b3dcfd1da89e1f54d3282fa1181ed..36e64bf16321544eee1367b6ffb5278cfe3a7253 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py @@ -96,3 +96,5 @@ class TorchC: add = torch._C._VariableFunctionsClass.add bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor clone = torch._C._VariableFunctionsClass.clone + clamp = torch._C._VariableFunctionsClass.clamp + tensor_split = torch._C._VariableFunctionsClass.tensor_split diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py index a8752656ed72bc21773aca2bb06d4e69d96a5c4b..521285d80c8bc7c05fb08c262fd5c94494781faa 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py @@ -69,6 +69,8 @@ class GradSaver: f"[atat] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}." f"{e}" ) + self.data_params.perturbed_result = None + self.data_params.original_result = None def check_grad_input(self, origin_grad, new_grad_index): if self.perturbed_grad_input is None: @@ -167,6 +169,7 @@ class GradSaver: self.handler_params.pert_mode, ) layer.handle(self.data_params) + self.data_params.args = None self.perturbed_grad_input = tuple( [x.cpu() for x in self.data_params.perturbed_result] ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py index 1d59ef9fc3adc2f90a7145d825ce597e209758e4..5801c4fe87e4b45429d66551787aae9e0c61484c 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py @@ -1,4 +1,5 @@ import math +import numpy as np from abc import ABC, abstractmethod from typing import Any, Optional, Tuple @@ -32,9 +33,9 @@ class FuzzHandler(ABC): origin_ouput = origin_ouput.values perturbed_output = perturbed_output.values if hasattr(perturbed_output, "dtype"): - abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype) + abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD) else: - abs_tol = FuzzThreshold.F32_THD.value + abs_tol = FuzzThreshold.F32_THD return ( origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device), perturbed_output, @@ -54,48 +55,43 @@ class FuzzHandler(ABC): return ThresholdConfig.COMP_CONSISTENT return ratio + @staticmethod + def tensor_split_for_endless_norm(origin_output, perturbed_output): + """ + 对将投入误差值计算的扰动前后输出张量进行分块 + :param origin_output: 原始输出 + :param perturbed_output: 扰动后输出 + :return origin_output_chunks: 切块后原始输出列表 + :return perturbed_output_chunks: 切块后扰动后输出列表 + """ + single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB + if single_output_mem == 0 or origin_output.ndim == 0: + return ([origin_output], [perturbed_output]) + chunks_exp = int(math.log(single_output_mem, 2)) - 4 + chunks = 2 ** chunks_exp + chunks = max(chunks, 1) + chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK) + origin_output_chunks = TorchC.tensor_split(origin_output.view(-1), chunks) + perturbed_output_chunks = TorchC.tensor_split(perturbed_output.view(-1), chunks) + return origin_output_chunks, perturbed_output_chunks + def get_endless_norm(self, origin_output, perturbed_output, abs_tol): - try: - ratio_tensor1 = TorchC.where( - TorchC.gt(TorchC.abs(perturbed_output), abs_tol), - TorchC.div( - TorchC.abs(origin_output), - TorchC.add(TorchC.abs(perturbed_output), abs_tol), - ), - 1, - ) - ratio_tensor2 = TorchC.where( - TorchC.gt(TorchC.abs(origin_output), abs_tol), - TorchC.div( - TorchC.abs(perturbed_output), - TorchC.add(TorchC.abs(origin_output), abs_tol), - ), - 1, - ) - except: - ratio_tensor1 = TorchC.where( - TorchC.gt(TorchC.abs(perturbed_output.to(torch.float32)), abs_tol), - TorchC.div( - origin_output.to(torch.float32), perturbed_output.to(torch.float32) - ), - 1, - ) - ratio_tensor2 = TorchC.where( - TorchC.gt(TorchC.abs(origin_output.to(torch.float32)), abs_tol), - TorchC.div( - perturbed_output.to(torch.float32), origin_output.to(torch.float32) - ), - 1, - ) - norm1 = self.convert_overflow_ratio_to_consistent( - TorchC.max(ratio_tensor1).item() - ) - norm2 = self.convert_overflow_ratio_to_consistent( - TorchC.max(ratio_tensor2).item() - ) - norm3 = self.convert_overflow_ratio_to_consistent( - TorchC.min(ratio_tensor1).item() - ) + origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_endless_norm(origin_output, perturbed_output) + norm1 = -np.inf + norm2 = -np.inf + norm3 = np.inf + for i, chunk_origin in enumerate(origin_output_chunks): + if chunk_origin.nelement() == 0: + break + chunk_perturbed = perturbed_output_chunks[i] + ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol, + TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1) + ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol, + TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1) + norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(TorchC.max(ratio_tensor1).item())) + norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(TorchC.max(ratio_tensor2).item())) + norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(TorchC.max(ratio_tensor1).item())) + if norm3 < 0: ratio = ThresholdConfig.SYMBOL_FLIPPING else: diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py index 2f590855f1b96e0a6475c87c9b3dfdafd0288332..742f0dc8f93a10d458a0428ce7027d7bf5548c0b 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py @@ -10,7 +10,6 @@ from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler class CheckerHandler(FuzzHandler): - @staticmethod def other_compare(self, data_params: DataParams) -> bool: is_consistent = SingleCompare().compare_seq( data_params.original_result, data_params.perturbed_result