From e083e8dde48dee0806dcc726584af201044b9b06 Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 16 Apr 2024 16:37:27 +0800 Subject: [PATCH] fix --- .../api_accuracy_checker/compare/api_precision_standard.yaml | 1 + debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml b/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml index 4033538b7..45051a3dd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml @@ -106,3 +106,4 @@ BinaryCompareStandard: - tril_ - triu - triu_ + - type_as diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 856cb237c..5891a0316 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -39,6 +39,7 @@ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'save_error_data', 'is_continue_run_ut', 'real_data_path']) not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} +not_raise_dtype_set = {'type_as'} tqdm_params = { 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1 @@ -142,6 +143,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): elif len(need_raise_dtypes) >= 2: raise_dtype = torch.float32 + raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype is_detach = api_name not in not_detach_set cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype) cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()} -- Gitee