From c2874dbabd2f7f460624d61d85d26ae40ff776e8 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 2 Jul 2025 16:08:51 +0800 Subject: [PATCH 01/23] support fp8 --- .../msprobe/core/common/const.py | 2 +- .../api_accuracy_checker/common/utils.py | 28 +++ .../api_accuracy_checker/compare/algorithm.py | 29 ++++ .../api_accuracy_checker/compare/compare.py | 6 +- .../compare/compare_utils.py | 3 +- .../precision_standard/standard_register.py | 3 +- .../run_ut/data_generate.py | 38 +++- .../run_ut/run_overflow_check.py | 2 +- .../api_accuracy_checker/run_ut/run_ut.py | 4 +- .../run_ut/run_ut_utils.py | 164 +++++++++++++++++- .../run_ut/test_run_ut.py | 2 +- 11 files changed, 266 insertions(+), 15 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 4af7b5a855..d69c8c50bb 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -154,7 +154,7 @@ class Const: STACK = "stack" ATEN = "Aten" - MODULE_WHITE_LIST = ["torch", "numpy"] + MODULE_WHITE_LIST = ["torch", "numpy", "torch_npu"] FUNC_SKIP_LIST = ["construct", "__call__"] FILE_SKIP_LIST = ["msprobe", "MindSpeed"] diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py index 5724f62623..72c4745b5e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py @@ -259,3 +259,31 @@ def get_attribute(module_name, attribute_name): logger.error(f"Failed to get attribute {attribute_name} from module {module_name}: {e}") raise CompareException(CompareException.INVALID_ATTRIBUTE_ERROR) from e return attribute + + +def is_dtype_fp8(dtype): + """ + Function Description: + Check if the data type is float8. + Parameter: + dtype: Data type. + Return: + True or False. + """ + if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + return True + return False + + +def is_dtype_hif8(dtype): + """ + Function Description: + Check if the data type is HiFloat8Tensor. + Parameter: + dtype: Data type. + Return: + True or False. + """ + if str(dtype) == "": + return True + return False diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index ddee254c2b..2b3a99fea7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -22,6 +22,7 @@ import numpy as np from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig +from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8 from msprobe.core.common.const import CompareConst @@ -212,6 +213,8 @@ def check_norm_value(normal_value_mask, rel_err, rtol): def get_ulp_err(bench_output, device_output, dtype): + if is_dtype_fp8(dtype): + return calc_ulp_err_fp8(bench_output, device_output) parameters = ULP_PARAMETERS.get(dtype) min_eb = parameters.get('min_eb')[0] exponent_num = parameters.get('exponent_num')[0] @@ -227,6 +230,32 @@ def get_ulp_err(bench_output, device_output, dtype): return ulp_err +def calc_ulp_err_fp8(bench_output, device_output): + # compute ulp error of FP8 + x = np.float64(bench_output) + H8 = np.float64(device_output) + + EX = np.log2(abs(x) + 2**(-1000)) + EX[EX < -22] = -22 + E = np.floor(EX) # Exponent + + Eabs = np.abs(E) + Wm = np.zeros_like(x) # Mantissa width Init + Wm[Eabs <= 15] = 1 + Wm[Eabs <= 7 ] = 2 + Wm[Eabs <= 3 ] = 3 + ulp_err = (H8 - x) * 2 ** (-E + Wm) # for Wm = 1~3 + + S_EX = EX * np.where(x >= 0, 1, -1) + EH = np.log2(abs(H8) + 2**(-1000)) + + S_EH = EH * np.where(H8 >= 0, 1, -1) + ulp_err1 = S_EH - S_EX # for Wm = 0 + + ulp_err[Wm == 0] = ulp_err1[Wm == 0] # Merge 2 cases + return ulp_err + + def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type): return (device_output.astype(data_type) - bench_output).astype(data_type) * \ np.exp2(-eb + exponent_num).astype(data_type) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index c12a54c18a..058ac105b1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -360,8 +360,8 @@ class Comparator: compare_column, npu_dtype) return status, compare_column, message - def _perform_comparison(self, api_name, input_data): - comparison_func = self.registry.get_comparison_function(api_name, None) + def _perform_comparison(self, api_name, input_data, dtype): + comparison_func = self.registry.get_comparison_function(api_name, dtype) comparison_func(input_data) def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype): @@ -371,7 +371,7 @@ class Comparator: rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) input_data = CompareInput(bench_output, device_output, compare_column, dtype, rel_err_orign) if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: - self._perform_comparison(api_name, input_data) + self._perform_comparison(api_name, input_data, dtype) else: message += f"The data type {dtype} is not supported for new precision standard." diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py index 89c4401b2c..d16efcd616 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -30,7 +30,8 @@ from msprobe.pytorch.common.log import logger current_time = time.strftime("%Y%m%d%H%M%S") API_PRECISION_COMPARE_RESULT_FILE_NAME = "api_precision_compare_result_" + current_time + ".csv" API_PRECISION_COMPARE_DETAILS_FILE_NAME = "api_precision_compare_details_" + current_time + ".csv" -BENCHMARK_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32'] +BENCHMARK_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32', "torch.float8_e4m3fn", + "torch.float8_e5m2"] API_PRECISION_COMPARE_UNSUPPORT_LIST = ['torch.float64', 'torch.complex64', 'torch.complex128'] ULP_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32'] BINARY_COMPARE_UNSUPPORT_LIST = BENCHMARK_COMPARE_SUPPORT_LIST + API_PRECISION_COMPARE_UNSUPPORT_LIST diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py index 82df8c54e8..f9a9c388ff 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py @@ -96,7 +96,8 @@ class StandardRegistry: The BINARY_COMPARE_UNSUPPORT_LIST should be defined and contain all data types that are not supported for binary comparison. """ - if dtype and dtype not in BINARY_COMPARE_UNSUPPORT_LIST: + + if dtype not in BINARY_COMPARE_UNSUPPORT_LIST: return CompareConst.BINARY_CONSISTENCY for name, category in self.api_standard_function_map.items(): if api_name in category: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 9d89b2de32..3e0d6e57f8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -20,9 +20,16 @@ import math import torch import numpy +try: + import torch_npu +except ImportError: + IS_GPU = True +else: + IS_GPU = False + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \ - CompareException, get_module_and_atttribute_name, get_attribute + CompareException, get_module_and_atttribute_name, get_attribute, is_dtype_fp8, is_dtype_hif8 from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import load_pt @@ -38,7 +45,10 @@ FLOAT_TYPE = [ 'torch.double', 'torch.float16', 'torch.half', - 'torch.bfloat16' + 'torch.bfloat16', + 'torch.float8_e4m3fn', + 'torch.float8_e5m2', + 'torch_npu.HiFloat8Tensor' ] NUMPY_TYPE = [ "numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32", @@ -69,10 +79,15 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): if api_name in hf_32_standard_api and data.dtype == torch.float32: data = fp32_to_hf32_to_fp32(data) if info.get('requires_grad') and need_grad: + origin_dtype = data.dtype + if is_dtype_fp8(origin_dtype): + data = data.to(torch.float32) data.requires_grad_(True) temp_data = data * 1 data = temp_data.type_as(data) data.retain_grad() + if is_dtype_fp8(origin_dtype): + data = data.to(origin_dtype) elif data_type.startswith("numpy"): if data_type not in NUMPY_TYPE: raise Exception("{} is not supported now".format(data_type)) @@ -196,8 +211,13 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type): tensor = torch.full(shape, high, dtype=dtype) tensor[-1] = low return tensor + low_scale, high_scale = low, high - dtype_finfo = torch.finfo(dtype) + if is_dtype_hif8(dtype): + finfo_dtype = torch.float32 + else: + finfo_dtype = dtype + dtype_finfo = torch.finfo(finfo_dtype) #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩 if high == float(CompareConst.INF): high_scale = dtype_finfo.max @@ -209,7 +229,11 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type): low_scale = dtype_finfo.min scale = high_scale - low_scale - rand01 = torch.rand(shape, dtype=dtype) + if is_dtype_fp8(dtype) or is_dtype_hif8(dtype): + generate_dtype = torch.float32 + else: + generate_dtype = dtype + rand01 = torch.rand(shape, dtype=generate_dtype) tensor = rand01 * scale + low_scale elif 'int' in data_dtype or 'long' in data_dtype: low, high = int(low), int(high) @@ -236,6 +260,12 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type): if low_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]: tmp_tensor[0] = low_origin data = tmp_tensor.reshape(shape) + if is_dtype_fp8(dtype): + data = data.to(dtype) + if is_dtype_hif8(dtype): + if IS_GPU: + raise CompareException("GPU does not support torch_npu.HiFloat8Tensor") + data = torch_npu.HiFloat8Tensor.to_hifloat8(data) return data diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index 0f184d14b6..422840f7aa 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -125,7 +125,7 @@ def run_torch_api(api_full_name, api_info_dict, real_data_path): device_info_kwargs = kwargs.get(Const.DEVICE) if device_info_kwargs and device_info_kwargs.get(Const.VALUE): kwargs[Const.DEVICE] = current_device - npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name) + npu_args, npu_kwargs, _ = generate_device_params(args, kwargs, False, api_name) if kwargs.get(Const.DEVICE): del kwargs[Const.DEVICE] cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs, False, None) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 082f391c95..54ddd1808b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -269,7 +269,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict device_info_kwargs = kwargs.get(Const.DEVICE) if device_info_kwargs and device_info_kwargs.get(Const.VALUE): kwargs[Const.DEVICE] = current_device - device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name) + device_args, device_kwargs, is_fp8 = generate_device_params(args, kwargs, need_backward, api_name) if kwargs.get(Const.DEVICE): del kwargs[Const.DEVICE] cpu_params = generate_cpu_params(args, kwargs, need_backward, api_name) @@ -284,6 +284,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict device_exec_params = ExecParams(api_type, api_name, current_device, device_args, device_kwargs, is_autocast, autocast_dtype) device_out = exec_api(device_exec_params) + if is_fp8 and isinstance(device_out, torch.Tensor) and device_out.dtype == torch.float32: + device_out = device_out.to(torch.float16) current_path = os.path.dirname(os.path.realpath(__file__)) ut_setting_path = os.path.join(current_path, "torch_ut_setting.json") api_setting_dict = get_json_contents(ut_setting_path) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index 60557c77d7..9266cef1dd 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -16,7 +16,7 @@ import os from collections import namedtuple import re - +import numpy as np import torch try: import torch_npu @@ -33,6 +33,8 @@ from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate +from msprobe.pytorch.common.utils import is_hifloat8_tensor +from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8 hf_32_standard_api = ["conv1d", "conv2d"] @@ -148,6 +150,10 @@ def raise_bench_data_dtype(api_name, arg, raise_dtype=None): 输出: arg: 转换dtype的标杆输入 ''' + if is_hifloat8_tensor(arg): + return hif8_to_fp32(arg) + if is_dtype_fp8(arg.dtype): + return fp8_to_fp32(arg) if api_name in hf_32_standard_api and arg.dtype == torch.float32: return arg if raise_dtype is None or arg.dtype not in PRECISION_MAPPING or raise_dtype == arg.dtype: @@ -156,13 +162,17 @@ def raise_bench_data_dtype(api_name, arg, raise_dtype=None): def generate_device_params(input_args, input_kwargs, need_backward, api_name): + is_fp8 = False def recursive_arg_to_device(arg_in, to_detach, depth=0): + nonlocal is_fp8 if depth > Const.MAX_DEPTH: logger.error("The depth of arg_in is too large, please check the arg_in.") raise CompareException(CompareException.RECURSION_LIMIT_ERROR) if isinstance(arg_in, (list, tuple)): return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth+1) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): + if is_dtype_fp8(arg_in.dtype) or (not IS_GPU and isinstance(arg_in, torch_npu.HiFloat8Tensor)): + is_fp8 = True if need_backward and arg_in.requires_grad: arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_() temp_arg_in = arg_in * 1 @@ -178,7 +188,7 @@ def generate_device_params(input_args, input_kwargs, need_backward, api_name): device_args = recursive_arg_to_device(input_args, is_detach) device_kwargs = \ {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()} - return device_args, device_kwargs + return device_args, device_kwargs, is_fp8 def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): @@ -260,3 +270,153 @@ def is_unsupported_api(api_name, is_overflow_check=False): if flag: logger.info(f"{split_name} api is not supported for run ut. SKIP.") return flag + + +def fp8_to_fp32(x: torch.Tensor) -> torch.Tensor: + """ + 将FP8格式的张量转换为FP32格式,保持原始FP8的表示范围 + 使用纯PyTorch操作替代NumPy位运算 + + 参数: + x: 输入的FP8张量,可以是torch.float8_e4m3fn或torch.float8_e5m2类型 + + 返回: + torch.Tensor: 转换后的FP32张量,保持原始FP8的表示范围 + """ + if x.dtype == torch.float8_e4m3fn: + # E4M3FN格式:1符号+4指数+3尾数,偏置7 + # 位布局:SEEEEEMM + + # 将FP8值视为无符号整数进行位操作 + x_int = x.view(torch.uint8) + + # 提取符号位、指数位和尾数位 + sign_bits = (x_int & 0x80) >> 7 # 最高位是符号位 + exp_bits = (x_int & 0x78) >> 3 # 接下来4位是指数位 + mantissa_bits = x_int & 0x07 # 最后3位是尾数位 + + # 处理规格化数和非规格化数 + is_normal = exp_bits != 0 + + # 计算FP32的指数部分(偏置127) + fp32_exp = torch.where( + is_normal, + (exp_bits - 7 + 127).to(torch.int32), # 规格化数:指数 = 原始指数 + 120 + torch.tensor(0, dtype=torch.int32, device=x.device) # 非规格化数:指数为0 + ) + + # 计算FP32的尾数部分 + # 规格化数:隐含1,尾数 = 1.0 + 原始尾数 * 2^(-3) + # 非规格化数:无隐含1,尾数 = 0.0 + 原始尾数 * 2^(-3) + fp32_mantissa = torch.where( + is_normal, + 1.0 + mantissa_bits.to(torch.float32) / 8.0, # 2^(-3) = 1/8 + mantissa_bits.to(torch.float32) / 8.0 + ) + + # 计算符号值 (-1)^sign + sign_value = torch.pow(-1.0, sign_bits.to(torch.float32)) + + # 计算最终FP32值 + # 规格化数:value = (-1)^sign * (1.0 + mantissa/8) * 2^(exp - 7) + # 非规格化数:value = (-1)^sign * (mantissa/8) * 2^(-6) + fp32_result = sign_value * fp32_mantissa * torch.pow(2.0, fp32_exp - 127) + + return fp32_result + + elif x.dtype == torch.float8_e5m2: + # E5M2格式:1符号+5指数+2尾数,偏置15 + # 位布局:SEEEEEEM + + # 将FP8值视为无符号整数进行位操作 + x_int = x.view(torch.uint8) + + # 提取符号位、指数位和尾数位 + sign_bits = (x_int & 0x80) >> 7 # 最高位是符号位 + exp_bits = (x_int & 0x7C) >> 2 # 接下来5位是指数位 + mantissa_bits = x_int & 0x03 # 最后2位是尾数位 + + # 处理规格化数和非规格化数 + is_normal = exp_bits != 0 + + # 计算FP32的指数部分(偏置127) + fp32_exp = torch.where( + is_normal, + (exp_bits - 15 + 127).to(torch.int32), # 规格化数:指数 = 原始指数 + 112 + torch.tensor(0, dtype=torch.int32, device=x.device) # 非规格化数:指数为0 + ) + + # 计算FP32的尾数部分 + # 规格化数:隐含1,尾数 = 1.0 + 原始尾数 * 2^(-2) + # 非规格化数:无隐含1,尾数 = 0.0 + 原始尾数 * 2^(-2) + fp32_mantissa = torch.where( + is_normal, + 1.0 + mantissa_bits.to(torch.float32) / 4.0, # 2^(-2) = 1/4 + mantissa_bits.to(torch.float32) / 4.0 + ) + + # 计算符号值 (-1)^sign + sign_value = torch.pow(-1.0, sign_bits.to(torch.float32)) + + # 计算最终FP32值 + fp32_result = sign_value * fp32_mantissa * torch.pow(2.0, fp32_exp - 127) + + return fp32_result + + else: + raise ValueError(f"Unsupported dtype: {x.dtype}. Expected torch.float8_e4m3fn or torch.float8_e5m2.") + +def hif8_to_fp32(x: torch_npu.HiFloat8Tensor) -> torch.Tensor: + """ + 将HiFloat8格式的张量转换为FP32格式,保持原始HiFloat8的表示范围 + 使用纯PyTorch操作替代NumPy位运算 + + 参数: + x: 输入的HiFloat8张量,可以是torch_npu.HiFloat8Tensor类型 + + 返回: + torch.Tensor: 转换后的FP32张量,保持原始HiFloat8的表示范围 + """ + x = np.array(x) # 确保输入是numpy数组 + + # 创建结果数组 + res = np.zeros_like(x, dtype=np.float32) + + # 获取矩阵形状并遍历每个元素 + M, N = x.shape + for i in range(M): + for j in range(N): + z = x[i, j] + + # 处理特殊值 + if np.isnan(z) or np.isinf(z): + res[i, j] = z + continue + + # 提取符号位 + s = 1.0 if z >= 0 else -1.0 + tmp = abs(z) + + # 处理零值 + if tmp == 0: + res[i, j] = 0.0 + continue + + # 确定指数范围和尾数位数 + E = np.floor(np.log2(tmp + 1e-100)) # 添加小常量避免log2(0) + absE = abs(E) + + # 根据指数范围确定尾数位数和还原规则 + if absE <= 3: # 3-bit Mantissa + mantissa = (tmp / (2.0 ** E)) * 8.0 # 还原尾数部分 + res[i, j] = s * (mantissa / 8.0) * (2.0 ** E) + elif absE <= 7: # 2-bit Mantissa + mantissa = (tmp / (2.0 ** E)) * 4.0 + res[i, j] = s * (mantissa / 4.0) * (2.0 ** E) + elif absE <= 15: # 1-bit Mantissa + mantissa = (tmp / (2.0 ** E)) * 2.0 + res[i, j] = s * (mantissa / 2.0) * (2.0 ** E) + else: # 0-bit Mantissa + res[i, j] = s * (2.0 ** E) + + return torch.from_numpy(res) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py index 13bf0a5b19..0ab34b9288 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py @@ -170,7 +170,7 @@ class TestRunUtMethods(unittest.TestCase): mocks['retain_grad'].return_value = None mocks['to'].return_value = mock_tensor - device_args, device_kwargs = generate_device_params([mock_tensor], {'inplace': False}, True, '') + device_args, device_kwargs, _ = generate_device_params([mock_tensor], {'inplace': False}, True, '') self.assertEqual(len(device_args), 1) self.assertEqual(device_args[0].dtype, torch.float32) self.assertTrue(device_args[0].requires_grad) -- Gitee From 2adac53fa52d2ff6a6120b859a0bf58bb3c8ad13 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 2 Jul 2025 16:27:12 +0800 Subject: [PATCH 02/23] support fp8 --- .../api_accuracy_checker/common/utils.py | 40 +++++++++++++++---- .../api_accuracy_checker/compare/compare.py | 1 + .../precision_standard/standard_register.py | 12 +++--- 3 files changed, 40 insertions(+), 13 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py index 72c4745b5e..883aafda5a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py @@ -266,13 +266,16 @@ def is_dtype_fp8(dtype): Function Description: Check if the data type is float8. Parameter: - dtype: Data type. + dtype: Data type (torch.dtype or string). Return: True or False. """ - if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - return True - return False + # 处理字符串类型的 dtype + if isinstance(dtype, str): + return dtype in ["float8_e4m3fn", "float8_e5m2"] + + # 处理 torch.dtype 类型 + return dtype in [torch.float8_e4m3fn, torch.float8_e5m2] def is_dtype_hif8(dtype): @@ -280,10 +283,31 @@ def is_dtype_hif8(dtype): Function Description: Check if the data type is HiFloat8Tensor. Parameter: - dtype: Data type. + dtype: Data type (string). Return: True or False. """ - if str(dtype) == "": - return True - return False + # 处理字符串类型的 dtype + if isinstance(dtype, str): + dtype_str = dtype + # 处理类对象或 dtype 对象 + else: + dtype_str = str(dtype) + + # 检查是否匹配 HiFloat8Tensor 的字符串表示 + return ( + dtype_str == "" or + dtype_str == "torch_npu.HiFloat8Tensor" + ) + + +def is_dtype_fp8_or_hif8(dtype): + """ + Function Description: + Check if the data type is FP8 (native or HiFloat8 variant). + Parameters: + dtype: Data type (torch.dtype, class, or string). + Returns: + True if dtype is FP8 or HiFloat8, False otherwise. + """ + return is_dtype_fp8(dtype) or is_dtype_hif8(dtype) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index 058ac105b1..4bd46cc044 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -332,6 +332,7 @@ class Comparator: def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column): cpu_shape = bench_output.shape npu_shape = device_output.shape + npu_dtype = device_output.dtype if npu_dtype == torch.bfloat16: bench_output = bench_output.to(torch.float32) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py index f9a9c388ff..daabc383e9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py @@ -16,6 +16,7 @@ # limitations under the License. from typing import Callable +from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8_or_hif8 from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import absolute_standard_api, binary_standard_api, \ ulp_standard_api, thousandth_standard_api, accumulative_error_standard_api, BINARY_COMPARE_UNSUPPORT_LIST from msprobe.core.common.const import CompareConst @@ -70,11 +71,11 @@ class StandardRegistry: raise ValueError("The function to be registered must be callable.") self.comparison_functions[standard] = func - def get_comparison_function(self, api_name, dtype=None): + def get_comparison_function(self, api_name, dtype): standard = self._get_standard_category(api_name, dtype) return self.comparison_functions.get(standard) - def _get_standard_category(self, api_name, dtype=None): + def _get_standard_category(self, api_name, dtype): """ Determines the standard category for a given API name and data type. @@ -84,7 +85,7 @@ class StandardRegistry: Args: api_name (str): The name of the API for which to determine the standard category. - dtype (type, optional): The data type to check against the BINARY_COMPARE_UNSUPPORT_LIST. Defaults to None. + dtype (type, optional): The data type to check against the BINARY_COMPARE_UNSUPPORT_LIST. Returns: str: The name of the standard category that matches the API name and data type, or 'benchmark' if no match @@ -96,8 +97,9 @@ class StandardRegistry: The BINARY_COMPARE_UNSUPPORT_LIST should be defined and contain all data types that are not supported for binary comparison. """ - - if dtype not in BINARY_COMPARE_UNSUPPORT_LIST: + if is_dtype_fp8_or_hif8(dtype): + return CompareConst.ULP_COMPARE + if str(dtype) not in BINARY_COMPARE_UNSUPPORT_LIST: return CompareConst.BINARY_CONSISTENCY for name, category in self.api_standard_function_map.items(): if api_name in category: -- Gitee From e6a24d05d0e48a15c17837e6fae330137413f7b7 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 2 Jul 2025 16:28:51 +0800 Subject: [PATCH 03/23] support fp8 --- .../pytorch/api_accuracy_checker/run_ut/run_ut_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index 9266cef1dd..3821ca526c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -272,7 +272,7 @@ def is_unsupported_api(api_name, is_overflow_check=False): return flag -def fp8_to_fp32(x: torch.Tensor) -> torch.Tensor: +def fp8_to_fp32(x): """ 将FP8格式的张量转换为FP32格式,保持原始FP8的表示范围 使用纯PyTorch操作替代NumPy位运算 @@ -366,7 +366,7 @@ def fp8_to_fp32(x: torch.Tensor) -> torch.Tensor: else: raise ValueError(f"Unsupported dtype: {x.dtype}. Expected torch.float8_e4m3fn or torch.float8_e5m2.") -def hif8_to_fp32(x: torch_npu.HiFloat8Tensor) -> torch.Tensor: +def hif8_to_fp32(x): """ 将HiFloat8格式的张量转换为FP32格式,保持原始HiFloat8的表示范围 使用纯PyTorch操作替代NumPy位运算 -- Gitee From 8e35a9fb6c44bc3f03e0a1e448ebd3fd18b56378 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 2 Jul 2025 16:54:29 +0800 Subject: [PATCH 04/23] bugfix --- .../api_accuracy_checker/compare/api_precision_compare.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 55e93d271c..6f2af45bec 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -175,6 +175,8 @@ def analyse_csv(npu_data, gpu_data, config): try: new_status = get_api_status(row_npu, row_gpu, api_name, compare_column, registry) except Exception as err: + import traceback + traceback.print_exc() logger.error(f"Get api status error: {str(err)}") compare_column.api_name = full_api_name_with_direction_status compare_column.compare_result = CompareConst.SKIP @@ -459,3 +461,7 @@ def _api_precision_compare_parser(parser): parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The api precision compare task result out path.", required=False) + + +if __name__ == '__main__': + _api_precision_compare() \ No newline at end of file -- Gitee From f3a327c886b2e67b984092536f7cf69e3d059c8f Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 2 Jul 2025 16:58:14 +0800 Subject: [PATCH 05/23] fix bug --- .../msprobe/pytorch/api_accuracy_checker/common/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py index 883aafda5a..1f04342231 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py @@ -272,7 +272,7 @@ def is_dtype_fp8(dtype): """ # 处理字符串类型的 dtype if isinstance(dtype, str): - return dtype in ["float8_e4m3fn", "float8_e5m2"] + return dtype in ["torch.float8_e4m3fn", "torch.float8_e5m2"] # 处理 torch.dtype 类型 return dtype in [torch.float8_e4m3fn, torch.float8_e5m2] -- Gitee From 2bb439626235f99660947b1aebf014da6bc5e04d Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 9 Jul 2025 14:17:41 +0800 Subject: [PATCH 06/23] fix bug --- .../msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py | 1 + .../msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 54ddd1808b..d1a545ea04 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -149,6 +149,7 @@ def run_api_offline(config, compare, api_name_set): if config.save_error_data: do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success) except Exception as err: + import traceback;traceback.print_exc() if "expected scalar type Long" in str(err): logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " "'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index 3821ca526c..dcf0c9a0a1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -23,9 +23,11 @@ try: except ImportError: current_device = "cuda" from torch.cuda.amp import autocast + IS_GPU = True else: current_device = "npu" from torch_npu.npu.amp import autocast + IS_GPU = False from msprobe.core.common.const import FileCheckConst, Const, CompareConst from msprobe.core.common.file_utils import FileChecker -- Gitee From 0026c76fc1897bf85365a8fdc617eba0eab40688 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 10 Jul 2025 09:37:47 +0800 Subject: [PATCH 07/23] fixbug --- .../pytorch/api_accuracy_checker/run_ut/data_generate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 3e0d6e57f8..a56a0d4ff9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -32,7 +32,7 @@ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, CompareException, get_module_and_atttribute_name, get_attribute, is_dtype_fp8, is_dtype_hif8 from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import load_pt +from msprobe.pytorch.common.utils import load_pt, is_hifloat8_tensor from msprobe.core.common.const import Const, FileCheckConst, CompareConst @@ -86,7 +86,7 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): temp_data = data * 1 data = temp_data.type_as(data) data.retain_grad() - if is_dtype_fp8(origin_dtype): + if is_dtype_fp8(origin_dtype) and not is_hifloat8_tensor(data): data = data.to(origin_dtype) elif data_type.startswith("numpy"): if data_type not in NUMPY_TYPE: -- Gitee From f0a8bd59904088964779681225291c6144390a31 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 10 Jul 2025 10:30:29 +0800 Subject: [PATCH 08/23] fix bug --- .../pytorch/api_accuracy_checker/run_ut/data_generate.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index a56a0d4ff9..5896af2d7f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -32,7 +32,7 @@ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, CompareException, get_module_and_atttribute_name, get_attribute, is_dtype_fp8, is_dtype_hif8 from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import load_pt, is_hifloat8_tensor +from msprobe.pytorch.common.utils import load_pt from msprobe.core.common.const import Const, FileCheckConst, CompareConst @@ -79,15 +79,17 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): if api_name in hf_32_standard_api and data.dtype == torch.float32: data = fp32_to_hf32_to_fp32(data) if info.get('requires_grad') and need_grad: - origin_dtype = data.dtype + origin_dtype = info.get('dtype') if is_dtype_fp8(origin_dtype): data = data.to(torch.float32) data.requires_grad_(True) temp_data = data * 1 data = temp_data.type_as(data) data.retain_grad() - if is_dtype_fp8(origin_dtype) and not is_hifloat8_tensor(data): + if is_dtype_fp8(origin_dtype): data = data.to(origin_dtype) + if is_dtype_hif8(origin_dtype): + data = torch_npu.HiFloat8Tensor.to_hifloat8(data) elif data_type.startswith("numpy"): if data_type not in NUMPY_TYPE: raise Exception("{} is not supported now".format(data_type)) -- Gitee From 35705ae311298541a82836a2f149a97c6977ec97 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 10 Jul 2025 10:43:00 +0800 Subject: [PATCH 09/23] fix bug --- .../msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index dcf0c9a0a1..60e7921aa7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -379,6 +379,7 @@ def hif8_to_fp32(x): 返回: torch.Tensor: 转换后的FP32张量,保持原始HiFloat8的表示范围 """ + x = x.cpu().numpy() x = np.array(x) # 确保输入是numpy数组 # 创建结果数组 -- Gitee From cb0547142c2e470275f9341586e313cd6b984533 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 10 Jul 2025 11:30:56 +0800 Subject: [PATCH 10/23] fix bug --- .../pytorch/api_accuracy_checker/run_ut/data_generate.py | 7 +++++-- .../pytorch/api_accuracy_checker/run_ut/run_ut_utils.py | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 5896af2d7f..a22efc3fd4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -32,7 +32,7 @@ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, CompareException, get_module_and_atttribute_name, get_attribute, is_dtype_fp8, is_dtype_hif8 from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import load_pt +from msprobe.pytorch.common.utils import load_pt, is_hifloat8_tensor from msprobe.core.common.const import Const, FileCheckConst, CompareConst @@ -79,7 +79,10 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): if api_name in hf_32_standard_api and data.dtype == torch.float32: data = fp32_to_hf32_to_fp32(data) if info.get('requires_grad') and need_grad: - origin_dtype = info.get('dtype') + if is_hifloat8_tensor(data): + origin_dtype = info.get('dtype') + else: + origin_dtype = data.dtype if is_dtype_fp8(origin_dtype): data = data.to(torch.float32) data.requires_grad_(True) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index 60e7921aa7..fe8a2aab69 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -379,6 +379,7 @@ def hif8_to_fp32(x): 返回: torch.Tensor: 转换后的FP32张量,保持原始HiFloat8的表示范围 """ + requires_grad = x.requires_grad x = x.cpu().numpy() x = np.array(x) # 确保输入是numpy数组 @@ -422,4 +423,7 @@ def hif8_to_fp32(x): else: # 0-bit Mantissa res[i, j] = s * (2.0 ** E) - return torch.from_numpy(res) + res = torch.from_numpy(res) + if requires_grad: + res = res.requires_grad_() + return res -- Gitee From 9eef7ae0d13f0840adce398b46e749d74eb5351c Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 10 Jul 2025 14:40:09 +0800 Subject: [PATCH 11/23] fix bug --- .../msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index fe8a2aab69..6f7de48942 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -380,7 +380,7 @@ def hif8_to_fp32(x): torch.Tensor: 转换后的FP32张量,保持原始HiFloat8的表示范围 """ requires_grad = x.requires_grad - x = x.cpu().numpy() + x = x.cpu().detach().numpy() x = np.array(x) # 确保输入是numpy数组 # 创建结果数组 -- Gitee From 285d0aaf1f86289ce71b05cd772bcbeb8544d425 Mon Sep 17 00:00:00 2001 From: gitee Date: Mon, 14 Jul 2025 09:43:26 +0800 Subject: [PATCH 12/23] fix bug --- .../run_ut/run_ut_utils.py | 75 ++++++++++--------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index 6f7de48942..db785952ff 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -383,45 +383,50 @@ def hif8_to_fp32(x): x = x.cpu().detach().numpy() x = np.array(x) # 确保输入是numpy数组 - # 创建结果数组 + # 创建结果数组,保持与输入相同的形状 res = np.zeros_like(x, dtype=np.float32) - # 获取矩阵形状并遍历每个元素 - M, N = x.shape - for i in range(M): - for j in range(N): - z = x[i, j] - - # 处理特殊值 - if np.isnan(z) or np.isinf(z): - res[i, j] = z - continue - - # 提取符号位 - s = 1.0 if z >= 0 else -1.0 - tmp = abs(z) - - # 处理零值 - if tmp == 0: - res[i, j] = 0.0 - continue + # 获取输入张量的所有维度 + dimensions = x.shape + # 计算总元素数量 + total_elements = np.prod(dimensions) + + # 遍历每个元素 + for idx in range(total_elements): + # 将一维索引转换为多维索引 + multi_indices = np.unravel_index(idx, dimensions) + z = x[multi_indices] + + # 处理特殊值 + if np.isnan(z) or np.isinf(z): + res[multi_indices] = z + continue - # 确定指数范围和尾数位数 - E = np.floor(np.log2(tmp + 1e-100)) # 添加小常量避免log2(0) - absE = abs(E) + # 提取符号位 + s = 1.0 if z >= 0 else -1.0 + tmp = abs(z) + + # 处理零值 + if tmp == 0: + res[multi_indices] = 0.0 + continue - # 根据指数范围确定尾数位数和还原规则 - if absE <= 3: # 3-bit Mantissa - mantissa = (tmp / (2.0 ** E)) * 8.0 # 还原尾数部分 - res[i, j] = s * (mantissa / 8.0) * (2.0 ** E) - elif absE <= 7: # 2-bit Mantissa - mantissa = (tmp / (2.0 ** E)) * 4.0 - res[i, j] = s * (mantissa / 4.0) * (2.0 ** E) - elif absE <= 15: # 1-bit Mantissa - mantissa = (tmp / (2.0 ** E)) * 2.0 - res[i, j] = s * (mantissa / 2.0) * (2.0 ** E) - else: # 0-bit Mantissa - res[i, j] = s * (2.0 ** E) + # 确定指数范围和尾数位数 + E = np.floor(np.log2(tmp + 1e-100)) # 添加小常量避免log2(0) + absE = abs(E) + + # 根据指数范围确定尾数位数和还原规则 + if absE <= 3: # 3-bit Mantissa + mantissa = (tmp / (2.0 ** E)) * 8.0 # 还原尾数部分 + res[multi_indices] = s * (mantissa / 8.0) * (2.0 ** E) + elif absE <= 7: # 2-bit Mantissa + mantissa = (tmp / (2.0 ** E)) * 4.0 + res[multi_indices] = s * (mantissa / 4.0) * (2.0 ** E) + elif absE <= 15: # 1-bit Mantissa + mantissa = (tmp / (2.0 ** E)) * 2.0 + res[multi_indices] = s * (mantissa / 2.0) * (2.0 ** E) + else: # 0-bit Mantissa + res[multi_indices] = s * (2.0 ** E) res = torch.from_numpy(res) if requires_grad: -- Gitee From 538e646466b3ed52a87f89ad55d5a4359734af06 Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 15 Jul 2025 10:56:53 +0800 Subject: [PATCH 13/23] fix bug --- .../compare/api_precision_compare.py | 10 ++++- .../api_accuracy_checker/compare/compare.py | 39 ++++++++++++------- .../precision_standard/standard_register.py | 15 ++++--- .../api_accuracy_checker/run_ut/run_ut.py | 5 ++- .../run_ut/run_ut_utils.py | 3 +- .../tensor_transport_layer/device_dispatch.py | 4 +- 6 files changed, 50 insertions(+), 26 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 6f2af45bec..965fe6315f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -251,11 +251,19 @@ def get_api_status(row_npu, row_gpu, api_name, compare_column, registry): compare_column.api_name = full_api_name_with_direction_status dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] input_data = PrecisionCompareInput(row_npu, row_gpu, dtype, compare_column) - comparison_func = registry.get_comparison_function(api_name, dtype) + in_dtype = get_in_dtype(row_npu) + comparison_func = registry.get_comparison_function(api_name, dtype, in_dtype) new_status = comparison_func(input_data) return new_status +def get_in_dtype(row_npu): + if row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO].isspace(): + return "torch.float32" + else: + return "torch.float8_e4m3fn" + + def print_test_success(api_full_name, forward_result, backward_result): is_fwd_success = (forward_result == CompareConst.PASS) is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index 4bd46cc044..531679cf4d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -180,7 +180,7 @@ class Comparator: write_csv(DETAIL_TEST_ROWS, detail_save_path) @recursion_depth_decorator("compare_core") - def _compare_core(self, api_name, bench_output, device_output): + def _compare_core(self, api_name, bench_output, device_output, is_fp8): compare_column = CompareColumn() if not isinstance(bench_output, type(device_output)): status = CompareConst.ERROR @@ -192,7 +192,7 @@ class Comparator: message = "bench and npu output dict keys are different." else: status, compare_column, message = self._compare_core(api_name, list(bench_output.values()), - list(device_output.values())) + list(device_output.values()), is_fp8) elif isinstance(bench_output, torch.Tensor): copy_bench_out = bench_output.detach().clone() copy_device_output = device_output.detach().clone() @@ -200,7 +200,7 @@ class Comparator: compare_column.npu_type = str(copy_device_output.dtype) compare_column.shape = tuple(device_output.shape) status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output, - compare_column) + compare_column, is_fp8) elif isinstance(bench_output, (bool, int, float, str)): compare_column.bench_type = str(type(bench_output)) compare_column.npu_type = str(type(device_output)) @@ -255,11 +255,12 @@ class Comparator: bench_output, device_output = data_info.bench_output, data_info.device_output bench_grad, device_grad = data_info.bench_grad, data_info.device_grad backward_message = data_info.backward_message + is_fp8 = data_info.is_fp8 if "dropout" in full_api_name: fwd_success_status, fwd_compare_alg_results = self._compare_dropout(bench_output, device_output) else: fwd_success_status, fwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_output, - device_output) + device_output, is_fp8) if not (bench_grad and device_grad): bwd_success_status, bwd_compare_alg_results = (CompareConst.SPACE, []) else: @@ -267,7 +268,7 @@ class Comparator: bwd_success_status, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], device_grad[0]) else: bwd_success_status, bwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_grad, - device_grad) + device_grad, is_fp8) if backward_message: backward_column = CompareColumn() bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)] @@ -297,7 +298,7 @@ class Comparator: registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, self._accumulative_error_compare) return registry - def _compare_core_wrapper(self, api_name, bench_output, device_output): + def _compare_core_wrapper(self, api_name, bench_output, device_output, is_fp8): detailed_result_total = [] test_final_success = CompareConst.PASS if isinstance(bench_output, (list, tuple)): @@ -308,12 +309,12 @@ class Comparator: else: device_output = device_output[:len(bench_output)] for b_out_i, n_out_i in zip(bench_output, device_output): - status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i) + status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i, is_fp8) status.append(status_i) compare_result.append(compare_result_i) message.append(message_i) else: - status, compare_result, message = self._compare_core(api_name, bench_output, device_output) + status, compare_result, message = self._compare_core(api_name, bench_output, device_output, is_fp8) if not isinstance(status, list): detailed_result_total.append(compare_result.to_column_value(status, message)) if status == CompareConst.ERROR: @@ -329,11 +330,15 @@ class Comparator: test_final_success = CompareConst.WARNING return test_final_success, detailed_result_total - def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column): + def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column, is_fp8): cpu_shape = bench_output.shape npu_shape = device_output.shape npu_dtype = device_output.dtype + if is_fp8: + in_dtype = torch.float8_e4m3fn + else: + in_dtype = torch.float32 if npu_dtype == torch.bfloat16: bench_output = bench_output.to(torch.float32) device_output = device_output.to(torch.float32) @@ -357,22 +362,28 @@ class Comparator: compare_column.error_rate = err_rate return status, compare_column, message else: + in_and_out_dtype = { + 'dtype': npu_dtype, + 'in_dtype': in_dtype + } status, compare_column, message = self._compare_float_tensor(api_name, bench_output, device_output, - compare_column, npu_dtype) + compare_column, in_and_out_dtype) return status, compare_column, message - def _perform_comparison(self, api_name, input_data, dtype): - comparison_func = self.registry.get_comparison_function(api_name, dtype) + def _perform_comparison(self, api_name, input_data, dtype, in_dtype): + comparison_func = self.registry.get_comparison_function(api_name, dtype, in_dtype) comparison_func(input_data) - def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype): + def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, in_and_out_dtype): + dtype = in_and_out_dtype.get('dtype') + in_dtype = in_and_out_dtype.get('in_dtype') message = "" _, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) abs_err = get_abs_err(bench_output, device_output) rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) input_data = CompareInput(bench_output, device_output, compare_column, dtype, rel_err_orign) if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: - self._perform_comparison(api_name, input_data, dtype) + self._perform_comparison(api_name, input_data, dtype, in_dtype) else: message += f"The data type {dtype} is not supported for new precision standard." diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py index daabc383e9..c9a1752db7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py @@ -71,11 +71,11 @@ class StandardRegistry: raise ValueError("The function to be registered must be callable.") self.comparison_functions[standard] = func - def get_comparison_function(self, api_name, dtype): - standard = self._get_standard_category(api_name, dtype) + def get_comparison_function(self, api_name, dtype, in_dtype): + standard = self._get_standard_category(api_name, dtype, in_dtype) return self.comparison_functions.get(standard) - def _get_standard_category(self, api_name, dtype): + def _get_standard_category(self, api_name, out_dtype, in_dtype): """ Determines the standard category for a given API name and data type. @@ -85,7 +85,8 @@ class StandardRegistry: Args: api_name (str): The name of the API for which to determine the standard category. - dtype (type, optional): The data type to check against the BINARY_COMPARE_UNSUPPORT_LIST. + out_dtype (type): The data type of the output tensor. + in_dtype (type): The data type of the input tensor. Returns: str: The name of the standard category that matches the API name and data type, or 'benchmark' if no match @@ -97,9 +98,11 @@ class StandardRegistry: The BINARY_COMPARE_UNSUPPORT_LIST should be defined and contain all data types that are not supported for binary comparison. """ - if is_dtype_fp8_or_hif8(dtype): + if is_dtype_fp8_or_hif8(out_dtype): return CompareConst.ULP_COMPARE - if str(dtype) not in BINARY_COMPARE_UNSUPPORT_LIST: + if is_dtype_fp8_or_hif8(in_dtype): + return CompareConst.ABSOLUTE_THRESHOLD + if str(out_dtype) not in BINARY_COMPARE_UNSUPPORT_LIST: return CompareConst.BINARY_CONSISTENCY for name, category in self.api_standard_function_map.items(): if api_name in category: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index d1a545ea04..2b9a502a21 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -315,7 +315,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict out = safe_get_value(out, 0, "out") device_out = safe_get_value(device_out, 0, "device_out") - return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message) + return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message, + rank=0, is_fp8=is_fp8) def run_torch_api_online(api_full_name, api_data, backward_content): @@ -330,7 +331,7 @@ def run_torch_api_online(api_full_name, api_data, backward_content): device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None) device_out = exec_api(device_exec_params) device_out = move2device_exec(device_out, "cpu") - return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank) + return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank, is_fp8=False) def check_need_grad(api_info_dict): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index db785952ff..0e739a6d6a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -65,7 +65,7 @@ class BackwardMessage: class UtDataInfo: def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list, - backward_message, rank=0): + backward_message, rank=0, is_fp8=False): self.bench_grad = bench_grad self.device_grad = device_grad self.device_output = device_output @@ -74,6 +74,7 @@ class UtDataInfo: self.in_fwd_data_list = in_fwd_data_list self.backward_message = backward_message self.rank = rank + self.is_fp8 = is_fp8 def get_validated_result_csv_path(result_csv_path, mode): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py index 6fc36bcdec..ed972db09e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py @@ -106,7 +106,7 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_ cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None) cpu_out = exec_api(cpu_exec_params) - npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank) + npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank, is_fp8=False) npu_detail = compare.compare_output(api_full_name, npu_data_info, True) npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1]) @@ -114,7 +114,7 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_ api_data_gpu = move2target_device(api_data, device) # args, kwargs -> gpu, result -> npu data_info = func(api_full_name, api_data_gpu, config.backward_content) gpu_out = data_info.bench_output - gpu_data_info = UtDataInfo(None, None, gpu_out, cpu_out, None, [], None, rank=api_data.rank) + gpu_data_info = UtDataInfo(None, None, gpu_out, cpu_out, None, [], None, rank=api_data.rank, is_fp8=False) gpu_detail = compare.compare_output(api_full_name, gpu_data_info, True) gpu_data = pd.DataFrame(gpu_detail, columns=DETAIL_TEST_ROWS[-1]) -- Gitee From b3eafabc23164dd9e1cdb2e8f197c489710d8ee4 Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 15 Jul 2025 15:25:28 +0800 Subject: [PATCH 14/23] fixbug --- .../api_accuracy_checker/common/utils.py | 6 ++++ .../api_accuracy_checker/compare/algorithm.py | 30 +++++++++---------- .../run_ut/data_generate.py | 4 +-- .../api_accuracy_checker/run_ut/run_ut.py | 1 - .../run_ut/run_ut_utils.py | 29 +++++++++--------- 5 files changed, 38 insertions(+), 32 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py index 1f04342231..8df1a6b034 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py @@ -311,3 +311,9 @@ def is_dtype_fp8_or_hif8(dtype): True if dtype is FP8 or HiFloat8, False otherwise. """ return is_dtype_fp8(dtype) or is_dtype_hif8(dtype) + + +def is_hifloat8_tensor(tensor): + if not IS_GPU and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor): + return True + return False \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index 2b3a99fea7..2c4678aa16 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -233,26 +233,26 @@ def get_ulp_err(bench_output, device_output, dtype): def calc_ulp_err_fp8(bench_output, device_output): # compute ulp error of FP8 x = np.float64(bench_output) - H8 = np.float64(device_output) + hi_fp8 = np.float64(device_output) - EX = np.log2(abs(x) + 2**(-1000)) - EX[EX < -22] = -22 - E = np.floor(EX) # Exponent + ex = np.log2(abs(x) + 2**(-1000)) + ex[ex < -22] = -22 + exponent = np.floor(ex) # Exponent - Eabs = np.abs(E) - Wm = np.zeros_like(x) # Mantissa width Init - Wm[Eabs <= 15] = 1 - Wm[Eabs <= 7 ] = 2 - Wm[Eabs <= 3 ] = 3 - ulp_err = (H8 - x) * 2 ** (-E + Wm) # for Wm = 1~3 + eabs = np.abs(exponent) + wm = np.zeros_like(x) # Mantissa width Init + wm[eabs <= 15] = 1 + wm[eabs <= 7 ] = 2 + wm[eabs <= 3 ] = 3 + ulp_err = (hi_fp8 - x) * 2 ** (-exponent + wm) # for wm = 1~3 - S_EX = EX * np.where(x >= 0, 1, -1) - EH = np.log2(abs(H8) + 2**(-1000)) + s_ex = ex * np.where(x >= 0, 1, -1) + eh = np.log2(abs(hi_fp8) + 2**(-1000)) - S_EH = EH * np.where(H8 >= 0, 1, -1) - ulp_err1 = S_EH - S_EX # for Wm = 0 + s_eh = eh * np.where(hi_fp8 >= 0, 1, -1) + ulp_err1 = s_eh - s_ex # for wm = 0 - ulp_err[Wm == 0] = ulp_err1[Wm == 0] # Merge 2 cases + ulp_err[wm == 0] = ulp_err1[wm == 0] # Merge 2 cases return ulp_err diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index a22efc3fd4..d882dd3c6a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -29,10 +29,10 @@ else: from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \ - CompareException, get_module_and_atttribute_name, get_attribute, is_dtype_fp8, is_dtype_hif8 + CompareException, get_module_and_atttribute_name, get_attribute, is_dtype_fp8, is_dtype_hif8, is_hifloat8_tensor from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import load_pt, is_hifloat8_tensor +from msprobe.pytorch.common.utils import load_pt from msprobe.core.common.const import Const, FileCheckConst, CompareConst diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 2b9a502a21..2fdd4317f4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -149,7 +149,6 @@ def run_api_offline(config, compare, api_name_set): if config.save_error_data: do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success) except Exception as err: - import traceback;traceback.print_exc() if "expected scalar type Long" in str(err): logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " "'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index 0e739a6d6a..c3329b427e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -35,8 +35,7 @@ from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate -from msprobe.pytorch.common.utils import is_hifloat8_tensor -from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8 +from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8, is_hifloat8_tensor hf_32_standard_api = ["conv1d", "conv2d"] @@ -166,6 +165,7 @@ def raise_bench_data_dtype(api_name, arg, raise_dtype=None): def generate_device_params(input_args, input_kwargs, need_backward, api_name): is_fp8 = False + def recursive_arg_to_device(arg_in, to_detach, depth=0): nonlocal is_fp8 if depth > Const.MAX_DEPTH: @@ -369,6 +369,7 @@ def fp8_to_fp32(x): else: raise ValueError(f"Unsupported dtype: {x.dtype}. Expected torch.float8_e4m3fn or torch.float8_e5m2.") + def hif8_to_fp32(x): """ 将HiFloat8格式的张量转换为FP32格式,保持原始HiFloat8的表示范围 @@ -413,21 +414,21 @@ def hif8_to_fp32(x): continue # 确定指数范围和尾数位数 - E = np.floor(np.log2(tmp + 1e-100)) # 添加小常量避免log2(0) - absE = abs(E) + exponent = np.floor(np.log2(tmp + 1e-100)) # 添加小常量避免log2(0) + eabs = abs(exponent) # 根据指数范围确定尾数位数和还原规则 - if absE <= 3: # 3-bit Mantissa - mantissa = (tmp / (2.0 ** E)) * 8.0 # 还原尾数部分 - res[multi_indices] = s * (mantissa / 8.0) * (2.0 ** E) - elif absE <= 7: # 2-bit Mantissa - mantissa = (tmp / (2.0 ** E)) * 4.0 - res[multi_indices] = s * (mantissa / 4.0) * (2.0 ** E) - elif absE <= 15: # 1-bit Mantissa - mantissa = (tmp / (2.0 ** E)) * 2.0 - res[multi_indices] = s * (mantissa / 2.0) * (2.0 ** E) + if eabs <= 3: # 3-bit Mantissa + mantissa = (tmp / (2.0 ** exponent)) * 8.0 # 还原尾数部分 + res[multi_indices] = s * (mantissa / 8.0) * (2.0 ** exponent) + elif eabs <= 7: # 2-bit Mantissa + mantissa = (tmp / (2.0 ** exponent)) * 4.0 + res[multi_indices] = s * (mantissa / 4.0) * (2.0 ** exponent) + elif eabs <= 15: # 1-bit Mantissa + mantissa = (tmp / (2.0 ** exponent)) * 2.0 + res[multi_indices] = s * (mantissa / 2.0) * (2.0 ** exponent) else: # 0-bit Mantissa - res[multi_indices] = s * (2.0 ** E) + res[multi_indices] = s * (2.0 ** exponent) res = torch.from_numpy(res) if requires_grad: -- Gitee From 24557a11352e6bd0fdbd90e18db357502a07b820 Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 15 Jul 2025 15:54:34 +0800 Subject: [PATCH 15/23] fix bug --- .../api_accuracy_checker/compare/algorithm.py | 4 +- .../compare/test_compare.py | 68 ++++++++++++++----- .../test_standard_register.py | 22 +++--- 3 files changed, 66 insertions(+), 28 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index 2c4678aa16..c6e0fdd9c5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -242,8 +242,8 @@ def calc_ulp_err_fp8(bench_output, device_output): eabs = np.abs(exponent) wm = np.zeros_like(x) # Mantissa width Init wm[eabs <= 15] = 1 - wm[eabs <= 7 ] = 2 - wm[eabs <= 3 ] = 3 + wm[eabs <= 7] = 2 + wm[eabs <= 3] = 3 ulp_err = (hi_fp8 - x) * 2 ** (-exponent + wm) # for wm = 1~3 s_ex = ex * np.where(x >= 0, 1, -1) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py index fc05f9d469..ce49c7440b 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py @@ -102,7 +102,7 @@ class TestCompare(unittest.TestCase): '\nMax abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n']]) def test_compare_core_different(self): - res = self.compare._compare_core('api', 1, 'str') + res = self.compare._compare_core('api', 1, 'str', False) self.assertEqual(res[0], 'error') self.assertEqual(res[2], 'bench and npu output type is different.') @@ -112,7 +112,7 @@ class TestCompare(unittest.TestCase): 'key1': 1, 'key2': 2 } - res = self.compare._compare_core('api', output_dict, output_dict) + res = self.compare._compare_core('api', output_dict, output_dict, False) self.assertEqual(res[0], 'error') self.assertEqual(res[2], "Unexpected output type in compare_core: ") @@ -126,27 +126,27 @@ class TestCompare(unittest.TestCase): 'key3': 3, 'key4': 4 } - res = self.compare._compare_core('api', bench_dict, device_dict) + res = self.compare._compare_core('api', bench_dict, device_dict, False) self.assertEqual(res[0], 'error') self.assertEqual(res[2], 'bench and npu output dict keys are different.') def test_compare_core_with_tensor(self): tensor = torch.tensor([1, 2, 3]) - res = self.compare._compare_core('api', tensor, tensor) + res = self.compare._compare_core('api', tensor, tensor, False) self.assertEqual(res[0], 'pass') self.assertEqual(res[2], 'Compare algorithm is not supported for int64 data. Only judged by Error Rate.\n') def test_compare_core_with_buildin(self): interger = 1 - res = self.compare._compare_core('api', interger, interger) + res = self.compare._compare_core('api', interger, interger, False) self.assertEqual(res[0], 'pass') self.assertEqual(res[2], '') def test_compare_core_with_none(self): - res = self.compare._compare_core('api', None, None) + res = self.compare._compare_core('api', None, None, False) self.assertEqual(res[0], 'SKIP') self.assertEqual(res[2], 'Bench output is None, skip this test.') @@ -155,7 +155,7 @@ class TestCompare(unittest.TestCase): bench_out, npu_out = torch.randn(100, 100), torch.randn(100, 100) bench_grad, npu_grad = [torch.randn(100, 100)], [torch.randn(100, 100)] api_name = 'Functional.conv2d.0' - data_info = UtDataInfo(bench_grad, npu_grad, bench_out, npu_out, None, None, None) + data_info = UtDataInfo(bench_grad, npu_grad, bench_out, npu_out, None, None, None, False) is_fwd_success, is_bwd_success = self.compare.compare_output(api_name, data_info) self.assertFalse(is_fwd_success) # is_bwd_success should be checked @@ -270,27 +270,39 @@ class TestCompare(unittest.TestCase): cpu_output = torch.Tensor([1.0, 2.0, 3.0]) npu_output = torch.Tensor([1.0, 2.0, 3.0]) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("conv2d", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") def test_compare_float_tensor_binary(self): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16) npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("abs", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") def test_compare_float_tensor_absolute(self): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("mul", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") @@ -298,9 +310,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") @@ -308,9 +324,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16) npu_output = torch.tensor([1.1, 2.1, 3.1], dtype=torch.float16) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "error") @@ -318,9 +338,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16) npu_output = torch.tensor([1.0001, 2.0001, 3.0001], dtype=torch.float16) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") @@ -328,9 +352,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16) npu_output = torch.tensor([1.01, 2.01, 3.01], dtype=torch.float16) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "Warning") @@ -338,9 +366,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) npu_output = torch.tensor([1.01, 2.01, 3.01], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "error") @@ -348,9 +380,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_register.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_register.py index 0a77634893..408d83e462 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_register.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_register.py @@ -27,7 +27,7 @@ class TestStandardRegistry(unittest.TestCase): mock_func = Mock() self.registry.register("binary_consistency", mock_func) # 使用支持二进制比较的数据类型 - result = self.registry.get_comparison_function("abs", dtype='torch.int8') + result = self.registry.get_comparison_function("abs", dtype='torch.int8', in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_comparison_function_absolute_threshold(self): @@ -35,55 +35,57 @@ class TestStandardRegistry(unittest.TestCase): mock_func = Mock() self.registry.register("absolute_threshold", mock_func) # 假设'test_api'在absolute_standard_api列表中 - result = self.registry.get_comparison_function("mul") + result = self.registry.get_comparison_function("mul", dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_comparison_function_ulp(self): """测试获取ULP比较函数""" mock_func = Mock() self.registry.register("ulp_compare", mock_func) - result = self.registry.get_comparison_function("matmul") + result = self.registry.get_comparison_function("matmul", dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_comparison_function_thousandth(self): """测试获取双千比较函数""" mock_func = Mock() self.registry.register("thousandth_threshold", mock_func) - result = self.registry.get_comparison_function("conv2d") + result = self.registry.get_comparison_function("conv2d", dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_comparison_function_benchmark(self): """测试获取默认benchmark比较函数""" mock_func = Mock() self.registry.register("benchmark", mock_func) - result = self.registry.get_comparison_function("npu_fusion_attention") + result = self.registry.get_comparison_function("npu_fusion_attention", dtype='torch.float32', + in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_standard_category_binary(self): """测试获取二进制一致性标准类别""" dtype = 'torch.int8' self.assertNotIn(dtype, BINARY_COMPARE_UNSUPPORT_LIST) - category = self.registry._get_standard_category("abs", dtype) + category = self.registry._get_standard_category("abs", out_dtype=dtype, in_dtype='torch.float32') self.assertEqual(category, "binary_consistency") def test_get_standard_category_absolute(self): """测试获取绝对阈值标准类别""" - category = self.registry._get_standard_category("mul") + category = self.registry._get_standard_category("mul", out_dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(category, "absolute_threshold") def test_get_standard_category_default(self): """测试获取默认benchmark标准类别""" - category = self.registry._get_standard_category("unknown_api") + category = self.registry._get_standard_category("unknown_api", out_dtype='torch.float32', + in_dtype='torch.float32') self.assertEqual(category, "benchmark") def test_get_standard_category_ulp(self): """测试获取ULP标准类别""" - category = self.registry._get_standard_category("matmul") + category = self.registry._get_standard_category("matmul", out_dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(category, "ulp_compare") def test_get_standard_category_thousandth(self): """测试获取双千比对标准类别""" - category = self.registry._get_standard_category("conv2d") + category = self.registry._get_standard_category("conv2d", out_dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(category, "thousandth_threshold") -- Gitee From 491df9b8357b60cba729d6fd88aea6253dff4f6b Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 15 Jul 2025 16:06:57 +0800 Subject: [PATCH 16/23] fix bug --- .../api_accuracy_checker/compare/test_compare.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py index ce49c7440b..9109e69389 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py @@ -72,7 +72,7 @@ class TestCompare(unittest.TestCase): def test_compare_core_wrapper(self): dummy_input = torch.randn(100, 100) bench_out, npu_out = dummy_input, dummy_input - test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out) + test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out, False) actual_cosine_similarity = detailed_result_total[0][3] # 设置一个小的公差值 tolerance = 1e-4 @@ -210,7 +210,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.Tensor([1.0, 2.0, 3.0]) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "pass") def test_compare_torch_tensor_bf16(self): @@ -218,7 +218,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "pass") def test_compare_torch_tensor_different_shape(self): @@ -226,7 +226,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.Tensor([1.0, 2.0, 3.0]) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "error") def test_compare_torch_tensor_different_dtype(self): @@ -234,7 +234,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.Tensor([1.0, 2.0, 3.0]) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "error") def test_compare_torch_tensor_special_dtype(self): @@ -242,7 +242,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.Tensor([True, True, False]) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "pass") def test_compare_builtin_type_pass_with_special_types(self): -- Gitee From 1f3f6b285f9b00ded559b731d3d35dd84dad83dd Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 15 Jul 2025 16:24:36 +0800 Subject: [PATCH 17/23] fixbug --- .../pytorch_ut/api_accuracy_checker/compare/test_compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py index 9109e69389..2ab38eb56d 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py @@ -86,7 +86,7 @@ class TestCompare(unittest.TestCase): self.assertTrue(test_final_success) bench_out, npu_out = [dummy_input, dummy_input], [dummy_input, dummy_input] - test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out) + test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out, False) actual_cosine_similarity = detailed_result_total[0][3] self.assertTrue(np.isclose(actual_cosine_similarity, 1.0, atol=tolerance)) actual_cosine_similarity = detailed_result_total[1][3] -- Gitee From 9cdc586e170bfbe86387f214812022a4121ebedc Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 16 Jul 2025 10:11:42 +0800 Subject: [PATCH 18/23] bugfix --- .../pytorch/api_accuracy_checker/run_ut/data_generate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index d882dd3c6a..221446c266 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -29,7 +29,8 @@ else: from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \ - CompareException, get_module_and_atttribute_name, get_attribute, is_dtype_fp8, is_dtype_hif8, is_hifloat8_tensor + CompareException, get_module_and_atttribute_name, get_attribute, is_dtype_fp8, is_dtype_hif8, is_hifloat8_tensor, \ + is_dtype_fp8_or_hif8 from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import load_pt @@ -71,6 +72,9 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): data_type = info.get('type') data_path = info.get('datapath', info.get('data_name')) data_path = get_full_data_path(data_path, real_data_path) + dtype = info.get('dtype') + if is_dtype_fp8_or_hif8(dtype) and IS_GPU: + raise CompareException("GPU does not need to support float8 data type") if data_type in TENSOR_DATA_LIST: if data_path: data = gen_real_tensor(data_path, convert_type) -- Gitee From 85f722f8addee78b17c5e3eac482785660005fa4 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 16 Jul 2025 15:34:00 +0800 Subject: [PATCH 19/23] support fp8 ulp --- .../precision_standard/base_standard.py | 16 +++++----- .../precision_standard/ulp_compare.py | 32 +++++++++++++++---- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py index e3ff663758..0dcd413830 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py @@ -133,16 +133,16 @@ class BasePrecisionCompare: return compare_result def _get_and_convert_values(self, column_name): - npu_value = self.row_npu.get(column_name) - gpu_value = self.row_gpu.get(column_name) - if npu_value is None: - raise ValueError(f"NPU value for column '{column_name}' is None.") - if gpu_value is None: - raise ValueError(f"GPU value for column '{column_name}' is None.") - npu_value = convert_str_to_float(npu_value) - gpu_value = convert_str_to_float(gpu_value) + npu_value = self._get_and_convert_value(self.row_npu, column_name, "NPU") + gpu_value = self._get_and_convert_value(self.row_gpu, column_name, "GPU") return npu_value, gpu_value + def _get_and_convert_value(self, row_data, column_name, device_type): + value = row_data.get(column_name) + if value is None: + raise ValueError(f"{device_type} value for column '{column_name}' is None.") + return convert_str_to_float(value) + def _post_compare(self, metrics, inf_nan_consistency): metrics = self._get_status(metrics, inf_nan_consistency) metrics.update({'compare_algorithm': self.compare_algorithm}) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py index df181588ad..c22ee51da6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py @@ -25,6 +25,7 @@ from msprobe.core.common.const import Const, CompareConst from msprobe.pytorch.api_accuracy_checker.compare.algorithm import calc_ratio, get_ulp_err from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, check_inf_or_nan, \ is_inf_or_nan +from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8_or_hif8 UlpInfNanConsistency = namedtuple('UlpInfNanConsistency', ['mean_ulp_err_inf_nan_consistency', @@ -144,7 +145,10 @@ class UlpPrecisionCompare(BasePrecisionCompare): mean_ulp_err = metrics.get(CompareConst.MEAN_ULP_ERR) ulp_err_proportion = metrics.get(CompareConst.ULP_ERR_PROPORTION) ulp_err_proportion_ratio = metrics.get(CompareConst.ULP_ERR_PROPORTION_RATIO) - if dtype == Const.TORCH_FLOAT32: + if is_dtype_fp8_or_hif8(dtype): + status, final_message = \ + self._get_fp8_ulp_err_status(ulp_err_proportion) + elif dtype == Const.TORCH_FLOAT32: status, final_message = \ self._get_fp32_ulp_err_status(mean_ulp_err, ulp_err_proportion, ulp_err_proportion_ratio) else: @@ -182,14 +186,28 @@ class UlpPrecisionCompare(BasePrecisionCompare): compare_message = "ERROR: ULP误差不满足标准\n" return CompareConst.ERROR, compare_message + def _get_fp8_ulp_err_status(self, ulp_err_proportion, ulp_err_proportion_ratio): + _, ulp_err_proportion_threshold, _ = StandardConfig.get_ulp_threshold(torch.float16) + if ulp_err_proportion < ulp_err_proportion_threshold: + return CompareConst.PASS, "" + compare_message = "ERROR: ULP误差不满足标准\n" + return CompareConst.ERROR, compare_message + def _compute_ratio(self): compare_message = "" - mean_ulp_err, mean_ulp_err_inf_nan_consistency, mean_ulp_err_message = self._compute_mean_ulp_err() - compare_message += mean_ulp_err_message - npu_ulp_err_proportion, gpu_ulp_err_proportion = self._compute_ulp_err_proportion() - ulp_err_proportion_ratio, ulp_err_proportion_ratio_inf_nan_consistency, ulp_err_proportion_ratio_message = \ - self._compute_ulp_err_proportion_ratio(npu_ulp_err_proportion, gpu_ulp_err_proportion, str(self.dtype)) - compare_message += ulp_err_proportion_ratio_message + dtype = self.row_npu.get(ApiPrecisionCompareColumn.DEVICE_DTYPE) + if is_dtype_fp8_or_hif8(dtype): + mean_ulp_err = CompareConst.SPACE + ulp_err_proportion_ratio = CompareConst.SPACE + npu_ulp_err_proportion = self._get_and_convert_value(self.row_npu, + ApiPrecisionCompareColumn.ULP_ERR_PROPORTION, "NPU") + else: + mean_ulp_err, mean_ulp_err_inf_nan_consistency, mean_ulp_err_message = self._compute_mean_ulp_err() + compare_message += mean_ulp_err_message + npu_ulp_err_proportion, gpu_ulp_err_proportion = self._compute_ulp_err_proportion() + ulp_err_proportion_ratio, ulp_err_proportion_ratio_inf_nan_consistency, ulp_err_proportion_ratio_message = \ + self._compute_ulp_err_proportion_ratio(npu_ulp_err_proportion, gpu_ulp_err_proportion, str(self.dtype)) + compare_message += ulp_err_proportion_ratio_message metrics = { CompareConst.MEAN_ULP_ERR: mean_ulp_err, CompareConst.ULP_ERR_PROPORTION: npu_ulp_err_proportion, -- Gitee From d1cce24929418c7a4e124c41d29afee0c9df31ec Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 16 Jul 2025 15:36:55 +0800 Subject: [PATCH 20/23] FIX BUG --- .../api_accuracy_checker/precision_standard/ulp_compare.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py index c22ee51da6..415d91daba 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py @@ -201,6 +201,8 @@ class UlpPrecisionCompare(BasePrecisionCompare): ulp_err_proportion_ratio = CompareConst.SPACE npu_ulp_err_proportion = self._get_and_convert_value(self.row_npu, ApiPrecisionCompareColumn.ULP_ERR_PROPORTION, "NPU") + mean_ulp_err_inf_nan_consistency = True + ulp_err_proportion_ratio_inf_nan_consistency = True else: mean_ulp_err, mean_ulp_err_inf_nan_consistency, mean_ulp_err_message = self._compute_mean_ulp_err() compare_message += mean_ulp_err_message -- Gitee From 53333d656e0dc268f3c4cefcd79a6b9e3cd281e3 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 16 Jul 2025 15:37:56 +0800 Subject: [PATCH 21/23] FIX BUG --- .../api_accuracy_checker/precision_standard/ulp_compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py index 415d91daba..824b9b29ed 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py @@ -186,7 +186,7 @@ class UlpPrecisionCompare(BasePrecisionCompare): compare_message = "ERROR: ULP误差不满足标准\n" return CompareConst.ERROR, compare_message - def _get_fp8_ulp_err_status(self, ulp_err_proportion, ulp_err_proportion_ratio): + def _get_fp8_ulp_err_status(self, ulp_err_proportion): _, ulp_err_proportion_threshold, _ = StandardConfig.get_ulp_threshold(torch.float16) if ulp_err_proportion < ulp_err_proportion_threshold: return CompareConst.PASS, "" -- Gitee From c51f404aad19923d302b34fb970e40cd7535df34 Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 18 Jul 2025 09:56:20 +0800 Subject: [PATCH 22/23] fix bug --- .../msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index c3329b427e..af6f1e327e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -174,7 +174,7 @@ def generate_device_params(input_args, input_kwargs, need_backward, api_name): if isinstance(arg_in, (list, tuple)): return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth+1) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): - if is_dtype_fp8(arg_in.dtype) or (not IS_GPU and isinstance(arg_in, torch_npu.HiFloat8Tensor)): + if is_dtype_fp8(arg_in.dtype) or is_hifloat8_tensor(arg_in): is_fp8 = True if need_backward and arg_in.requires_grad: arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_() -- Gitee From 5398adb2a2abc2a608d08209e49b4cc942694739 Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Wed, 23 Jul 2025 12:57:15 +0000 Subject: [PATCH 23/23] update debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py. Signed-off-by: jiangchangting1 --- .../api_accuracy_checker/precision_standard/ulp_compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py index 824b9b29ed..2cfa6c2cff 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py @@ -187,7 +187,7 @@ class UlpPrecisionCompare(BasePrecisionCompare): return CompareConst.ERROR, compare_message def _get_fp8_ulp_err_status(self, ulp_err_proportion): - _, ulp_err_proportion_threshold, _ = StandardConfig.get_ulp_threshold(torch.float16) + _, ulp_err_proportion_threshold, _ = StandardConfig.get_ulp_threshold(torch.float16) if ulp_err_proportion < ulp_err_proportion_threshold: return CompareConst.PASS, "" compare_message = "ERROR: ULP误差不满足标准\n" -- Gitee