diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index f8648eb20ccd7572c39a752db287e8b63a0e096c..cc7cac770374d7a4dc8464f7a402a8688957a4c2 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -91,9 +91,9 @@ class Const: WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR RAISE_PRECISION = { - "torch.float16" : "torch.float32", - "torch.bfloat16" : "torch.float32", - "torch.float32" : "torch.float64" + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.float32: torch.float64 } CONVERT = { "int32_to_int64": ["torch.int32", "torch.int64"], diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index fb885168d662006b3fa639c2322113d6f1f6b386..7b974bea82c113e6dbadfc560adba1225a97eb54 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -55,6 +55,12 @@ def deal_detach(arg, to_detach=True): return arg.detach() if to_detach else arg +def deal_dtype(arg, raise_dtype=None): + if raise_dtype is None or arg.dtype not in Const.RAISE_PRECISION or raise_dtype == arg.dtype: + return arg + return arg.type(raise_dtype) + + def generate_device_params(input_args, input_kwargs, need_backward): def recursive_arg_to_device(arg_in, to_detach=True): if isinstance(arg_in, (list, tuple)): @@ -77,35 +83,36 @@ def generate_device_params(input_args, input_kwargs, need_backward): def generate_cpu_params(input_args, input_kwargs, need_backward): - first_dtype = None - - def recursive_arg_to_cpu(arg_in, to_detach=True): - nonlocal first_dtype + def recursive_arg_to_cpu(arg_in, to_detach=True, raise_dtype=None): if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_cpu(arg, to_detach) for arg in arg_in) + return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach).requires_grad_() - if first_dtype is None: - first_dtype = arg_in.dtype - else: - arg_in = deal_detach(arg_in.clone(), to_detach).requires_grad_() + arg_in = deal_detach(deal_dtype(arg_in.clone(), raise_dtype), to_detach).requires_grad_() temp_arg_in = arg_in * 1 arg_in = temp_arg_in.type_as(arg_in) arg_in.retain_grad() return arg_in else: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach) - if first_dtype is None: - first_dtype = arg_in.dtype - return arg_in - return deal_detach(arg_in.clone(), to_detach) + return deal_detach(deal_dtype(arg_in.clone(), raise_dtype=raise_dtype), to_detach) else: return arg_in - cpu_args = recursive_arg_to_cpu(input_args) + def recursive_find_dtypes(arg_in): + if isinstance(arg_in, (list, tuple)): + return set().union(*tuple(recursive_find_dtypes(arg) for arg in arg_in)) + elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: + return set([arg_in.dtype]) + return set() + + raise_dtype = None + need_raise_dtypes = recursive_find_dtypes(input_args) + if len(need_raise_dtypes) == 1: + raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop()) + elif len(need_raise_dtypes) >= 2: + raise_dtype = torch.float32 + + cpu_args = recursive_arg_to_cpu(input_args, raise_dtype=raise_dtype) cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out") for key, value in input_kwargs.items()} return cpu_args, cpu_kwargs