diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py index 326f623be728c13b2a39627b67b3fca63babe3eb..5d88912024fc93858d27f1f611f4d30b3d2cd5c7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py @@ -39,7 +39,6 @@ class DataParams: origin_func: Optional[Callable] = None api_type: Optional[str] = None fuzz_stage: Optional[str] = None - grad_unequal_flag: Optional[bool] = True @dataclass @@ -127,6 +126,8 @@ def make_unequal_row( ) if isinstance(ratio, float): row.max_rel = ratio - 1 + if isinstance(ratio, str): + row.max_rel = ratio origin_tensor = data_params.original_result perturbed_tensor = data_params.perturbed_result if index is not None: diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py index 2d865ea0348a02ec1ed930c6964da2aa842e8984..e3fd2b69fef2772354401a22344376258e77a008 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py @@ -124,6 +124,7 @@ class TorchC: abs = torch._C._VariableFunctionsClass.abs where = torch._C._VariableFunctionsClass.where div = torch._C._VariableFunctionsClass.div + mul = torch._C._VariableFunctionsClass.mul max = torch._C._VariableFunctionsClass.max min = torch._C._VariableFunctionsClass.min gt = torch._C._VariableFunctionsClass.gt @@ -138,3 +139,5 @@ class TorchC: tensor_split = torch._C._VariableFunctionsClass.tensor_split stack = torch._C._VariableFunctionsClass.stack reshape = torch._C._VariableFunctionsClass.reshape + nan_to_num = torch._C._VariableFunctionsClass.nan_to_num + aminmax = torch._C._VariableFunctionsClass.aminmax diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py index efdf23999bc1abfa01e7a6bb14960a0aef0d4671..58cfea45d00459db65355a2cdba4471bac7b754e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py @@ -82,13 +82,11 @@ class GradSaver: data_params = DataParams() data_params.original_result = origin_grad data_params.perturbed_result = perturbed_grad - data_params.grad_unequal_flag = False data_params.valid_input_index = index try: handler.handle(data_params) if not data_params.is_consistent: self.is_compare = False - data_params.grad_unequal_flag = True data_params.is_consistent = True data_params.perturbed_result = self.perturbed_grad_input data_params.original_result = self.origin_grad_input diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py index 0c5282faa7d37c9348312f514e66c410032a0c73..e89ca5d9f43bcc1f5bfe850c2db17c3ad9dc4688 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py @@ -89,12 +89,6 @@ class FuzzHandler(ABC): ) return origin_output_chunks, perturbed_output_chunks - @staticmethod - def convert_overflow_ratio_to_consistent(ratio): - if math.isnan(ratio) or math.isinf(ratio): - return ThresholdConfig.COMP_CONSISTENT - return ratio - @abstractmethod def get_threshold(self, dtype): pass @@ -107,10 +101,10 @@ class FuzzHandler(ABC): self, origin_output, perturbed_output, norm_type, abs_tol ): if norm_type == NormType.ENDLESS_NORM: - return self.calculate_error(origin_output, perturbed_output, abs_tol) + return self.calculate_max_ratio(origin_output, perturbed_output, abs_tol) return ThresholdConfig.COMP_CONSISTENT - def calculate_error(self, origin_output, perturbed_output, abs_tol): + def calculate_max_ratio(self, origin_output, perturbed_output, abs_tol): origin_output_chunks, perturbed_output_chunks = ( self.tensor_split_for_error_calculate(origin_output, perturbed_output) ) @@ -122,42 +116,30 @@ class FuzzHandler(ABC): raise FreeBenchmarkException( FreeBenchmarkException.OutputIndexError, err_msg ) - norm1 = -np.inf - norm2 = -np.inf - norm3 = np.inf + + max_ratio = ThresholdConfig.COMP_CONSISTENT for i, chunk_origin in enumerate(origin_output_chunks): if chunk_origin.nelement() == 0: break chunk_perturbed = perturbed_output_chunks[i] - ratio_tensor1 = TorchC.where( - TorchC.abs(chunk_perturbed) > abs_tol, - TorchC.div( - TorchC.clamp(chunk_origin, min=abs_tol), - TorchC.clamp(chunk_perturbed, min=abs_tol), - ), - 1, - ) - ratio_tensor2 = TorchC.where( - TorchC.abs(chunk_origin) > abs_tol, - TorchC.div( - TorchC.clamp(chunk_perturbed, min=abs_tol), - TorchC.clamp(chunk_origin, min=abs_tol), - ), - 1, + # 如果乘积最小值 < 极小值乘积的负值,认为存在非极小值符号相反的情况 + if TorchC.lt( + TorchC.min(TorchC.mul(chunk_origin, chunk_perturbed)), -(abs_tol**2) + ): + return ThresholdConfig.SYMBOL_FLIPPING + # 求A/B B/A的比值前,将值限制在大于极小值范围内 + clamp_origin = TorchC.clamp(TorchC.abs(chunk_origin), min=abs_tol) + clamp_perturbed = TorchC.clamp(TorchC.abs(chunk_perturbed), min=abs_tol) + # 对于计算结果为nan的情况,认为两者没有差异 + ratio_tensor = TorchC.nan_to_num( + TorchC.div(clamp_origin, clamp_perturbed), + nan=ThresholdConfig.COMP_CONSISTENT, ) - norm_values = TorchC.stack( - [TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)] - ) - max_ratio1, max_ratio2 = norm_values.tolist() - norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1)) - norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2)) - norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1)) - - if norm3 < 0: - ratio = ThresholdConfig.SYMBOL_FLIPPING - else: - ratio = max(norm1, norm2) - return ratio + # 求A/B 和 B/A比值最大值,其中 B/A的最大值为 A/B的最小值的倒数 + min_ratio, max_ratio = TorchC.stack([*TorchC.aminmax(ratio_tensor)]).tolist() + min_ratio_reciprocal = np.inf if min_ratio == 0 else 1 / min_ratio + max_ratio = max(max_ratio, min_ratio_reciprocal, max_ratio) + return max_ratio def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float: try: @@ -220,10 +202,12 @@ class FuzzHandler(ABC): ) npu_consistent = is_consistent max_fuzz_ratio = ( - max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio) + max_fuzz_ratio + if not isinstance(ratio, (int, float)) + else max(max_fuzz_ratio, ratio) ) - data_params.is_consistent = is_consistent and data_params.is_consistent - if not is_consistent and data_params.grad_unequal_flag: + data_params.is_consistent = is_consistent + if not is_consistent: self.unequal_rows.append( make_unequal_row(data_params, self.params, ratio=ratio) ) @@ -235,12 +219,12 @@ class FuzzHandler(ABC): ) npu_consistent = npu_consistent and is_consistent max_fuzz_ratio = ( - max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio) - ) - data_params.is_consistent = ( - is_consistent and data_params.is_consistent + max_fuzz_ratio + if not isinstance(ratio, (int, float)) + else max(max_fuzz_ratio, ratio) ) - if not is_consistent and data_params.grad_unequal_flag: + data_params.is_consistent = is_consistent + if not is_consistent: self.unequal_rows.append( make_unequal_row( data_params, self.params, ratio=ratio, index=index_ diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py index 1bfed53e6d8223a3a3f90004b0b9e2915c1729a6..5bfc672df18b97eb47b2ffed3ab9bf6b66dd03e7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py @@ -75,10 +75,6 @@ class PreheatHandler(FuzzHandler): if self.params.preheat_config.get("preheat_step") <= self.params.step: return data_params.original_result - if not data_params.grad_unequal_flag: - data_params.grad_unequal_flag = True - data_params.is_consistent = False - return data_params.original_result preheat_counter.add_api_called_time(self.pure_name) if not self._is_take_a_sample(): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_calculate_max_ratio.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_calculate_max_ratio.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbaceb49acdef603305c704d167416318e00f87 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_calculate_max_ratio.py @@ -0,0 +1,69 @@ +from unittest import TestCase + +import torch +from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig +from msprobe.pytorch.free_benchmark.common.params import HandlerParams +from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler + + +class TestFuzzHandler(TestCase): + + def setUp(self) -> None: + self.api_name = "test_api" + self.handler = CheckerHandler(HandlerParams(api_name=self.api_name)) + self.abs_tol = 1e-4 + + def test_calculate_max_ratio_with_equal_outputs(self): + # 测试两个输出相等时,比值应该接近1 + origin_output = torch.tensor([1.0, 2.0, 3.0]) + perturbed_output = torch.tensor([1.0, 2.0, 3.0]) + max_ratio = self.handler.calculate_max_ratio( + origin_output, perturbed_output, self.abs_tol + ) + self.assertAlmostEqual(max_ratio, 1.0) + + def test_calculate_max_ratio_with_different_outputs(self): + # 测试两个输出不同时,比值应该为最大的比值 + origin_output = torch.tensor([1.0, 2.0, 1e-4]) + perturbed_output = torch.tensor([1.3, 2.7, 1e-3]) + max_ratio = self.handler.calculate_max_ratio( + origin_output, perturbed_output, self.abs_tol + ) + result = torch.tensor(1e-3) / torch.tensor(1e-4) + self.assertAlmostEqual(max_ratio, result.item()) + + def test_calculate_max_ratio_with_tol_elements(self): + # 测试忽略绝对值小于极小值的情况,小于的全部变为极小值计算 + origin_output = torch.tensor([1.0, 1e-8, 1e-6]) + perturbed_output = torch.tensor([1.0, 1e-4, -1e-8]) + max_ratio = self.handler.calculate_max_ratio( + origin_output, perturbed_output, self.abs_tol + ) + self.assertAlmostEqual(max_ratio, 1.0) + + def test_calculate_max_ratio_with_symbol_flipping(self): + # 测试乘积符号相反时,应该返回SYMBOL_FLIPPING + origin_output = torch.tensor([1.0, -2.0, 3.0]) + perturbed_output = torch.tensor([1.0, 2.0, 3.0]) + result = self.handler.calculate_max_ratio( + origin_output, perturbed_output, self.abs_tol + ) + self.assertEqual(result, ThresholdConfig.SYMBOL_FLIPPING) + + def test_calculate_max_ratio_with_nan_values(self): + # 测试包含NaN值时,函数应该正确计算 + origin_output = torch.tensor([1.0, float("nan"), 2.0]) + perturbed_output = torch.tensor([1.1, float("nan"), 2.4]) + max_ratio = self.handler.calculate_max_ratio( + origin_output, perturbed_output, self.abs_tol + ) + self.assertAlmostEqual(max_ratio, 1.2) + + def test_calculate_max_ratio_with_empty_chunks(self): + # 测试空的输出块时,函数应该正确处理 + origin_output = torch.tensor([]) + perturbed_output = torch.tensor([]) + max_ratio = self.handler.calculate_max_ratio( + origin_output, perturbed_output, self.abs_tol + ) + self.assertEqual(max_ratio, ThresholdConfig.COMP_CONSISTENT)