diff --git a/README.md b/README.md index ea548a0bfcab119d7af06912587d5887adc8a0a8..623c205332c504b46e66c8fb40d4ea62dde292ef 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# 🚨 重要通知 +# 🚨 重要通知: **1. Ascend Training Tools 更名为 MindStudio Training Tools (mstt)。** diff --git a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md index f9bcf3476a8c3d5b68049cb8e6af7190dd6bb33b..a5d9e27062f7aac28a5ced084f00a4b56160e0be 100644 --- a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md +++ b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md @@ -103,11 +103,11 @@ PyTorch 与 MindSpore 动态图场景下,"level"须为"L1";MindSpore 静态 参数解释是否必选 scope自定义检测 API 列表(仅 PyTorch 场景支持),list[str] 类型,默认值为空列表,当 list 也为空列表时,表示检测所有 API。需要在 [ ] 内配置具体 API 名(在 dump 的结果中查看)。与 list 参数不能同时配置。
配置示例:"scope": ["Torch.matmul.0.forward", "Tensor.pow.4.forward"]。否 list自定义检测 API 类型或 API 名称,list[str] 类型,默认值为空列表,表示检测所有 API(PyTorch 场景下还需 scope 也为空列表)。与 scope 参数不能同时配置。否 - PyTorch 场景:指定某一类 API,对某一类的 API 进行无标杆比对。
配置示例:"list": ["relu"]。 + PyTorch 场景:指定某一类 API,对某一类的 API 进行无标杆比对。
配置示例:"list": ["relu"]。针对任务auto_fix, 须知其中scale功能支持matmul,bmm,softmax,linear,其他算子将跳过scale使用其他修复方法。 MindSpore 场景:指定 API 名称,对列表中的 API 进行检测。
配置示例:"list": ["mindspore.mint.div", "mindspore.ops.bmm", "mindspore.Tensor.__add__"]。 fuzz_device标杆设备,str 类型。可选参数:
"npu":无标杆,通过添加扰动因子进行比对,默认值;
"cpu":以 CPU 为标杆,pert_mode 须配置为"to_cpu"(仅 PyTorch 场景支持)。
配置示例:"fuzz_device": "npu"。否 - pert_mode无标杆扰动因子,str 类型。可选参数:
"improve_precision":对输入做升精度,默认值;
"add_noise":对输入增加噪声;
"no_change":不加扰动直接二次执行;
"bit_noise":输入的末位比特翻转,MindSpore 场景不支持 BF16 类型的向量;
"change_value":输入的张量首尾值调换;
"to_cpu":在 CPU 等价执行(仅 PyTorch 场景支持)。
配置示例:"pert_mode": "improve_precision"。否 - handler_type处理类型,可选参数:
"check":进行无标杆比对检查,默认值;
"fix":将扰动后的 API 输出结果覆盖原始 API 输出结果,尝试将 Loss 曲线恢复正常,该模式下不支持预热功能与反向过程,且仅支持"improve_precision"、"to_cpu"( PyTorch 场景)两种扰动因子。
配置示例:"handler_type": "check"。否 + pert_mode无标杆扰动因子,str 类型。可选参数:
"improve_precision":对输入做升精度,默认值;
"add_noise":对输入增加噪声;
"no_change":不加扰动直接二次执行;
"bit_noise":输入的末位比特翻转,MindSpore 场景不支持 BF16 类型的向量;
"change_value":输入的张量首尾值调换;
"to_cpu":在 CPU 等价执行(仅 PyTorch 场景支持)。
"auto_fix":使用scale、切精度、同步等方法快速排除和恢复算子问题。
配置示例:"pert_mode": "improve_precision"。否 + handler_type处理类型,可选参数:
"check":进行无标杆比对检查,默认值;
"fix":将扰动后的 API 输出结果覆盖原始 API 输出结果,尝试将 Loss 曲线恢复正常,该模式下不支持预热功能与反向过程,且仅支持"improve_precision"、"to_cpu"( PyTorch 场景)、"auto_fix"( PyTorch 场景)三种扰动因子。
配置示例:"handler_type": "check"。否 fuzz_level无标杆数据 dump 级别,即选择比对结果文件应输出的表头属性,当前仅支持取值为:"L1"。输出结果详见 1.6.1 无标杆比对数据存盘格式。否 fuzz_stage比对过程,选择对 API 前向或反向进行无标杆比对,可选参数:
"forward":前向,默认值;
"backward":反向, 仅 PyTorch 场景支持。当 fuzz_stage 为 "backward" 时,handler_type 只能为 "check"。
配置示例:"fuzz_stage": "backward"。否 if_preheat预热功能(仅 PyTorch 场景支持),bool 类型。开启功能后工具可以根据每次迭代的输出调整精度算法的阈值,从而更准确地找出存在精度问题的 API。当"handler_type": "fix"时,不支持预热。可选参数:
true(开启)或 false(关闭),默认关闭。
配置示例:"if_preheat": "true"。否 diff --git a/debug/accuracy_tools/msprobe/docs/15.free_benchmarking_PyTorch.md b/debug/accuracy_tools/msprobe/docs/15.free_benchmarking_PyTorch.md index a2bc2112c16444eb76838c0f931fc3e94b502df7..bd345813400d31eacb9e425dd4c4ca9837a79e0f 100644 --- a/debug/accuracy_tools/msprobe/docs/15.free_benchmarking_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/15.free_benchmarking_PyTorch.md @@ -1,7 +1,7 @@ # PyTorch 场景的无标杆比对 ## 1 简介 -* 本工具的目标是在不依赖标杆数据的情况下,检测模型训练中可能存在的精度问题API级别算子,并提供升精度和tocpu接口快速验证。 +* 本工具的目标是在不依赖标杆数据的情况下,检测模型训练中可能存在的精度问题API级别算子,并提供升精度和tocpu接口快速验证,以及针对算子的快速恢复。 * 工具基于**数值病态分析理论**:对算子的输入增加很小的扰动,从而放大输出值异常现象;检测算子原始输出和扰动后输出间误差是否符合精度标准。 * 该工具的**特点**有: @@ -10,6 +10,7 @@ * 推荐使用场景(针对**算子精度问题**): * **暂无标杆数据**,模型Loss异常,要做精度问题算子排查; * **验证可疑算子**,要做进一步确认,验证是否对模型Loss有影响; + * **可疑算子快速恢复**,使用scale、切精度、同步等方法快速排除和恢复算子问题; * 低精度模型效果不如高精度,要做精度问题算子排查。 * 该工具的约束 * 仅支持Pytorch2.x场景; @@ -20,7 +21,7 @@ 2. **扰动因子**:基于torch.nn.Module的hook机制,在注册的hook函数中对算子输入进行特定类型扰动。 3. **误差分析**: * **check**: 在hook函数中二次执行算子得到扰动后的算子输出,计算扰动后输出与原始输出的相对误差,查看是否符合精度标准; - * **fix**: 需要做验证时,可以选择将特定扰动类型(升精度,to cpu)的输出替换原始输出,观察对模型Loss是否有影响。 + * **fix**: 需要做验证时,可以选择将特定扰动类型(升精度,to cpu)的输出替换原始输出,观察对模型Loss是否有影响;需要恢复算子时,可以选择自动恢复,工具将自动执行——检测前向中Nan/inf/全0问题,然后基于缩放->切高精度->Synchronize->Contiguous->引导tocpu的顺序进行排查替换。 4. **精度风险算子**:不达标精度标准的,最终会在输出件中展示 ![alt text](./img/free_benchmark_framework.png) @@ -87,7 +88,7 @@ D-->config.json配置 - +
参数是否必选可配置项适用场景
scope自定义需要通过指定算子名来限制算子插桩范围 如:["Torch.matmul.0.forward", "Tensor.pow.4.forward"]。
list自定义需要通过指定算子类型来限制算子插桩范围 如:["relu"] 会匹配所有算子名中包含relu的算子。
list自定义需要通过指定算子类型来限制算子插桩范围 如:["relu"] 会匹配所有算子名中包含relu的算子。针对任务auto_fix, 须知其中scale功能支持matmul,bmm,softmax,linear,其他算子将跳过scale使用其他修复方法。
fuzz_stage"forward"(默认)需要进行算子前向计算的精度问题排查或验证可疑算子。
"backward"需要进行算子反向计算的精度问题排查,不支持仅反向验证,前向验证包括反向。
@@ -96,12 +97,13 @@ D-->config.json配置 - + +
参数是否必选可配置项适用场景
pert_mode"improve_precision" (默认)(常用)(可做验证) 插桩算子可能在低精度下有精度问题,扰动因子会将输入的低精度向量升精度。
pert_mode"improve_precision" (默认)(常用)(可做验证) 插桩算子可能在低精度下有精度问题,扰动因子会将输入的低精度向量升精度。
"bit_noise"(常用)插桩算子可能在轻微扰动下暴露精度问题,扰动因子会将输入向量最后一个比特位翻转。
"add_noise"插桩算子可能在轻微扰动下暴露精度问题,扰动因子会为输入向量增加一个极小。
"change_value"插桩算子可能存在大数吃小数问题,扰动因子会交换输入向量的首尾。
"no_change"插桩算子可能存在数值稳定性精度问题,扰动因子会复制原始输。
"to_cpu"(可做验证) 插桩算子可能在同 CPU 精度表现不一致,扰动因子会将输入转至 CPU,需要配合 fuzz_device="cpu"使用。
"auto_fix"(专做修复) 已有怀疑算子,实现自动恢复,检测前向中Nan/inf/全0问题,按照缩放->切高精度->Synchronize->Contiguous->引导tocpu的顺序进行排查替换,快速恢复。
fuzz_device"npu" (默认)pert_mode 不需要to cpu操作。
"cpu"pert_mode 须配置为"to_cpu",目前仅支持"to cpu"扰动因子。
@@ -111,7 +113,7 @@ D-->config.json配置 - +
参数是否必选可配置项适用场景
handler_type"check"(默认)要做精度问题算子排查,输出扰动前后不符合精度标准的算子,支持所有扰动因子。
"fix"要做可疑算子验证,用扰动后输出替换原始输出,支持"improve_precision","to_cpu"两种扰动因子。
"fix"要做可疑算子验证,用扰动后输出替换原始输出,支持"improve_precision","to_cpu"两种扰动因子;要做快速修复,用扰动后输出替换原始输出,支持"auto_fix"。
### 3.3 在模型脚本中开启工具 diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/enums.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/enums.py index 181631c624ada4c9d0613a05c55cd9f11f0d66d5..dac78f016a09792539a15ef4a9f11bb3af1bf0e8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/enums.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/enums.py @@ -8,6 +8,7 @@ class PerturbationMode: NO_CHANGE = "no_change" BIT_NOISE = "bit_noise" TO_CPU = "to_cpu" + AUTO = "auto_fix" class DeviceType: @@ -48,6 +49,7 @@ class PytorchFreeBenchmarkConst: PerturbationMode.NO_CHANGE, PerturbationMode.BIT_NOISE, PerturbationMode.TO_CPU, + PerturbationMode.AUTO, ] DEFAULT_MODE = PerturbationMode.IMPROVE_PRECISION DEVICE_LIST = [DeviceType.NPU, DeviceType.CPU] @@ -57,7 +59,7 @@ class PytorchFreeBenchmarkConst: FUZZ_LEVEL_LIST = [FuzzLevel.BASE_LEVEL] DEFAULT_FUZZ_LEVEL = FuzzLevel.BASE_LEVEL FUZZ_STAGE_LIST = [Const.FORWARD, Const.BACKWARD] - FIX_MODE_LIST = [PerturbationMode.IMPROVE_PRECISION, PerturbationMode.TO_CPU] + FIX_MODE_LIST = [PerturbationMode.IMPROVE_PRECISION, PerturbationMode.TO_CPU, PerturbationMode.AUTO] DEFAULT_FUZZ_STAGE = Const.FORWARD DEFAULT_PREHEAT_STEP = 15 DEFAULT_MAX_SAMPLE = 20 diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py index 391c2ceaca09dc5a53e475fffe9429fe3052ff7f..3dbded07e940efc6b9e0e8f4e98ded9f140df4e4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py @@ -16,7 +16,7 @@ import torch from msprobe.core.common.exceptions import FreeBenchmarkException from msprobe.pytorch.free_benchmark.common.enums import DeviceType - +from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode class Tools: @@ -74,8 +74,11 @@ class Tools: return tensor_seq @staticmethod - def convert_fuzz_output_to_origin(origin, perturbed): + def convert_fuzz_output_to_origin(origin, perturbed, pert_mode): if isinstance(origin, torch.Tensor) and isinstance(perturbed, torch.Tensor): + if pert_mode == PerturbationMode.AUTO: + origin = perturbed + return origin origin.data = perturbed.to(origin.dtype).to(origin.device) return origin if isinstance(origin, dict) and isinstance(perturbed, dict): diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py index 66d7b7e10429dbfb939cdfa005422ce4f8e48f99..1d3d6fdfe37d93fca3cd9003475a2c32996bb45b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py @@ -88,6 +88,9 @@ class FreeBenchmarkCheck(ABC): layer.handle(data_params) handler_params = make_handler_params(name, self.config, self.current_iter) handler = FuzzHandlerFactory.create(handler_params) + if handler_params.pert_mode == PerturbationMode.AUTO: + perturbed_output = handler.handle(data_params, handler_params.pert_mode) + return perturbed_output, handler.get_unequal_rows() perturbed_output = handler.handle(data_params) return perturbed_output, handler.get_unequal_rows() diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py index 79256cd4063fb1b2db231fd53242bb725209b132..15005ecd394101bf412674c046789dacf1103116 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py @@ -25,7 +25,7 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision impor ) from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer from msprobe.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer - +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.auto_fix import AutoLayer class LayerFactory: layers = { @@ -35,6 +35,7 @@ class LayerFactory: PerturbationMode.NO_CHANGE: NoChangeLayer, PerturbationMode.BIT_NOISE: BitNoiseLayer, PerturbationMode.IMPROVE_PRECISION: ImprovePrecisionLayer, + PerturbationMode.AUTO: AutoLayer, }, DeviceType.CPU: {PerturbationMode.TO_CPU: CpuLayer}, } diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/auto_fix.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/auto_fix.py new file mode 100644 index 0000000000000000000000000000000000000000..4941af49719ec10e97dd62609325318a15e111c0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/auto_fix.py @@ -0,0 +1,276 @@ +# # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from typing import Any, Callable, Dict, List, Optional, Tuple + +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import NpuBaseLayer +from msprobe.pytorch.free_benchmark.common.utils import Tools +from msprobe.pytorch.free_benchmark.common.enums import DeviceType +from msprobe.core.common.const import Const +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.constant import CommonField +from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode +from msprobe.pytorch.free_benchmark.common.params import DataParams + +class ScaleConst: + """ + Class for ScaleLayer's const + """ + SOFTMAX_NAME = "softmax" + LINEAR_NAME = "linear" + MATMUL_NAME = "matmul" + BMM_NAME = "bmm" + + FP16_EPS = torch.finfo(torch.float16).tiny # TODO dtype fix + FP16_UB = torch.finfo(torch.float16).max + + + SQRT_UB = torch.tensor([np.sqrt(FP16_UB)], dtype=torch.float16).npu() + SQRT_UB_INV = torch.tensor([1 / np.sqrt(FP16_UB)], dtype=torch.float16).npu() + + COMMUNICATION_NAMES = ["all_reduce","all_gather","reduce_scatter"] + +class AutoLayer(NpuBaseLayer): + def check_catastrophe(self, tensor_obj): + if isinstance(tensor_obj, torch.Tensor): + if torch.all(tensor_obj.eq(0)): + return True + if torch.isinf(tensor_obj).any(): + return True + if torch.isnan(tensor_obj).any(): + return True + return False + if isinstance(tensor_obj, dict): + return any(self.check_catastrophe(value) for value in tensor_obj.values()) + if isinstance(tensor_obj, (tuple, list)): + return any(self.check_catastrophe(value) for value in tensor_obj) + return False + + def tensor_scale(self, tensor_obj,unscale=False): + if isinstance(tensor_obj, torch.Tensor): + if(unscale): + tensor_obj = self._unscale(tensor_obj) + else: + tensor_obj = self._scale(tensor_obj) + self.is_added = True + return tensor_obj + if isinstance(tensor_obj, dict): + return { + key: self.tensor_scale(value) + for key, value in tensor_obj.items() + } + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)( + [self.tensor_scale(value) for value in tensor_obj] + ) + return tensor_obj + + def tensor_contiguous(self, tensor_obj): + if isinstance(tensor_obj, torch.Tensor): + return tensor_obj.contiguous() + if isinstance(tensor_obj, dict): + return { + key: self.tensor_contiguous(value) + for key, value in tensor_obj.items() + } + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)( + [self.tensor_contiguous(value) for value in tensor_obj] + ) + return tensor_obj + + def improve_tensor_precision(self, tensor_obj): + if ( + isinstance(tensor_obj, torch.Tensor) + and torch.is_floating_point(tensor_obj) + and tensor_obj.dtype not in [torch.float32, torch.float64] + ): + self._set_improve_values(tensor_obj) + tensor_obj = self._change_dtype(tensor_obj) + self.is_added = True + return tensor_obj + if isinstance(tensor_obj, dict): + return { + key: self.improve_tensor_precision(value) + for key, value in tensor_obj.items() + } + if isinstance(tensor_obj, (tuple, list)): + return type(tensor_obj)( + [self.improve_tensor_precision(value) for value in tensor_obj] + ) + return tensor_obj + + def handle(self, params: DataParams): + is_scale_applicable = ( + ScaleConst.SOFTMAX_NAME in self.api_name or + ScaleConst.LINEAR_NAME in self.api_name or + ScaleConst.MATMUL_NAME in self.api_name or + ScaleConst.BMM_NAME in self.api_name + ) + + self.scale_factor = 1.0 + params.perturbed_result = params.original_result + if not self.check_catastrophe(params.perturbed_result): + return params.perturbed_result + logger.info_on_rank_0( + f"[msprobe] Free benchmark: An Problem shows here. " + ) + + #! Try Scale + if is_scale_applicable: + logger.info_on_rank_0( + f"[msprobe] Free benchmark: Perturbation is " + f"{PerturbationMode.AUTO} of {self.api_name}. " + f"Trying Scale for this." + ) + new_args = self.tensor_scale(params.args) + params.perturbed_result = params.origin_func( + *new_args, **params.kwargs) + + if (ScaleConst.SOFTMAX_NAME in self.api_name): + params.perturbed_result = self.tensor_scale( + params.perturbed_result, unscale=True) + try: + new_args1 = params.perturbed_result, *new_args[1:] + params.perturbed_result = params.origin_func(*new_args1, + **params.kwargs) + except KeyError as e: + logger.info_on_rank_0( + f"[msprobe] Free benchmark: Something was wrong during softmax recalc!!!") + else: + params.perturbed_result = self.tensor_scale( + params.perturbed_result, unscale=True) + + if not self.check_catastrophe(params.perturbed_result): + logger.info_on_rank_0( + f"[msprobe] Free benchmark: Autofix-'Scaler' is Useful, " + f"Problem solved." + ) + return params.perturbed_result + + #! Try improve precision + if self.check_catastrophe(params.perturbed_result): + logger.info_on_rank_0( + f"[msprobe] Free benchmark: 'Scaler' is Useless. " + f"Trying 'improve precision' for " + f"{PerturbationMode.AUTO} of {self.api_name}." + ) + new_args = self.improve_tensor_precision(params.args) + if params.fuzz_stage == Const.BACKWARD: + new_kwargs = {} + else: + new_kwargs = self.improve_tensor_precision(params.kwargs) + if "inplace" in new_kwargs: + new_kwargs["inplace"] = False + params.perturbed_result = params.origin_func(*new_args, **new_kwargs) + + if not self.check_catastrophe(params.perturbed_result): + logger.info_on_rank_0( + f"[msprobe] Free benchmark: Autofix-'Improve Precision' is Useful, " + f"Problem solved." + ) + return params.perturbed_result + + #! Try Synchronize + if self.check_catastrophe(params.perturbed_result): + logger.info_on_rank_0( + f"[msprobe] Free benchmark: 'Improve Precision' is Useless " + f"Trying 'Synchronize' for " + f"{PerturbationMode.AUTO} of {self.api_name}." + ) + torch_npu.npu.synchronize() + params.perturbed_result = params.origin_func(*params.args, **params.kwargs) + torch_npu.npu.synchronize() + + if not self.check_catastrophe(params.perturbed_result): + logger.info_on_rank_0( + f"[msprobe] Free benchmark: Autofix-'Synchronize' is Useful, " + f"Problem solved." + ) + return params.perturbed_result + + #! Try Contiguous + if self.check_catastrophe(params.perturbed_result): + logger.info_on_rank_0( + f"[msprobe] Free benchmark: 'Synchronize' is Useless, too. " + f"Trying 'Contiguous' for" + f"{PerturbationMode.AUTO} of {self.api_name}." + ) + new_args = self.tensor_contiguous(params.args) + new_kwargs = self.tensor_contiguous(params.kwargs) + params.perturbed_result = params.origin_func(*new_args, **new_kwargs) + + if not self.check_catastrophe(params.perturbed_result): + logger.info_on_rank_0( + f"[msprobe] Free benchmark: Autofix-'Contiguous' is Useful, " + f"Problem solved." + ) + return params.perturbed_result + + #! Hint to 'tocpu' + if self.check_catastrophe(params.perturbed_result): + logger.info_on_rank_0( + f"[msprobe] Free benchmark: 'Contiguous' is Useless, too. " + f"Please set pert_mode to 'To_cpu' for further check." + ) + return params.original_result + + def _get_scale_factor(self, inputs): + nominator = ScaleConst.SQRT_UB + x_norm = torch.norm(inputs, p=1, dim=-1).max() + if(torch.isfinite(x_norm).all()): + denominator = torch.maximum(ScaleConst.SQRT_UB, x_norm) + else: + # if L1 norm of inputs is inf scale to 1 / sqrt(FP16_UB) + return ScaleConst.SQRT_UB_INV.to(torch.get_device(inputs)) + computed_scale = (nominator / denominator).to(torch.get_device(inputs)) + return computed_scale + + def _scale(self, inputs): + cur_scale = self._get_scale_factor(inputs) + self.scale_factor = max(ScaleConst.FP16_EPS, + cur_scale * self.scale_factor) + scaled_inputs = inputs * cur_scale + return scaled_inputs + + def _unscale(self, output): + if (ScaleConst.SOFTMAX_NAME in self.api_name): + unscaled_outputs = (torch.log(output + ScaleConst.FP16_EPS) / self.scale_factor) + return unscaled_outputs + else: + rescale_factor = max(self.scale_factor, (torch.max(output)) / ScaleConst.SQRT_UB) + unscaled_outputs = output / rescale_factor + return unscaled_outputs + + def _set_improve_values(self, inputs): + if inputs.dtype in [torch.float16, torch.bfloat16]: + self.perturbed_value = torch.float32 + + def _change_dtype(self, inputs): + if hasattr(inputs, CommonField.DEVICE): + device = inputs.device + if device is CommonField.META: + new_inputs = inputs.to( + device=CommonField.META, dtype=self.perturbed_value + ) + else: + new_inputs = inputs.to(dtype=self.perturbed_value).to(device) + else: + new_inputs = inputs.to(dtype=self.perturbed_value) + return new_inputs \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py index d0b918402dd28d2c9f92070dceb1ea1e5e1d3d53..b70b0b6bce67f9dc10cc15c2f387d6c546b8f3d4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py @@ -20,17 +20,17 @@ from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.params import DataParams from msprobe.pytorch.free_benchmark.common.utils import Tools from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler - +from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode class FixHandler(FuzzHandler): def get_threshold(self, dtype): return self._get_default_threshold(dtype) - def handle(self, data_params: DataParams) -> Any: + def handle(self, data_params: DataParams, pert_mode: PerturbationMode = None) -> Any: try: return Tools.convert_fuzz_output_to_origin( - data_params.original_result, data_params.perturbed_result + data_params.original_result, data_params.perturbed_result, pert_mode ) except FreeBenchmarkException as e: logger.warning(