From 207cc468945fc72d03dec689708cc35b431cf369 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Thu, 19 Sep 2024 11:13:22 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=97=A0=E6=A0=87=E6=9D=86=E5=A0=86?= =?UTF-8?q?=E6=A0=88=E5=86=85=E5=AE=B9=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/core/data_dump/data_collector.py | 25 ++++++++++---- .../data_processor/pytorch_processor.py | 33 +++++++++++++++---- .../msprobe/core/data_dump/json_writer.py | 3 ++ .../pytorch/free_benchmark/common/params.py | 4 +++ .../free_benchmark/compare/grad_saver.py | 3 +- 5 files changed, 54 insertions(+), 14 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index 166113f146..3a26d3f377 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -13,7 +13,8 @@ def build_data_collector(config): class DataCollector: multi_output_apis = ["_sort_", "npu_flash_attention"] - tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK] + tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR] + tasks_no_need_jsons = [Const.FREE_BENCHMARK,] level_without_construct = ["L1", "L2"] def __init__(self, config): @@ -57,6 +58,9 @@ class DataCollector: self.data_processor.update_api_or_module_name(api_or_module_name) def write_json(self): + # 无标杆场景不处理json数据 + if self.config.task in self.tasks_no_need_jsons: + return self.data_writer.write_json() def update_data(self, data_info, msg=''): @@ -71,9 +75,13 @@ class DataCollector: return msg def pre_forward_data_collect(self, name, module, pid, module_input_output): - backward_name = name.replace(Const.FORWARD, Const.BACKWARD) - if self.check_scope_and_pid(self.scope, backward_name, pid): - self.data_processor.analyze_pre_forward(backward_name, module, module_input_output) + if self.config.task == Const.FREE_BENCHMARK: + # 无标杆场景在反向时需要在preforwardhook中进行特殊处理 + backward_name = name.replace(Const.FORWARD, Const.BACKWARD) + if not self.check_scope_and_pid(self.scope, backward_name, pid): + return + self.data_processor.free_benchmark_pre_forward(backward_name, module, module_input_output) + return if not self.is_inplace(module) or not self.check_scope_and_pid(self.scope, name, pid): return logger.info(f"API {name} is inplace.") @@ -89,7 +97,8 @@ class DataCollector: data_info = self.data_processor.analyze_forward(name, module, module_input_output) else: data_info = self.data_processor.analyze_forward_inplace(name, module_input_output) - if self.config.level == "L2": + # 避免在l2 和无标杆场景采集堆栈信息 + if self.config.level == "L2" or self.config.task == Const.FREE_BENCHMARK: return self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) self.handle_data(name, data_info, flush=self.data_processor.is_terminated) @@ -125,6 +134,9 @@ class DataCollector: self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, flush=False): + # 无标杆场景不处理json数据 + if self.config.task in self.tasks_no_need_jsons: + return if data_info: msg = f"msprobe is collecting data on {name}. " msg = self.update_data(data_info, msg) @@ -136,7 +148,8 @@ class DataCollector: def update_dump_paths(self, *args): self.data_writer.update_dump_paths(*args) - self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level) + if self.config.task not in self.tasks_no_need_jsons: + self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level) def update_iter(self, current_iter): self.data_processor.update_iter(current_iter) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 4efe0490fe..5bbfecba97 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -7,6 +7,7 @@ import torch from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode from msprobe.core.common.log import logger from msprobe.core.common.const import Const, OverflowConst, FileCheckConst +from msprobe.core.common.utils import check_op_str_pattern_valid from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow @@ -265,17 +266,23 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor): def __init__(self, config, data_writer): super().__init__(config, data_writer) self.checker = FreeBenchmarkCheck(config=config) + self._backward_stack_info_dict = {} self._return_forward_new_output = None self._forward_new_output = None def update_iter(self, current_iter): super().update_iter(current_iter) self.checker.update_iter(current_iter) - - def update_unequal_rows(self, unequal_rows: List[UnequalRow]): - if not unequal_rows: - return + + def pop_backward_stack_info(self, backward_name): + ret = None + if backward_name in self._backward_stack_info_dict: + ret = self._backward_stack_info_dict.pop(backward_name) + return ret + + def update_unequal_rows(self, unequal_rows: List[UnequalRow], stack_info=None): for row in unequal_rows: + row.update_stack_info(stack_info) data_dict = asdict(row) self.data_writer.write_data_to_csv( data_dict.values(), @@ -284,8 +291,9 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor): ) return - def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): - self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs) + def free_benchmark_pre_forward(self, backward_name, module, module_input_output: ModuleForwardInputsOutputs): + self._backward_stack_info_dict.update(self.analyze_api_call_stack(backward_name)) + self.checker.pre_forward(backward_name, module, self, module_input_output.args, module_input_output.kwargs) def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): new_output, unequal_rows = self.checker.forward( @@ -295,7 +303,18 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor): module_input_output.kwargs, module_input_output.output, ) - self.update_unequal_rows(unequal_rows) + if unequal_rows: + stack_info = self.analyze_api_call_stack(name).get(name) + if isinstance(stack_info, list): + for item in stack_info: + check_op_str_pattern_valid(item, name, stack=True) + else: + logger.warning( + f"Expected stack_info to be a list, " + f"but got {type(stack_info).__name__} for '{name}'" + ) + stack_info = None + self.update_unequal_rows(unequal_rows, stack_info=stack_info) if self.checker.if_fix(): self._return_forward_new_output = True self._forward_new_output = new_output diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 5a14b5b527..1b309c0d2f 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -38,6 +38,9 @@ class DataWriter: change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) def initialize_json_file(self, **kwargs): + if self.config.task == Const.FREE_BENCHMARK: + # 无标杆工具只创建dump_path文件夹 + return kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) save_json(self.dump_file_path, kwargs) diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py index bbfc245a63..2457cb1bba 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py @@ -49,6 +49,10 @@ class UnequalRow: dtype: Optional[str] = None shape: Optional[str] = None output_index: Optional[int] = None + stack_info: Optional[List] = None + + def update_stack_info(self, stack_info): + self.stack_info = stack_info @dataclass diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py index e58223e597..856de7fc49 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py @@ -25,6 +25,7 @@ class GradSaver: def register_compare_func_for_inputs(self, inputs, data_processor): _index = 0 + stack_info = data_processor.pop_backward_stack_info(self.api_name) for j, obj in enumerate(inputs): if torch.is_tensor(obj) and obj.requires_grad: @@ -37,7 +38,7 @@ class GradSaver: self.compare_grad_results( handler, grad, perturbed_grad, index=input_index ) - data_processor.update_unequal_rows(handler.get_unequal_rows()) + data_processor.update_unequal_rows(handler.get_unequal_rows(), stack_info=stack_info) except IndexError: logger.warning_on_rank_0( f"[msprobe] Free benchmark: grad index out of range. api:{self.handler_params.api_name}." -- Gitee From 9ca2a3c13a083863925ff088997cd8ea6118d324 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Mon, 21 Oct 2024 17:56:43 +0800 Subject: [PATCH 2/2] =?UTF-8?q?TorchC=E6=9B=BF=E6=8D=A2=E4=B8=BAtorch.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pytorch/free_benchmark/common/utils.py | 25 -------------- .../compare/single_benchmark.py | 33 +++++++++---------- .../perturbed_layers/npu/add_noise.py | 13 ++++---- .../perturbed_layers/npu/bit_noise.py | 13 ++++---- .../perturbed_layers/npu/change_value.py | 11 +++---- .../result_handlers/base_handler.py | 16 ++++----- 6 files changed, 41 insertions(+), 70 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py index 631beeb85c..f524490d06 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py @@ -75,28 +75,3 @@ class Tools: ) return type(origin)(result) return origin - -class TorchC: - sum = torch._C._VariableFunctionsClass.sum - isinf = torch._C._VariableFunctionsClass.isinf - isfinite = torch._C._VariableFunctionsClass.isfinite - isnan = torch._C._VariableFunctionsClass.isnan - logical_not = torch._C._VariableFunctionsClass.logical_not - subtract = torch._C._VariableFunctionsClass.subtract - abs = torch._C._VariableFunctionsClass.abs - where = torch._C._VariableFunctionsClass.where - div = torch._C._VariableFunctionsClass.div - max = torch._C._VariableFunctionsClass.max - min = torch._C._VariableFunctionsClass.min - gt = torch._C._VariableFunctionsClass.gt - ge = torch._C._VariableFunctionsClass.ge - lt = torch._C._VariableFunctionsClass.lt - mean = torch._C._VariableFunctionsClass.mean - full = torch._C._VariableFunctionsClass.full - add = torch._C._VariableFunctionsClass.add - bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor - clone = torch._C._VariableFunctionsClass.clone - clamp = torch._C._VariableFunctionsClass.clamp - tensor_split = torch._C._VariableFunctionsClass.tensor_split - stack = torch._C._VariableFunctionsClass.stack - reshape = torch._C._VariableFunctionsClass.reshape diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py index 59239fcd00..bedb4fa104 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py @@ -3,7 +3,6 @@ import math import torch from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig -from msprobe.pytorch.free_benchmark.common.utils import TorchC class SingleCompare: @@ -15,15 +14,15 @@ class SingleCompare: @staticmethod def filter_overflow(tensor) -> int: - inf_num = TorchC.sum(TorchC.isinf(tensor)) - nan_num = TorchC.sum(TorchC.isnan(tensor)) + inf_num = torch.sum(torch.isinf(tensor)) + nan_num = torch.sum(torch.isnan(tensor)) return inf_num + nan_num @staticmethod def replace_inf_or_nan(tensor): - finite_mask = TorchC.isfinite(tensor) - inf_or_nan_mask = TorchC.logical_not(finite_mask) - inf_or_nan_num = TorchC.sum(inf_or_nan_mask).item() + finite_mask = torch.isfinite(tensor) + inf_or_nan_mask = torch.logical_not(finite_mask) + inf_or_nan_num = torch.sum(inf_or_nan_mask).item() if inf_or_nan_num > 0: tensor[inf_or_nan_mask] = 1 return tensor @@ -85,20 +84,20 @@ class SingleCompare: return True def _cal_compare_metrics(self, actual, golden): - diff_value = TorchC.subtract(actual, golden) - diff_abs = TorchC.abs(diff_value) - golden_abs = TorchC.abs(golden) + diff_value = torch.subtract(actual, golden) + diff_abs = torch.abs(diff_value) + golden_abs = torch.abs(golden) # 使用绝对误差的元素 - self.absolute_err = TorchC.max(TorchC.where( - TorchC.lt(TorchC.abs(actual), self.threshold.small_value), diff_abs, 0 + self.absolute_err = torch.max(torch.where( + torch.abs(actual) < self.threshold.small_value, diff_abs, 0 )) - diff_rel = TorchC.div(diff_abs, golden_abs) + diff_rel = torch.div(diff_abs, golden_abs) # 使用相对误差的元素 - self.relative_err = TorchC.max(TorchC.where( - TorchC.ge(TorchC.abs(actual), self.threshold.small_value), diff_rel, 0 + self.relative_err = torch.max(torch.where( + torch.abs(actual) >= self.threshold.small_value, diff_rel, 0 )) # 获取误差均衡性 - divided = TorchC.where( - TorchC.ge(TorchC.abs(golden), self.threshold.small_value), golden_abs, 1 + divided = torch.where( + torch.abs(golden) >= self.threshold.small_value, golden_abs, 1 ) - self.eb = TorchC.mean(TorchC.div(diff_value, divided)) + self.eb = torch.mean(torch.div(diff_value, divided)) diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py index 2ccc2bfcf7..5e3442a1f7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -3,7 +3,6 @@ from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode from msprobe.pytorch.free_benchmark.common.params import DataParams -from msprobe.pytorch.free_benchmark.common.utils import TorchC from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -19,9 +18,9 @@ class AddNoiseLayer(NpuBaseLayer): if not self.pre_check(tensor_obj): return tensor_obj noise = self._get_noise(tensor_obj) - result = TorchC.where( - TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value ** 0.5), - TorchC.add(noise, tensor_obj), + result = torch.where( + torch.abs(tensor_obj) > self.perturbed_value ** 0.5, + torch.add(noise, tensor_obj), tensor_obj, ).to(tensor_obj.dtype) self.is_added = True @@ -46,7 +45,7 @@ class AddNoiseLayer(NpuBaseLayer): def _get_noise(self, tensor_obj): dtype = tensor_obj.dtype device = str(tensor_obj.device) - noise = TorchC.full( + noise = torch.full( tensor_obj.shape, self.perturbed_value, device=device, @@ -74,13 +73,13 @@ class AddNoiseLayer(NpuBaseLayer): tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND ) try: - max_val = TorchC.max(TorchC.abs(tensor_obj)).item() + max_val = torch.max(torch.abs(tensor_obj)).item() except Exception: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " f"when calculate maximun value, tensor is changed to float32." ) - max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() + max_val = torch.max(torch.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py index a0ac216917..76d2649c25 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -3,7 +3,6 @@ from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode from msprobe.pytorch.free_benchmark.common.params import DataParams -from msprobe.pytorch.free_benchmark.common.utils import TorchC from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -12,7 +11,7 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import class BitNoiseLayer(NpuBaseLayer): def __init__(self, api_name): super().__init__(api_name) - self.bit_mode = TorchC.bitwise_xor + self.bit_mode = torch.bitwise_xor self.bit_tail: int = 1 self.bit_type = None @@ -27,15 +26,15 @@ class BitNoiseLayer(NpuBaseLayer): if not self.pre_check(tensor_obj): return tensor_obj sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal - noise = TorchC.full( + noise = torch.full( tensor_obj.shape, self.bit_tail, device=tensor_obj.device, dtype=self.bit_type, ) result = tensor_obj.view(self.bit_type) - result = TorchC.where( - TorchC.gt(TorchC.abs(tensor_obj), sub_normal), + result = torch.where( + torch.abs(tensor_obj) > sub_normal, self.bit_mode(result, noise), result, ).view(tensor_obj.dtype) @@ -79,13 +78,13 @@ class BitNoiseLayer(NpuBaseLayer): tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND ) try: - max_val = TorchC.max(TorchC.abs(tensor_obj)).item() + max_val = torch.max(torch.abs(tensor_obj)).item() except Exception: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " f"when calculate maximun value, tensor is changed to float32." ) - max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() + max_val = torch.max(torch.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: logger.info_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py index ae5bf9f03b..0ecbf05639 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -2,7 +2,6 @@ import torch from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode from msprobe.pytorch.free_benchmark.common.params import DataParams -from msprobe.pytorch.free_benchmark.common.utils import TorchC from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -19,15 +18,15 @@ class ChangeValueLayer(NpuBaseLayer): 交换张量首尾 """ if isinstance(tensor_obj, torch.Tensor) and self.pre_check(tensor_obj): - new_tensor = TorchC.clone(tensor_obj) + new_tensor = torch.clone(tensor_obj) if new_tensor.ndim == 1: - temp_first = TorchC.clone(new_tensor[self.head]) - temp_last = TorchC.clone(new_tensor[self.tail]) + temp_first = torch.clone(new_tensor[self.head]) + temp_last = torch.clone(new_tensor[self.tail]) new_tensor[self.head] = temp_last new_tensor[self.tail] = temp_first else: - temp_first = TorchC.clone(new_tensor[self.head][self.head]) - temp_last = TorchC.clone(new_tensor[self.tail][self.tail]) + temp_first = torch.clone(new_tensor[self.head][self.head]) + temp_last = torch.clone(new_tensor[self.tail][self.tail]) new_tensor[self.head][self.head] = temp_last new_tensor[self.tail][self.tail] = temp_first diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py index e36f586735..fd98288daa 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py @@ -17,7 +17,7 @@ from msprobe.pytorch.free_benchmark.common.params import ( HandlerParams, make_unequal_row, ) -from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC +from msprobe.pytorch.free_benchmark.common.utils import Tools class FuzzHandler(ABC): @@ -61,8 +61,8 @@ class FuzzHandler(ABC): chunks = 2 ** chunks_exp chunks = max(chunks, 1) chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK) - origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks) - perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks) + origin_output_chunks = torch.tensor_split(torch.reshape(origin_output, (-1,)), chunks) + perturbed_output_chunks = torch.tensor_split(torch.reshape(perturbed_output, (-1,)), chunks) return origin_output_chunks, perturbed_output_chunks @staticmethod @@ -95,11 +95,11 @@ class FuzzHandler(ABC): if chunk_origin.nelement() == 0: break chunk_perturbed = perturbed_output_chunks[i] - ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol, - TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1) - ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol, - TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1) - norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]) + ratio_tensor1 = torch.where(torch.abs(chunk_perturbed) > abs_tol, + torch.div(torch.clamp(chunk_origin, min=abs_tol), torch.clamp(chunk_perturbed, min=abs_tol)), 1) + ratio_tensor2 = torch.where(torch.abs(chunk_origin) > abs_tol, + torch.div(torch.clamp(chunk_perturbed, min=abs_tol), torch.clamp(chunk_origin, min=abs_tol)), 1) + norm_values = torch.stack([torch.max(ratio_tensor1), torch.max(ratio_tensor2)]) max_ratio1, max_ratio2 = norm_values.tolist() norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1)) norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2)) -- Gitee