diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index ed7ce7e11d7b479b1100fd16f6deba2316430da0..856cb237ca7e1ce71da45938be2284c2e6133a2b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -118,23 +118,33 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): else: return arg_in - def recursive_find_dtypes(arg_in): + def is_tensor_with_raise_precision(arg_in, check_kwargs=False): + if arg_in.dtype in Const.RAISE_PRECISION: + return True + if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]: + return True + return False + + def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False): if isinstance(arg_in, (list, tuple)): - return set().union(*tuple(recursive_find_dtypes(arg) for arg in arg_in)) - elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: + return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in)) + elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs): return set([arg_in.dtype]) + elif isinstance(arg_in, dict) and check_kwargs: + return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values())) return set() raise_dtype = None need_raise_dtypes = recursive_find_dtypes(input_args) + need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True)) if len(need_raise_dtypes) == 1: - raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop()) + raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32) elif len(need_raise_dtypes) >= 2: raise_dtype = torch.float32 is_detach = api_name not in not_detach_set cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype) - cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach) for key, value in input_kwargs.items()} + cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()} return cpu_args, cpu_kwargs