diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 4af7b5a855b05318aa4b5ba8f361e0ad26067138..d69c8c50bb1b2273fda6922366916fa36017a6a5 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -154,7 +154,7 @@ class Const: STACK = "stack" ATEN = "Aten" - MODULE_WHITE_LIST = ["torch", "numpy"] + MODULE_WHITE_LIST = ["torch", "numpy", "torch_npu"] FUNC_SKIP_LIST = ["construct", "__call__"] FILE_SKIP_LIST = ["msprobe", "MindSpeed"] diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py index 5724f626237af164c582d2165354d5ab35e3b839..8df1a6b034526b02e1e19b1a5bcba4db1957e8e8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py @@ -259,3 +259,61 @@ def get_attribute(module_name, attribute_name): logger.error(f"Failed to get attribute {attribute_name} from module {module_name}: {e}") raise CompareException(CompareException.INVALID_ATTRIBUTE_ERROR) from e return attribute + + +def is_dtype_fp8(dtype): + """ + Function Description: + Check if the data type is float8. + Parameter: + dtype: Data type (torch.dtype or string). + Return: + True or False. + """ + # 处理字符串类型的 dtype + if isinstance(dtype, str): + return dtype in ["torch.float8_e4m3fn", "torch.float8_e5m2"] + + # 处理 torch.dtype 类型 + return dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + + +def is_dtype_hif8(dtype): + """ + Function Description: + Check if the data type is HiFloat8Tensor. + Parameter: + dtype: Data type (string). + Return: + True or False. + """ + # 处理字符串类型的 dtype + if isinstance(dtype, str): + dtype_str = dtype + # 处理类对象或 dtype 对象 + else: + dtype_str = str(dtype) + + # 检查是否匹配 HiFloat8Tensor 的字符串表示 + return ( + dtype_str == "" or + dtype_str == "torch_npu.HiFloat8Tensor" + ) + + +def is_dtype_fp8_or_hif8(dtype): + """ + Function Description: + Check if the data type is FP8 (native or HiFloat8 variant). + Parameters: + dtype: Data type (torch.dtype, class, or string). + Returns: + True if dtype is FP8 or HiFloat8, False otherwise. + """ + return is_dtype_fp8(dtype) or is_dtype_hif8(dtype) + + +def is_hifloat8_tensor(tensor): + if not IS_GPU and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor): + return True + return False \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index ddee254c2b1085f9af96fe2774c53fb88c5821f4..c6e0fdd9c54a596534a330e60190fe7f0d522799 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -22,6 +22,7 @@ import numpy as np from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig +from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8 from msprobe.core.common.const import CompareConst @@ -212,6 +213,8 @@ def check_norm_value(normal_value_mask, rel_err, rtol): def get_ulp_err(bench_output, device_output, dtype): + if is_dtype_fp8(dtype): + return calc_ulp_err_fp8(bench_output, device_output) parameters = ULP_PARAMETERS.get(dtype) min_eb = parameters.get('min_eb')[0] exponent_num = parameters.get('exponent_num')[0] @@ -227,6 +230,32 @@ def get_ulp_err(bench_output, device_output, dtype): return ulp_err +def calc_ulp_err_fp8(bench_output, device_output): + # compute ulp error of FP8 + x = np.float64(bench_output) + hi_fp8 = np.float64(device_output) + + ex = np.log2(abs(x) + 2**(-1000)) + ex[ex < -22] = -22 + exponent = np.floor(ex) # Exponent + + eabs = np.abs(exponent) + wm = np.zeros_like(x) # Mantissa width Init + wm[eabs <= 15] = 1 + wm[eabs <= 7] = 2 + wm[eabs <= 3] = 3 + ulp_err = (hi_fp8 - x) * 2 ** (-exponent + wm) # for wm = 1~3 + + s_ex = ex * np.where(x >= 0, 1, -1) + eh = np.log2(abs(hi_fp8) + 2**(-1000)) + + s_eh = eh * np.where(hi_fp8 >= 0, 1, -1) + ulp_err1 = s_eh - s_ex # for wm = 0 + + ulp_err[wm == 0] = ulp_err1[wm == 0] # Merge 2 cases + return ulp_err + + def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type): return (device_output.astype(data_type) - bench_output).astype(data_type) * \ np.exp2(-eb + exponent_num).astype(data_type) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 55e93d271cec67334fe21c1f6466df2d0254a36b..965fe6315f272543bb912027c4cb4d120f5fa1b6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -175,6 +175,8 @@ def analyse_csv(npu_data, gpu_data, config): try: new_status = get_api_status(row_npu, row_gpu, api_name, compare_column, registry) except Exception as err: + import traceback + traceback.print_exc() logger.error(f"Get api status error: {str(err)}") compare_column.api_name = full_api_name_with_direction_status compare_column.compare_result = CompareConst.SKIP @@ -249,11 +251,19 @@ def get_api_status(row_npu, row_gpu, api_name, compare_column, registry): compare_column.api_name = full_api_name_with_direction_status dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] input_data = PrecisionCompareInput(row_npu, row_gpu, dtype, compare_column) - comparison_func = registry.get_comparison_function(api_name, dtype) + in_dtype = get_in_dtype(row_npu) + comparison_func = registry.get_comparison_function(api_name, dtype, in_dtype) new_status = comparison_func(input_data) return new_status +def get_in_dtype(row_npu): + if row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO].isspace(): + return "torch.float32" + else: + return "torch.float8_e4m3fn" + + def print_test_success(api_full_name, forward_result, backward_result): is_fwd_success = (forward_result == CompareConst.PASS) is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE) @@ -459,3 +469,7 @@ def _api_precision_compare_parser(parser): parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The api precision compare task result out path.", required=False) + + +if __name__ == '__main__': + _api_precision_compare() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index c12a54c18ad07ae302b41d12704dc82fec01b4c2..531679cf4d32abb03f4676abbdc2855ec4d41140 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -180,7 +180,7 @@ class Comparator: write_csv(DETAIL_TEST_ROWS, detail_save_path) @recursion_depth_decorator("compare_core") - def _compare_core(self, api_name, bench_output, device_output): + def _compare_core(self, api_name, bench_output, device_output, is_fp8): compare_column = CompareColumn() if not isinstance(bench_output, type(device_output)): status = CompareConst.ERROR @@ -192,7 +192,7 @@ class Comparator: message = "bench and npu output dict keys are different." else: status, compare_column, message = self._compare_core(api_name, list(bench_output.values()), - list(device_output.values())) + list(device_output.values()), is_fp8) elif isinstance(bench_output, torch.Tensor): copy_bench_out = bench_output.detach().clone() copy_device_output = device_output.detach().clone() @@ -200,7 +200,7 @@ class Comparator: compare_column.npu_type = str(copy_device_output.dtype) compare_column.shape = tuple(device_output.shape) status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output, - compare_column) + compare_column, is_fp8) elif isinstance(bench_output, (bool, int, float, str)): compare_column.bench_type = str(type(bench_output)) compare_column.npu_type = str(type(device_output)) @@ -255,11 +255,12 @@ class Comparator: bench_output, device_output = data_info.bench_output, data_info.device_output bench_grad, device_grad = data_info.bench_grad, data_info.device_grad backward_message = data_info.backward_message + is_fp8 = data_info.is_fp8 if "dropout" in full_api_name: fwd_success_status, fwd_compare_alg_results = self._compare_dropout(bench_output, device_output) else: fwd_success_status, fwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_output, - device_output) + device_output, is_fp8) if not (bench_grad and device_grad): bwd_success_status, bwd_compare_alg_results = (CompareConst.SPACE, []) else: @@ -267,7 +268,7 @@ class Comparator: bwd_success_status, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], device_grad[0]) else: bwd_success_status, bwd_compare_alg_results = self._compare_core_wrapper(api_name, bench_grad, - device_grad) + device_grad, is_fp8) if backward_message: backward_column = CompareColumn() bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)] @@ -297,7 +298,7 @@ class Comparator: registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, self._accumulative_error_compare) return registry - def _compare_core_wrapper(self, api_name, bench_output, device_output): + def _compare_core_wrapper(self, api_name, bench_output, device_output, is_fp8): detailed_result_total = [] test_final_success = CompareConst.PASS if isinstance(bench_output, (list, tuple)): @@ -308,12 +309,12 @@ class Comparator: else: device_output = device_output[:len(bench_output)] for b_out_i, n_out_i in zip(bench_output, device_output): - status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i) + status_i, compare_result_i, message_i = self._compare_core(api_name, b_out_i, n_out_i, is_fp8) status.append(status_i) compare_result.append(compare_result_i) message.append(message_i) else: - status, compare_result, message = self._compare_core(api_name, bench_output, device_output) + status, compare_result, message = self._compare_core(api_name, bench_output, device_output, is_fp8) if not isinstance(status, list): detailed_result_total.append(compare_result.to_column_value(status, message)) if status == CompareConst.ERROR: @@ -329,10 +330,15 @@ class Comparator: test_final_success = CompareConst.WARNING return test_final_success, detailed_result_total - def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column): + def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column, is_fp8): cpu_shape = bench_output.shape npu_shape = device_output.shape + npu_dtype = device_output.dtype + if is_fp8: + in_dtype = torch.float8_e4m3fn + else: + in_dtype = torch.float32 if npu_dtype == torch.bfloat16: bench_output = bench_output.to(torch.float32) device_output = device_output.to(torch.float32) @@ -356,22 +362,28 @@ class Comparator: compare_column.error_rate = err_rate return status, compare_column, message else: + in_and_out_dtype = { + 'dtype': npu_dtype, + 'in_dtype': in_dtype + } status, compare_column, message = self._compare_float_tensor(api_name, bench_output, device_output, - compare_column, npu_dtype) + compare_column, in_and_out_dtype) return status, compare_column, message - def _perform_comparison(self, api_name, input_data): - comparison_func = self.registry.get_comparison_function(api_name, None) + def _perform_comparison(self, api_name, input_data, dtype, in_dtype): + comparison_func = self.registry.get_comparison_function(api_name, dtype, in_dtype) comparison_func(input_data) - def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype): + def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, in_and_out_dtype): + dtype = in_and_out_dtype.get('dtype') + in_dtype = in_and_out_dtype.get('in_dtype') message = "" _, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) abs_err = get_abs_err(bench_output, device_output) rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) input_data = CompareInput(bench_output, device_output, compare_column, dtype, rel_err_orign) if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: - self._perform_comparison(api_name, input_data) + self._perform_comparison(api_name, input_data, dtype, in_dtype) else: message += f"The data type {dtype} is not supported for new precision standard." diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py index 89c4401b2cac863bc609cce14a9f4c3ca03951b7..d16efcd616a26512b0d16b8b707f0658db54b1c2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -30,7 +30,8 @@ from msprobe.pytorch.common.log import logger current_time = time.strftime("%Y%m%d%H%M%S") API_PRECISION_COMPARE_RESULT_FILE_NAME = "api_precision_compare_result_" + current_time + ".csv" API_PRECISION_COMPARE_DETAILS_FILE_NAME = "api_precision_compare_details_" + current_time + ".csv" -BENCHMARK_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32'] +BENCHMARK_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32', "torch.float8_e4m3fn", + "torch.float8_e5m2"] API_PRECISION_COMPARE_UNSUPPORT_LIST = ['torch.float64', 'torch.complex64', 'torch.complex128'] ULP_COMPARE_SUPPORT_LIST = ['torch.float16', 'torch.bfloat16', 'torch.float32'] BINARY_COMPARE_UNSUPPORT_LIST = BENCHMARK_COMPARE_SUPPORT_LIST + API_PRECISION_COMPARE_UNSUPPORT_LIST diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py index e3ff6637586dd7e6e6c1ea966e5ecd88adf08c11..0dcd413830f211b137e842c51bd8eed5f5bc09f4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py @@ -133,16 +133,16 @@ class BasePrecisionCompare: return compare_result def _get_and_convert_values(self, column_name): - npu_value = self.row_npu.get(column_name) - gpu_value = self.row_gpu.get(column_name) - if npu_value is None: - raise ValueError(f"NPU value for column '{column_name}' is None.") - if gpu_value is None: - raise ValueError(f"GPU value for column '{column_name}' is None.") - npu_value = convert_str_to_float(npu_value) - gpu_value = convert_str_to_float(gpu_value) + npu_value = self._get_and_convert_value(self.row_npu, column_name, "NPU") + gpu_value = self._get_and_convert_value(self.row_gpu, column_name, "GPU") return npu_value, gpu_value + def _get_and_convert_value(self, row_data, column_name, device_type): + value = row_data.get(column_name) + if value is None: + raise ValueError(f"{device_type} value for column '{column_name}' is None.") + return convert_str_to_float(value) + def _post_compare(self, metrics, inf_nan_consistency): metrics = self._get_status(metrics, inf_nan_consistency) metrics.update({'compare_algorithm': self.compare_algorithm}) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py index 82df8c54e87ea1627159a52aef2544028ab21b22..c9a1752db72ea4aa55725d2702281bc9c372115d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py @@ -16,6 +16,7 @@ # limitations under the License. from typing import Callable +from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8_or_hif8 from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import absolute_standard_api, binary_standard_api, \ ulp_standard_api, thousandth_standard_api, accumulative_error_standard_api, BINARY_COMPARE_UNSUPPORT_LIST from msprobe.core.common.const import CompareConst @@ -70,11 +71,11 @@ class StandardRegistry: raise ValueError("The function to be registered must be callable.") self.comparison_functions[standard] = func - def get_comparison_function(self, api_name, dtype=None): - standard = self._get_standard_category(api_name, dtype) + def get_comparison_function(self, api_name, dtype, in_dtype): + standard = self._get_standard_category(api_name, dtype, in_dtype) return self.comparison_functions.get(standard) - def _get_standard_category(self, api_name, dtype=None): + def _get_standard_category(self, api_name, out_dtype, in_dtype): """ Determines the standard category for a given API name and data type. @@ -84,7 +85,8 @@ class StandardRegistry: Args: api_name (str): The name of the API for which to determine the standard category. - dtype (type, optional): The data type to check against the BINARY_COMPARE_UNSUPPORT_LIST. Defaults to None. + out_dtype (type): The data type of the output tensor. + in_dtype (type): The data type of the input tensor. Returns: str: The name of the standard category that matches the API name and data type, or 'benchmark' if no match @@ -96,7 +98,11 @@ class StandardRegistry: The BINARY_COMPARE_UNSUPPORT_LIST should be defined and contain all data types that are not supported for binary comparison. """ - if dtype and dtype not in BINARY_COMPARE_UNSUPPORT_LIST: + if is_dtype_fp8_or_hif8(out_dtype): + return CompareConst.ULP_COMPARE + if is_dtype_fp8_or_hif8(in_dtype): + return CompareConst.ABSOLUTE_THRESHOLD + if str(out_dtype) not in BINARY_COMPARE_UNSUPPORT_LIST: return CompareConst.BINARY_CONSISTENCY for name, category in self.api_standard_function_map.items(): if api_name in category: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py index df181588ad01836186c82df6fc2d23eef63238f0..2cfa6c2cff806ed18b3064caa68e0d4f167a664f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py @@ -25,6 +25,7 @@ from msprobe.core.common.const import Const, CompareConst from msprobe.pytorch.api_accuracy_checker.compare.algorithm import calc_ratio, get_ulp_err from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, check_inf_or_nan, \ is_inf_or_nan +from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8_or_hif8 UlpInfNanConsistency = namedtuple('UlpInfNanConsistency', ['mean_ulp_err_inf_nan_consistency', @@ -144,7 +145,10 @@ class UlpPrecisionCompare(BasePrecisionCompare): mean_ulp_err = metrics.get(CompareConst.MEAN_ULP_ERR) ulp_err_proportion = metrics.get(CompareConst.ULP_ERR_PROPORTION) ulp_err_proportion_ratio = metrics.get(CompareConst.ULP_ERR_PROPORTION_RATIO) - if dtype == Const.TORCH_FLOAT32: + if is_dtype_fp8_or_hif8(dtype): + status, final_message = \ + self._get_fp8_ulp_err_status(ulp_err_proportion) + elif dtype == Const.TORCH_FLOAT32: status, final_message = \ self._get_fp32_ulp_err_status(mean_ulp_err, ulp_err_proportion, ulp_err_proportion_ratio) else: @@ -182,14 +186,30 @@ class UlpPrecisionCompare(BasePrecisionCompare): compare_message = "ERROR: ULP误差不满足标准\n" return CompareConst.ERROR, compare_message + def _get_fp8_ulp_err_status(self, ulp_err_proportion): + _, ulp_err_proportion_threshold, _ = StandardConfig.get_ulp_threshold(torch.float16) + if ulp_err_proportion < ulp_err_proportion_threshold: + return CompareConst.PASS, "" + compare_message = "ERROR: ULP误差不满足标准\n" + return CompareConst.ERROR, compare_message + def _compute_ratio(self): compare_message = "" - mean_ulp_err, mean_ulp_err_inf_nan_consistency, mean_ulp_err_message = self._compute_mean_ulp_err() - compare_message += mean_ulp_err_message - npu_ulp_err_proportion, gpu_ulp_err_proportion = self._compute_ulp_err_proportion() - ulp_err_proportion_ratio, ulp_err_proportion_ratio_inf_nan_consistency, ulp_err_proportion_ratio_message = \ - self._compute_ulp_err_proportion_ratio(npu_ulp_err_proportion, gpu_ulp_err_proportion, str(self.dtype)) - compare_message += ulp_err_proportion_ratio_message + dtype = self.row_npu.get(ApiPrecisionCompareColumn.DEVICE_DTYPE) + if is_dtype_fp8_or_hif8(dtype): + mean_ulp_err = CompareConst.SPACE + ulp_err_proportion_ratio = CompareConst.SPACE + npu_ulp_err_proportion = self._get_and_convert_value(self.row_npu, + ApiPrecisionCompareColumn.ULP_ERR_PROPORTION, "NPU") + mean_ulp_err_inf_nan_consistency = True + ulp_err_proportion_ratio_inf_nan_consistency = True + else: + mean_ulp_err, mean_ulp_err_inf_nan_consistency, mean_ulp_err_message = self._compute_mean_ulp_err() + compare_message += mean_ulp_err_message + npu_ulp_err_proportion, gpu_ulp_err_proportion = self._compute_ulp_err_proportion() + ulp_err_proportion_ratio, ulp_err_proportion_ratio_inf_nan_consistency, ulp_err_proportion_ratio_message = \ + self._compute_ulp_err_proportion_ratio(npu_ulp_err_proportion, gpu_ulp_err_proportion, str(self.dtype)) + compare_message += ulp_err_proportion_ratio_message metrics = { CompareConst.MEAN_ULP_ERR: mean_ulp_err, CompareConst.ULP_ERR_PROPORTION: npu_ulp_err_proportion, diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 9d89b2de32f70c6fa7abf38add49b58a13531d7a..221446c26632eadcd134f13ba5977e7d61e0c2d9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -20,9 +20,17 @@ import math import torch import numpy +try: + import torch_npu +except ImportError: + IS_GPU = True +else: + IS_GPU = False + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \ - CompareException, get_module_and_atttribute_name, get_attribute + CompareException, get_module_and_atttribute_name, get_attribute, is_dtype_fp8, is_dtype_hif8, is_hifloat8_tensor, \ + is_dtype_fp8_or_hif8 from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import load_pt @@ -38,7 +46,10 @@ FLOAT_TYPE = [ 'torch.double', 'torch.float16', 'torch.half', - 'torch.bfloat16' + 'torch.bfloat16', + 'torch.float8_e4m3fn', + 'torch.float8_e5m2', + 'torch_npu.HiFloat8Tensor' ] NUMPY_TYPE = [ "numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32", @@ -61,6 +72,9 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): data_type = info.get('type') data_path = info.get('datapath', info.get('data_name')) data_path = get_full_data_path(data_path, real_data_path) + dtype = info.get('dtype') + if is_dtype_fp8_or_hif8(dtype) and IS_GPU: + raise CompareException("GPU does not need to support float8 data type") if data_type in TENSOR_DATA_LIST: if data_path: data = gen_real_tensor(data_path, convert_type) @@ -69,10 +83,20 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): if api_name in hf_32_standard_api and data.dtype == torch.float32: data = fp32_to_hf32_to_fp32(data) if info.get('requires_grad') and need_grad: + if is_hifloat8_tensor(data): + origin_dtype = info.get('dtype') + else: + origin_dtype = data.dtype + if is_dtype_fp8(origin_dtype): + data = data.to(torch.float32) data.requires_grad_(True) temp_data = data * 1 data = temp_data.type_as(data) data.retain_grad() + if is_dtype_fp8(origin_dtype): + data = data.to(origin_dtype) + if is_dtype_hif8(origin_dtype): + data = torch_npu.HiFloat8Tensor.to_hifloat8(data) elif data_type.startswith("numpy"): if data_type not in NUMPY_TYPE: raise Exception("{} is not supported now".format(data_type)) @@ -196,8 +220,13 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type): tensor = torch.full(shape, high, dtype=dtype) tensor[-1] = low return tensor + low_scale, high_scale = low, high - dtype_finfo = torch.finfo(dtype) + if is_dtype_hif8(dtype): + finfo_dtype = torch.float32 + else: + finfo_dtype = dtype + dtype_finfo = torch.finfo(finfo_dtype) #适配老版json high和low为inf或-inf的情况,取dtype的最大值或最小值进行放缩 if high == float(CompareConst.INF): high_scale = dtype_finfo.max @@ -209,7 +238,11 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type): low_scale = dtype_finfo.min scale = high_scale - low_scale - rand01 = torch.rand(shape, dtype=dtype) + if is_dtype_fp8(dtype) or is_dtype_hif8(dtype): + generate_dtype = torch.float32 + else: + generate_dtype = dtype + rand01 = torch.rand(shape, dtype=generate_dtype) tensor = rand01 * scale + low_scale elif 'int' in data_dtype or 'long' in data_dtype: low, high = int(low), int(high) @@ -236,6 +269,12 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type): if low_origin in [float(CompareConst.INF), float(CompareConst.NEG_INF)]: tmp_tensor[0] = low_origin data = tmp_tensor.reshape(shape) + if is_dtype_fp8(dtype): + data = data.to(dtype) + if is_dtype_hif8(dtype): + if IS_GPU: + raise CompareException("GPU does not support torch_npu.HiFloat8Tensor") + data = torch_npu.HiFloat8Tensor.to_hifloat8(data) return data diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index 0f184d14b66d84607a6767ba9ef5210ff4fc5b69..422840f7aa5db972cfce0270fd3a978f56bb0e10 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -125,7 +125,7 @@ def run_torch_api(api_full_name, api_info_dict, real_data_path): device_info_kwargs = kwargs.get(Const.DEVICE) if device_info_kwargs and device_info_kwargs.get(Const.VALUE): kwargs[Const.DEVICE] = current_device - npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name) + npu_args, npu_kwargs, _ = generate_device_params(args, kwargs, False, api_name) if kwargs.get(Const.DEVICE): del kwargs[Const.DEVICE] cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs, False, None) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 082f391c957578bad9b1dff546803aa7d4ce05b0..2fdd4317f4e0007d3f3d9e2a3e823e10c197103c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -269,7 +269,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict device_info_kwargs = kwargs.get(Const.DEVICE) if device_info_kwargs and device_info_kwargs.get(Const.VALUE): kwargs[Const.DEVICE] = current_device - device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name) + device_args, device_kwargs, is_fp8 = generate_device_params(args, kwargs, need_backward, api_name) if kwargs.get(Const.DEVICE): del kwargs[Const.DEVICE] cpu_params = generate_cpu_params(args, kwargs, need_backward, api_name) @@ -284,6 +284,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict device_exec_params = ExecParams(api_type, api_name, current_device, device_args, device_kwargs, is_autocast, autocast_dtype) device_out = exec_api(device_exec_params) + if is_fp8 and isinstance(device_out, torch.Tensor) and device_out.dtype == torch.float32: + device_out = device_out.to(torch.float16) current_path = os.path.dirname(os.path.realpath(__file__)) ut_setting_path = os.path.join(current_path, "torch_ut_setting.json") api_setting_dict = get_json_contents(ut_setting_path) @@ -312,7 +314,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict out = safe_get_value(out, 0, "out") device_out = safe_get_value(device_out, 0, "device_out") - return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message) + return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message, + rank=0, is_fp8=is_fp8) def run_torch_api_online(api_full_name, api_data, backward_content): @@ -327,7 +330,7 @@ def run_torch_api_online(api_full_name, api_data, backward_content): device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None) device_out = exec_api(device_exec_params) device_out = move2device_exec(device_out, "cpu") - return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank) + return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank, is_fp8=False) def check_need_grad(api_info_dict): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index 60557c77d79c685dbcf8a312910816dcb06b2702..af6f1e327e606786de17f332911ede7b4b5ce2f2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -16,16 +16,18 @@ import os from collections import namedtuple import re - +import numpy as np import torch try: import torch_npu except ImportError: current_device = "cuda" from torch.cuda.amp import autocast + IS_GPU = True else: current_device = "npu" from torch_npu.npu.amp import autocast + IS_GPU = False from msprobe.core.common.const import FileCheckConst, Const, CompareConst from msprobe.core.common.file_utils import FileChecker @@ -33,6 +35,7 @@ from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate +from msprobe.pytorch.api_accuracy_checker.common.utils import is_dtype_fp8, is_hifloat8_tensor hf_32_standard_api = ["conv1d", "conv2d"] @@ -61,7 +64,7 @@ class BackwardMessage: class UtDataInfo: def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list, - backward_message, rank=0): + backward_message, rank=0, is_fp8=False): self.bench_grad = bench_grad self.device_grad = device_grad self.device_output = device_output @@ -70,6 +73,7 @@ class UtDataInfo: self.in_fwd_data_list = in_fwd_data_list self.backward_message = backward_message self.rank = rank + self.is_fp8 = is_fp8 def get_validated_result_csv_path(result_csv_path, mode): @@ -148,6 +152,10 @@ def raise_bench_data_dtype(api_name, arg, raise_dtype=None): 输出: arg: 转换dtype的标杆输入 ''' + if is_hifloat8_tensor(arg): + return hif8_to_fp32(arg) + if is_dtype_fp8(arg.dtype): + return fp8_to_fp32(arg) if api_name in hf_32_standard_api and arg.dtype == torch.float32: return arg if raise_dtype is None or arg.dtype not in PRECISION_MAPPING or raise_dtype == arg.dtype: @@ -156,13 +164,18 @@ def raise_bench_data_dtype(api_name, arg, raise_dtype=None): def generate_device_params(input_args, input_kwargs, need_backward, api_name): + is_fp8 = False + def recursive_arg_to_device(arg_in, to_detach, depth=0): + nonlocal is_fp8 if depth > Const.MAX_DEPTH: logger.error("The depth of arg_in is too large, please check the arg_in.") raise CompareException(CompareException.RECURSION_LIMIT_ERROR) if isinstance(arg_in, (list, tuple)): return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth+1) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): + if is_dtype_fp8(arg_in.dtype) or is_hifloat8_tensor(arg_in): + is_fp8 = True if need_backward and arg_in.requires_grad: arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_() temp_arg_in = arg_in * 1 @@ -178,7 +191,7 @@ def generate_device_params(input_args, input_kwargs, need_backward, api_name): device_args = recursive_arg_to_device(input_args, is_detach) device_kwargs = \ {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()} - return device_args, device_kwargs + return device_args, device_kwargs, is_fp8 def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): @@ -260,3 +273,164 @@ def is_unsupported_api(api_name, is_overflow_check=False): if flag: logger.info(f"{split_name} api is not supported for run ut. SKIP.") return flag + + +def fp8_to_fp32(x): + """ + 将FP8格式的张量转换为FP32格式,保持原始FP8的表示范围 + 使用纯PyTorch操作替代NumPy位运算 + + 参数: + x: 输入的FP8张量,可以是torch.float8_e4m3fn或torch.float8_e5m2类型 + + 返回: + torch.Tensor: 转换后的FP32张量,保持原始FP8的表示范围 + """ + if x.dtype == torch.float8_e4m3fn: + # E4M3FN格式:1符号+4指数+3尾数,偏置7 + # 位布局:SEEEEEMM + + # 将FP8值视为无符号整数进行位操作 + x_int = x.view(torch.uint8) + + # 提取符号位、指数位和尾数位 + sign_bits = (x_int & 0x80) >> 7 # 最高位是符号位 + exp_bits = (x_int & 0x78) >> 3 # 接下来4位是指数位 + mantissa_bits = x_int & 0x07 # 最后3位是尾数位 + + # 处理规格化数和非规格化数 + is_normal = exp_bits != 0 + + # 计算FP32的指数部分(偏置127) + fp32_exp = torch.where( + is_normal, + (exp_bits - 7 + 127).to(torch.int32), # 规格化数:指数 = 原始指数 + 120 + torch.tensor(0, dtype=torch.int32, device=x.device) # 非规格化数:指数为0 + ) + + # 计算FP32的尾数部分 + # 规格化数:隐含1,尾数 = 1.0 + 原始尾数 * 2^(-3) + # 非规格化数:无隐含1,尾数 = 0.0 + 原始尾数 * 2^(-3) + fp32_mantissa = torch.where( + is_normal, + 1.0 + mantissa_bits.to(torch.float32) / 8.0, # 2^(-3) = 1/8 + mantissa_bits.to(torch.float32) / 8.0 + ) + + # 计算符号值 (-1)^sign + sign_value = torch.pow(-1.0, sign_bits.to(torch.float32)) + + # 计算最终FP32值 + # 规格化数:value = (-1)^sign * (1.0 + mantissa/8) * 2^(exp - 7) + # 非规格化数:value = (-1)^sign * (mantissa/8) * 2^(-6) + fp32_result = sign_value * fp32_mantissa * torch.pow(2.0, fp32_exp - 127) + + return fp32_result + + elif x.dtype == torch.float8_e5m2: + # E5M2格式:1符号+5指数+2尾数,偏置15 + # 位布局:SEEEEEEM + + # 将FP8值视为无符号整数进行位操作 + x_int = x.view(torch.uint8) + + # 提取符号位、指数位和尾数位 + sign_bits = (x_int & 0x80) >> 7 # 最高位是符号位 + exp_bits = (x_int & 0x7C) >> 2 # 接下来5位是指数位 + mantissa_bits = x_int & 0x03 # 最后2位是尾数位 + + # 处理规格化数和非规格化数 + is_normal = exp_bits != 0 + + # 计算FP32的指数部分(偏置127) + fp32_exp = torch.where( + is_normal, + (exp_bits - 15 + 127).to(torch.int32), # 规格化数:指数 = 原始指数 + 112 + torch.tensor(0, dtype=torch.int32, device=x.device) # 非规格化数:指数为0 + ) + + # 计算FP32的尾数部分 + # 规格化数:隐含1,尾数 = 1.0 + 原始尾数 * 2^(-2) + # 非规格化数:无隐含1,尾数 = 0.0 + 原始尾数 * 2^(-2) + fp32_mantissa = torch.where( + is_normal, + 1.0 + mantissa_bits.to(torch.float32) / 4.0, # 2^(-2) = 1/4 + mantissa_bits.to(torch.float32) / 4.0 + ) + + # 计算符号值 (-1)^sign + sign_value = torch.pow(-1.0, sign_bits.to(torch.float32)) + + # 计算最终FP32值 + fp32_result = sign_value * fp32_mantissa * torch.pow(2.0, fp32_exp - 127) + + return fp32_result + + else: + raise ValueError(f"Unsupported dtype: {x.dtype}. Expected torch.float8_e4m3fn or torch.float8_e5m2.") + + +def hif8_to_fp32(x): + """ + 将HiFloat8格式的张量转换为FP32格式,保持原始HiFloat8的表示范围 + 使用纯PyTorch操作替代NumPy位运算 + + 参数: + x: 输入的HiFloat8张量,可以是torch_npu.HiFloat8Tensor类型 + + 返回: + torch.Tensor: 转换后的FP32张量,保持原始HiFloat8的表示范围 + """ + requires_grad = x.requires_grad + x = x.cpu().detach().numpy() + x = np.array(x) # 确保输入是numpy数组 + + # 创建结果数组,保持与输入相同的形状 + res = np.zeros_like(x, dtype=np.float32) + + # 获取输入张量的所有维度 + dimensions = x.shape + # 计算总元素数量 + total_elements = np.prod(dimensions) + + # 遍历每个元素 + for idx in range(total_elements): + # 将一维索引转换为多维索引 + multi_indices = np.unravel_index(idx, dimensions) + z = x[multi_indices] + + # 处理特殊值 + if np.isnan(z) or np.isinf(z): + res[multi_indices] = z + continue + + # 提取符号位 + s = 1.0 if z >= 0 else -1.0 + tmp = abs(z) + + # 处理零值 + if tmp == 0: + res[multi_indices] = 0.0 + continue + + # 确定指数范围和尾数位数 + exponent = np.floor(np.log2(tmp + 1e-100)) # 添加小常量避免log2(0) + eabs = abs(exponent) + + # 根据指数范围确定尾数位数和还原规则 + if eabs <= 3: # 3-bit Mantissa + mantissa = (tmp / (2.0 ** exponent)) * 8.0 # 还原尾数部分 + res[multi_indices] = s * (mantissa / 8.0) * (2.0 ** exponent) + elif eabs <= 7: # 2-bit Mantissa + mantissa = (tmp / (2.0 ** exponent)) * 4.0 + res[multi_indices] = s * (mantissa / 4.0) * (2.0 ** exponent) + elif eabs <= 15: # 1-bit Mantissa + mantissa = (tmp / (2.0 ** exponent)) * 2.0 + res[multi_indices] = s * (mantissa / 2.0) * (2.0 ** exponent) + else: # 0-bit Mantissa + res[multi_indices] = s * (2.0 ** exponent) + + res = torch.from_numpy(res) + if requires_grad: + res = res.requires_grad_() + return res diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py index 6fc36bcdecac81ae302ec9fd64079758f74e4071..ed972db09e29ea4f26fca950789d56fc90f51dfc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py @@ -106,7 +106,7 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_ cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None) cpu_out = exec_api(cpu_exec_params) - npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank) + npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank, is_fp8=False) npu_detail = compare.compare_output(api_full_name, npu_data_info, True) npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1]) @@ -114,7 +114,7 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_ api_data_gpu = move2target_device(api_data, device) # args, kwargs -> gpu, result -> npu data_info = func(api_full_name, api_data_gpu, config.backward_content) gpu_out = data_info.bench_output - gpu_data_info = UtDataInfo(None, None, gpu_out, cpu_out, None, [], None, rank=api_data.rank) + gpu_data_info = UtDataInfo(None, None, gpu_out, cpu_out, None, [], None, rank=api_data.rank, is_fp8=False) gpu_detail = compare.compare_output(api_full_name, gpu_data_info, True) gpu_data = pd.DataFrame(gpu_detail, columns=DETAIL_TEST_ROWS[-1]) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py index fc05f9d469a811b41feda4baf8e05a61c63b7e6d..2ab38eb56dcc5581254faf8c00c3c211f607e644 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py @@ -72,7 +72,7 @@ class TestCompare(unittest.TestCase): def test_compare_core_wrapper(self): dummy_input = torch.randn(100, 100) bench_out, npu_out = dummy_input, dummy_input - test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out) + test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out, False) actual_cosine_similarity = detailed_result_total[0][3] # 设置一个小的公差值 tolerance = 1e-4 @@ -86,7 +86,7 @@ class TestCompare(unittest.TestCase): self.assertTrue(test_final_success) bench_out, npu_out = [dummy_input, dummy_input], [dummy_input, dummy_input] - test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out) + test_final_success, detailed_result_total = self.compare._compare_core_wrapper("api", bench_out, npu_out, False) actual_cosine_similarity = detailed_result_total[0][3] self.assertTrue(np.isclose(actual_cosine_similarity, 1.0, atol=tolerance)) actual_cosine_similarity = detailed_result_total[1][3] @@ -102,7 +102,7 @@ class TestCompare(unittest.TestCase): '\nMax abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n']]) def test_compare_core_different(self): - res = self.compare._compare_core('api', 1, 'str') + res = self.compare._compare_core('api', 1, 'str', False) self.assertEqual(res[0], 'error') self.assertEqual(res[2], 'bench and npu output type is different.') @@ -112,7 +112,7 @@ class TestCompare(unittest.TestCase): 'key1': 1, 'key2': 2 } - res = self.compare._compare_core('api', output_dict, output_dict) + res = self.compare._compare_core('api', output_dict, output_dict, False) self.assertEqual(res[0], 'error') self.assertEqual(res[2], "Unexpected output type in compare_core: ") @@ -126,27 +126,27 @@ class TestCompare(unittest.TestCase): 'key3': 3, 'key4': 4 } - res = self.compare._compare_core('api', bench_dict, device_dict) + res = self.compare._compare_core('api', bench_dict, device_dict, False) self.assertEqual(res[0], 'error') self.assertEqual(res[2], 'bench and npu output dict keys are different.') def test_compare_core_with_tensor(self): tensor = torch.tensor([1, 2, 3]) - res = self.compare._compare_core('api', tensor, tensor) + res = self.compare._compare_core('api', tensor, tensor, False) self.assertEqual(res[0], 'pass') self.assertEqual(res[2], 'Compare algorithm is not supported for int64 data. Only judged by Error Rate.\n') def test_compare_core_with_buildin(self): interger = 1 - res = self.compare._compare_core('api', interger, interger) + res = self.compare._compare_core('api', interger, interger, False) self.assertEqual(res[0], 'pass') self.assertEqual(res[2], '') def test_compare_core_with_none(self): - res = self.compare._compare_core('api', None, None) + res = self.compare._compare_core('api', None, None, False) self.assertEqual(res[0], 'SKIP') self.assertEqual(res[2], 'Bench output is None, skip this test.') @@ -155,7 +155,7 @@ class TestCompare(unittest.TestCase): bench_out, npu_out = torch.randn(100, 100), torch.randn(100, 100) bench_grad, npu_grad = [torch.randn(100, 100)], [torch.randn(100, 100)] api_name = 'Functional.conv2d.0' - data_info = UtDataInfo(bench_grad, npu_grad, bench_out, npu_out, None, None, None) + data_info = UtDataInfo(bench_grad, npu_grad, bench_out, npu_out, None, None, None, False) is_fwd_success, is_bwd_success = self.compare.compare_output(api_name, data_info) self.assertFalse(is_fwd_success) # is_bwd_success should be checked @@ -210,7 +210,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.Tensor([1.0, 2.0, 3.0]) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "pass") def test_compare_torch_tensor_bf16(self): @@ -218,7 +218,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.bfloat16) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "pass") def test_compare_torch_tensor_different_shape(self): @@ -226,7 +226,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.Tensor([1.0, 2.0, 3.0]) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "error") def test_compare_torch_tensor_different_dtype(self): @@ -234,7 +234,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.Tensor([1.0, 2.0, 3.0]) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "error") def test_compare_torch_tensor_special_dtype(self): @@ -242,7 +242,7 @@ class TestCompare(unittest.TestCase): npu_output = torch.Tensor([True, True, False]) compare_column = CompareColumn() status, compare_column, message = self.compare._compare_torch_tensor("api", cpu_output, npu_output, - compare_column) + compare_column, False) self.assertEqual(status, "pass") def test_compare_builtin_type_pass_with_special_types(self): @@ -270,27 +270,39 @@ class TestCompare(unittest.TestCase): cpu_output = torch.Tensor([1.0, 2.0, 3.0]) npu_output = torch.Tensor([1.0, 2.0, 3.0]) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("conv2d", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") def test_compare_float_tensor_binary(self): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16) npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("abs", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") def test_compare_float_tensor_absolute(self): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("mul", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") @@ -298,9 +310,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") @@ -308,9 +324,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16) npu_output = torch.tensor([1.1, 2.1, 3.1], dtype=torch.float16) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "error") @@ -318,9 +338,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16) npu_output = torch.tensor([1.0001, 2.0001, 3.0001], dtype=torch.float16) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") @@ -328,9 +352,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16) npu_output = torch.tensor([1.01, 2.01, 3.01], dtype=torch.float16) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "Warning") @@ -338,9 +366,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) npu_output = torch.tensor([1.01, 2.01, 3.01], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "error") @@ -348,9 +380,13 @@ class TestCompare(unittest.TestCase): cpu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) npu_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) compare_column = CompareColumn() + in_and_out_dtype ={ + 'dtype': npu_output.dtype, + 'in_dtype': torch.float32 + } status, compare_column, message = self.compare._compare_float_tensor("__matmul__", cpu_output.numpy(), npu_output.numpy(), - compare_column, npu_output.dtype) + compare_column, in_and_out_dtype) self.assertEqual(status, "pass") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_register.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_register.py index 0a776348933361e0934d6bef2dded9dca124bc02..408d83e462cb2b39d3db5902f1645e44a6271f26 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_register.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_register.py @@ -27,7 +27,7 @@ class TestStandardRegistry(unittest.TestCase): mock_func = Mock() self.registry.register("binary_consistency", mock_func) # 使用支持二进制比较的数据类型 - result = self.registry.get_comparison_function("abs", dtype='torch.int8') + result = self.registry.get_comparison_function("abs", dtype='torch.int8', in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_comparison_function_absolute_threshold(self): @@ -35,55 +35,57 @@ class TestStandardRegistry(unittest.TestCase): mock_func = Mock() self.registry.register("absolute_threshold", mock_func) # 假设'test_api'在absolute_standard_api列表中 - result = self.registry.get_comparison_function("mul") + result = self.registry.get_comparison_function("mul", dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_comparison_function_ulp(self): """测试获取ULP比较函数""" mock_func = Mock() self.registry.register("ulp_compare", mock_func) - result = self.registry.get_comparison_function("matmul") + result = self.registry.get_comparison_function("matmul", dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_comparison_function_thousandth(self): """测试获取双千比较函数""" mock_func = Mock() self.registry.register("thousandth_threshold", mock_func) - result = self.registry.get_comparison_function("conv2d") + result = self.registry.get_comparison_function("conv2d", dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_comparison_function_benchmark(self): """测试获取默认benchmark比较函数""" mock_func = Mock() self.registry.register("benchmark", mock_func) - result = self.registry.get_comparison_function("npu_fusion_attention") + result = self.registry.get_comparison_function("npu_fusion_attention", dtype='torch.float32', + in_dtype='torch.float32') self.assertEqual(result, mock_func) def test_get_standard_category_binary(self): """测试获取二进制一致性标准类别""" dtype = 'torch.int8' self.assertNotIn(dtype, BINARY_COMPARE_UNSUPPORT_LIST) - category = self.registry._get_standard_category("abs", dtype) + category = self.registry._get_standard_category("abs", out_dtype=dtype, in_dtype='torch.float32') self.assertEqual(category, "binary_consistency") def test_get_standard_category_absolute(self): """测试获取绝对阈值标准类别""" - category = self.registry._get_standard_category("mul") + category = self.registry._get_standard_category("mul", out_dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(category, "absolute_threshold") def test_get_standard_category_default(self): """测试获取默认benchmark标准类别""" - category = self.registry._get_standard_category("unknown_api") + category = self.registry._get_standard_category("unknown_api", out_dtype='torch.float32', + in_dtype='torch.float32') self.assertEqual(category, "benchmark") def test_get_standard_category_ulp(self): """测试获取ULP标准类别""" - category = self.registry._get_standard_category("matmul") + category = self.registry._get_standard_category("matmul", out_dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(category, "ulp_compare") def test_get_standard_category_thousandth(self): """测试获取双千比对标准类别""" - category = self.registry._get_standard_category("conv2d") + category = self.registry._get_standard_category("conv2d", out_dtype='torch.float32', in_dtype='torch.float32') self.assertEqual(category, "thousandth_threshold") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py index 13bf0a5b19c49576101c8f4daf0d609ee625aefe..0ab34b9288c3829ee3859fa551e703bf0d7f6eed 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py @@ -170,7 +170,7 @@ class TestRunUtMethods(unittest.TestCase): mocks['retain_grad'].return_value = None mocks['to'].return_value = mock_tensor - device_args, device_kwargs = generate_device_params([mock_tensor], {'inplace': False}, True, '') + device_args, device_kwargs, _ = generate_device_params([mock_tensor], {'inplace': False}, True, '') self.assertEqual(len(device_args), 1) self.assertEqual(device_args[0].dtype, torch.float32) self.assertTrue(device_args[0].requires_grad)