diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index f1c874407fcb8169524c72940b4785859b2455c6..4ee04396ba4f1f981488f0f713fa74fd03213e56 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -81,11 +81,13 @@ class Const: WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR CONVERT = { - "fp16_to_fp32": ["torch.float16", "torch.float32"] + "fp16_to_fp32": ["torch.float16", "torch.float32"], + "int32_to_int64": ["torch.int32", "torch.int64"] } CONVERT_API = { - "fp16_to_fp32": ["conv2d", "batch_norm", "relu", "max_pool2d", "interpolate", "group_norm", "layer_norm", "bmm", "tanh", "cross_entropy", "linear", "numel"] + "fp16_to_fp32": ["conv2d", "batch_norm", "relu", "max_pool2d", "interpolate", "group_norm", "layer_norm", "bmm", "tanh", "cross_entropy", "linear", "numel"], + "int32_to_int64": ["cross_entropy"] } class CompareConst: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 52d5d928d045913112f03a5093d936dc1627428b..e52d447a91eaa225345702e6306d2ba0784e275e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -71,10 +71,13 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): api_info_dict) compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) except Exception as err: + [_, api_name, _] = api_full_name.split("*") if "not implemented for 'Half'" in str(err): - [_, api_name, _] = api_full_name.split("*") print_warn_log(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API " f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.") + elif "expected scalar type Long" in str(err): + print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " + f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") else: print_error_log(f"Run {api_full_name} UT Error: %s" % str(err))