diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py index 971776d1326409c8878849e7b09a4614ffbc16f5..69ece0a0c6a7a58fe8904bde470b16bb32c0d404 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py @@ -10,7 +10,10 @@ from msprobe.pytorch.free_benchmark.common.enums import ( HandlerType, PerturbationMode, ) -from msprobe.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params +from msprobe.pytorch.free_benchmark.common.params import ( + data_pre_deal, + make_handler_params, +) from msprobe.pytorch.free_benchmark.compare.grad_saver import GradSaver from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( @@ -70,9 +73,9 @@ class FreeBenchmarkCheck(ABC): layer.handle(data_params) handler_params = make_handler_params(name, self.config, self.current_iter) handler = FuzzHandlerFactory.create(handler_params) - handler.handle(data_params) - return data_params.perturbed_result, handler.get_unequal_rows() - + perturbed_output = handler.handle(data_params) + return perturbed_output, handler.get_unequal_rows() + def backward(self, name, module, grad_output): if not self.config.fuzz_stage == Const.BACKWARD: