diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml b/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml index 4033538b73e9b6a094bbf2c05e0c02bbed607c24..45051a3ddc55750d4e89fc34319b8e672405a6c3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml @@ -106,3 +106,4 @@ BinaryCompareStandard: - tril_ - triu - triu_ + - type_as diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 856cb237ca7e1ce71da45938be2284c2e6133a2b..5891a0316339dd415cc9d05090056ac465f78f0a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -39,6 +39,7 @@ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'save_error_data', 'is_continue_run_ut', 'real_data_path']) not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} +not_raise_dtype_set = {'type_as'} tqdm_params = { 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1 @@ -142,6 +143,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): elif len(need_raise_dtypes) >= 2: raise_dtype = torch.float32 + raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype is_detach = api_name not in not_detach_set cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype) cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}