From d7f7e256ff7e51a4f14de1f253335dbcb45a8472 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Mon, 11 Mar 2024 09:21:42 +0000 Subject: [PATCH 1/4] bugfix Signed-off-by: sunyiming --- .../api_accuracy_checker/run_ut/run_ut.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index ed7ce7e11..c88cd8bb9 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -118,23 +118,26 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): else: return arg_in - def recursive_find_dtypes(arg_in): + def recursive_find_dtypes(arg_in, check_kwargs=False): if isinstance(arg_in, (list, tuple)): - return set().union(*tuple(recursive_find_dtypes(arg) for arg in arg_in)) - elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: + return set().union(*tuple(recursive_find_dtypes(arg, check_kwargs=check_kwargs) for arg in arg_in)) + elif isinstance(arg_in, torch.Tensor) and (arg_in.dtype in Const.RAISE_PRECISION or (check_kwargs and arg_in.dtype == torch.half)): return set([arg_in.dtype]) + elif isinstance(arg_in, dict) and check_kwargs: + return set().union(*tuple(recursive_find_dtypes(v, check_kwargs=True) for v in arg_in.values())) return set() raise_dtype = None need_raise_dtypes = recursive_find_dtypes(input_args) + need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True)) if len(need_raise_dtypes) == 1: - raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop()) + raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32) elif len(need_raise_dtypes) >= 2: raise_dtype = torch.float32 is_detach = api_name not in not_detach_set cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype) - cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach) for key, value in input_kwargs.items()} + cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()} return cpu_args, cpu_kwargs -- Gitee From 9692bb79b8a5de3d85d86b52859b7feedefb2933 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Tue, 12 Mar 2024 01:04:18 +0000 Subject: [PATCH 2/4] clean code Signed-off-by: sunyiming --- .../accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index c88cd8bb9..04fc0c0ee 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -118,10 +118,17 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): else: return arg_in + def is_tensor_with_raise_precision(arg_in, check_kwargs=False): + if arg_in.dtype in Const.RAISE_PRECISION: + return True + if check_kwargs and arg_in.dtype == torch.half: + return True + return False + def recursive_find_dtypes(arg_in, check_kwargs=False): if isinstance(arg_in, (list, tuple)): return set().union(*tuple(recursive_find_dtypes(arg, check_kwargs=check_kwargs) for arg in arg_in)) - elif isinstance(arg_in, torch.Tensor) and (arg_in.dtype in Const.RAISE_PRECISION or (check_kwargs and arg_in.dtype == torch.half)): + elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs): return set([arg_in.dtype]) elif isinstance(arg_in, dict) and check_kwargs: return set().union(*tuple(recursive_find_dtypes(v, check_kwargs=True) for v in arg_in.values())) -- Gitee From 18a36a16c3827741029bc0596846913809189492 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Tue, 12 Mar 2024 01:36:11 +0000 Subject: [PATCH 3/4] update Signed-off-by: sunyiming --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 04fc0c0ee..e05c622c5 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -121,7 +121,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): def is_tensor_with_raise_precision(arg_in, check_kwargs=False): if arg_in.dtype in Const.RAISE_PRECISION: return True - if check_kwargs and arg_in.dtype == torch.half: + if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]: return True return False -- Gitee From 0998236cb6597e88ab9576c55f97aceeb13d4639 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Wed, 13 Mar 2024 09:13:46 +0000 Subject: [PATCH 4/4] update debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py. Signed-off-by: sunyiming --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e05c622c5..856cb237c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -125,13 +125,13 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): return True return False - def recursive_find_dtypes(arg_in, check_kwargs=False): + def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False): if isinstance(arg_in, (list, tuple)): - return set().union(*tuple(recursive_find_dtypes(arg, check_kwargs=check_kwargs) for arg in arg_in)) + return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in)) elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs): return set([arg_in.dtype]) elif isinstance(arg_in, dict) and check_kwargs: - return set().union(*tuple(recursive_find_dtypes(v, check_kwargs=True) for v in arg_in.values())) + return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values())) return set() raise_dtype = None -- Gitee