From ac285e995a55fadd52b6ae32907cc9347146741f Mon Sep 17 00:00:00 2001 From: xieting Date: Tue, 12 Dec 2023 14:31:50 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E4=BF=AE=E6=94=B9CPU=E6=93=8D=E4=BD=9C?= =?UTF-8?q?=E6=95=B0=E6=B7=B7=E5=90=88=E7=B2=BE=E5=BA=A6=E6=83=85=E5=86=B5?= =?UTF-8?q?=E7=9A=84=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/common/utils.py | 7 ++-- .../api_accuracy_checker/run_ut/run_ut.py | 39 +++++++++++-------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index f8648eb20cc..b536625c23c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -91,9 +91,10 @@ class Const: WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR RAISE_PRECISION = { - "torch.float16" : "torch.float32", - "torch.bfloat16" : "torch.float32", - "torch.float32" : "torch.float64" + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.float32: torch.float64, + torch.float64: torch.float64 } CONVERT = { "int32_to_int64": ["torch.int32", "torch.int64"], diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index fb885168d66..ec9ed26013d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -55,6 +55,12 @@ def deal_detach(arg, to_detach=True): return arg.detach() if to_detach else arg +def deal_dtype(arg, raise_dtype=None): + if not Const.RAISE_PRECISION.get(arg.dtype) or raise_dtype is None or raise_dtype == arg.dtype: + return arg + return arg.type(raise_dtype) + + def generate_device_params(input_args, input_kwargs, need_backward): def recursive_arg_to_device(arg_in, to_detach=True): if isinstance(arg_in, (list, tuple)): @@ -77,35 +83,34 @@ def generate_device_params(input_args, input_kwargs, need_backward): def generate_cpu_params(input_args, input_kwargs, need_backward): - first_dtype = None - def recursive_arg_to_cpu(arg_in, to_detach=True): - nonlocal first_dtype + def recursive_arg_to_cpu(arg_in, to_detach=True, raise_dtype=None): if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_cpu(arg, to_detach) for arg in arg_in) + return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach).requires_grad_() - if first_dtype is None: - first_dtype = arg_in.dtype - else: - arg_in = deal_detach(arg_in.clone(), to_detach).requires_grad_() + arg_in = deal_detach(deal_dtype(arg_in.clone(), raise_dtype), to_detach).requires_grad_() temp_arg_in = arg_in * 1 arg_in = temp_arg_in.type_as(arg_in) arg_in.retain_grad() return arg_in else: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach) - if first_dtype is None: - first_dtype = arg_in.dtype - return arg_in - return deal_detach(arg_in.clone(), to_detach) + return deal_detach(deal_dtype(arg_in.clone(), raise_dtype=raise_dtype), to_detach) else: return arg_in - cpu_args = recursive_arg_to_cpu(input_args) + raise_dtype = None + need_raise_dtypes = set(input_arg.dtype for input_arg in input_args if input_arg.dtype in Const.RAISE_PRECISION) + if len(need_raise_dtypes) == 1: + raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop()) + elif len(need_raise_dtypes) >= 2: + raise_dtype = torch.float32 + for dtype in need_raise_dtypes: + if str(dtype).endswith("64"): + raise_dtype = torch.float64 + break + + cpu_args = recursive_arg_to_cpu(input_args, raise_dtype=raise_dtype) cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out") for key, value in input_kwargs.items()} return cpu_args, cpu_kwargs -- Gitee From 061b7f2dcf2c5e734259dd0fffb0ac69b27ea02f Mon Sep 17 00:00:00 2001 From: xieting Date: Tue, 12 Dec 2023 16:33:34 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E6=B7=BB=E5=8A=A0Tensor=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index ec9ed26013d..2bcec0c21ed 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -56,7 +56,7 @@ def deal_detach(arg, to_detach=True): def deal_dtype(arg, raise_dtype=None): - if not Const.RAISE_PRECISION.get(arg.dtype) or raise_dtype is None or raise_dtype == arg.dtype: + if raise_dtype is None or not Const.RAISE_PRECISION.get(arg.dtype) or raise_dtype == arg.dtype: return arg return arg.type(raise_dtype) @@ -100,7 +100,11 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): return arg_in raise_dtype = None - need_raise_dtypes = set(input_arg.dtype for input_arg in input_args if input_arg.dtype in Const.RAISE_PRECISION) + need_raise_dtypes = set( + input_arg.dtype + for input_arg in input_args + if isinstance(input_arg, torch.Tensor) and input_arg.dtype in Const.RAISE_PRECISION + ) if len(need_raise_dtypes) == 1: raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop()) elif len(need_raise_dtypes) >= 2: -- Gitee From 6cae1f3d7fc58d88048e7f5dd8d7fd30ba8d66ba Mon Sep 17 00:00:00 2001 From: xieting Date: Tue, 12 Dec 2023 17:59:44 +0800 Subject: [PATCH 3/8] =?UTF-8?q?=E5=A4=84=E7=90=86backward=E8=B0=83?= =?UTF-8?q?=E7=94=A8generate=5Fcpu=5Fparams=E5=8F=AA=E4=BC=A0=E5=85=A5Tens?= =?UTF-8?q?or=E8=80=8C=E4=B8=8D=E6=98=AFargs=E7=9A=84=E6=83=85=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 2bcec0c21ed..bcc4dc0be4f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -102,7 +102,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): raise_dtype = None need_raise_dtypes = set( input_arg.dtype - for input_arg in input_args + for input_arg in (input_args if isinstance(input_args, (list, tuple)) else [input_args]) if isinstance(input_arg, torch.Tensor) and input_arg.dtype in Const.RAISE_PRECISION ) if len(need_raise_dtypes) == 1: -- Gitee From 8289eab9e732cdd09f8fc6a5074829c6a6cb422d Mon Sep 17 00:00:00 2001 From: xieting Date: Wed, 13 Dec 2023 15:22:48 +0800 Subject: [PATCH 4/8] edit --- .../api_accuracy_checker/common/utils.py | 3 +-- .../api_accuracy_checker/run_ut/run_ut.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index b536625c23c..cc7cac77037 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -93,8 +93,7 @@ class Const: RAISE_PRECISION = { torch.float16: torch.float32, torch.bfloat16: torch.float32, - torch.float32: torch.float64, - torch.float64: torch.float64 + torch.float32: torch.float64 } CONVERT = { "int32_to_int64": ["torch.int32", "torch.int64"], diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index bcc4dc0be4f..747659f95a8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -56,7 +56,7 @@ def deal_detach(arg, to_detach=True): def deal_dtype(arg, raise_dtype=None): - if raise_dtype is None or not Const.RAISE_PRECISION.get(arg.dtype) or raise_dtype == arg.dtype: + if raise_dtype is None or arg.dtype not in Const.RAISE_PRECISION or raise_dtype == arg.dtype: return arg return arg.type(raise_dtype) @@ -83,7 +83,6 @@ def generate_device_params(input_args, input_kwargs, need_backward): def generate_cpu_params(input_args, input_kwargs, need_backward): - def recursive_arg_to_cpu(arg_in, to_detach=True, raise_dtype=None): if isinstance(arg_in, (list, tuple)): return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in) @@ -99,20 +98,18 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): else: return arg_in + def recursive_find_dtypes(arg_in): + if isinstance(arg_in, (list, tuple)): + return set().union(recursive_find_dtypes(arg) for arg in arg_in) + elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: + return {arg_in.dtype} + raise_dtype = None - need_raise_dtypes = set( - input_arg.dtype - for input_arg in (input_args if isinstance(input_args, (list, tuple)) else [input_args]) - if isinstance(input_arg, torch.Tensor) and input_arg.dtype in Const.RAISE_PRECISION - ) + need_raise_dtypes = recursive_find_dtypes(input_args) if len(need_raise_dtypes) == 1: raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop()) elif len(need_raise_dtypes) >= 2: raise_dtype = torch.float32 - for dtype in need_raise_dtypes: - if str(dtype).endswith("64"): - raise_dtype = torch.float64 - break cpu_args = recursive_arg_to_cpu(input_args, raise_dtype=raise_dtype) cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out") for key, value in input_kwargs.items()} -- Gitee From 113b685283d8725f65c70e989512ef39aa2f37bb Mon Sep 17 00:00:00 2001 From: xieting Date: Wed, 13 Dec 2023 15:46:16 +0800 Subject: [PATCH 5/8] edit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 747659f95a8..5a600224ec8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -100,7 +100,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): def recursive_find_dtypes(arg_in): if isinstance(arg_in, (list, tuple)): - return set().union(recursive_find_dtypes(arg) for arg in arg_in) + return set().union(*tuple(recursive_find_dtypes(arg) for arg in arg_in)) elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: return {arg_in.dtype} -- Gitee From 550c3c6d48dce7df20eecf6c8675405ee038b385 Mon Sep 17 00:00:00 2001 From: xieting Date: Wed, 13 Dec 2023 16:06:17 +0800 Subject: [PATCH 6/8] edit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 1 + 1 file changed, 1 insertion(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 5a600224ec8..61bf90be6f4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -103,6 +103,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): return set().union(*tuple(recursive_find_dtypes(arg) for arg in arg_in)) elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: return {arg_in.dtype} + return set() raise_dtype = None need_raise_dtypes = recursive_find_dtypes(input_args) -- Gitee From df5ec378b636d1d26f7928e33544a3de06061a9c Mon Sep 17 00:00:00 2001 From: xieting Date: Wed, 13 Dec 2023 16:06:44 +0800 Subject: [PATCH 7/8] edit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 61bf90be6f4..7d10883975d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -102,7 +102,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): if isinstance(arg_in, (list, tuple)): return set().union(*tuple(recursive_find_dtypes(arg) for arg in arg_in)) elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: - return {arg_in.dtype} + return set((arg_in.dtype, )) return set() raise_dtype = None -- Gitee From 66706dce8d61900af94f8ad27465ffeeb36c25ba Mon Sep 17 00:00:00 2001 From: xieting Date: Wed, 13 Dec 2023 16:07:12 +0800 Subject: [PATCH 8/8] edit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 7d10883975d..7b974bea82c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -102,7 +102,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): if isinstance(arg_in, (list, tuple)): return set().union(*tuple(recursive_find_dtypes(arg) for arg in arg_in)) elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: - return set((arg_in.dtype, )) + return set([arg_in.dtype]) return set() raise_dtype = None -- Gitee