From 39b13279b342b009edabf2322d87b3554468a00f Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 9 Aug 2023 15:52:51 +0800 Subject: [PATCH 1/9] cross_entropy modify min value --- .../api_accuracy_checker/common/utils.py | 11 ++++++----- .../api_accuracy_checker/run_ut/run_ut.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index f1c874407..ab947e9a1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -549,12 +549,13 @@ def check_input_file_valid(input_path, max_file_size=MAX_JSON_FILE_SIZE): raise ValueError(f'The file is too large, exceeds {max_file_size // 1024 ** 2}MB') -def check_need_convert(api_name): +def check_need_convert(api_name,value): convert_type = None - for key, value in Const.CONVERT_API.items(): - if api_name not in value: + for key, item in Const.CONVERT_API.items(): + if api_name not in item: continue else: convert_type = key - return convert_type - + if api_name=='cross_entropy' and value['args'][1]['Min'] <=0: + value['args'][1]['Min'] = 0 + return convert_type, value \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e6162ca8b..e008a4190 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -79,7 +79,7 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): def run_torch_api(api_full_name, api_setting_dict, backward_content, value): [api_type, api_name, _] = api_full_name.split("*") - convert_type = check_need_convert(api_name) + convert_type,value = check_need_convert(api_name, value) need_grad = True if value.get("kwargs") and "out" in value.get("kwargs"): need_grad = False -- Gitee From aeadf0ffaf2a513c792feb9152c0b89258f82209 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 9 Aug 2023 16:08:57 +0800 Subject: [PATCH 2/9] api_info_preprocess --- debug/accuracy_tools/api_accuracy_checker/common/utils.py | 6 +++++- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index ab947e9a1..80be0f660 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -549,13 +549,17 @@ def check_input_file_valid(input_path, max_file_size=MAX_JSON_FILE_SIZE): raise ValueError(f'The file is too large, exceeds {max_file_size // 1024 ** 2}MB') -def check_need_convert(api_name,value): +def check_need_convert(api_name): convert_type = None for key, item in Const.CONVERT_API.items(): if api_name not in item: continue else: convert_type = key + return convert_type + +def api_info_preprocess(api_name,value): + convert_type=check_need_convert(api_name) if api_name=='cross_entropy' and value['args'][1]['Min'] <=0: value['args'][1]['Min'] = 0 return convert_type, value \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e008a4190..a2850a334 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -5,7 +5,7 @@ import torch_npu import yaml import torch from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, check_need_convert, \ +from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ print_error_log from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate @@ -79,7 +79,7 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): def run_torch_api(api_full_name, api_setting_dict, backward_content, value): [api_type, api_name, _] = api_full_name.split("*") - convert_type,value = check_need_convert(api_name, value) + convert_type,value = api_info_preprocess(api_name, value) need_grad = True if value.get("kwargs") and "out" in value.get("kwargs"): need_grad = False -- Gitee From dbc42996fa49550903ae024e36b07f63f730478a Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 9 Aug 2023 16:09:57 +0800 Subject: [PATCH 3/9] change name --- debug/accuracy_tools/api_accuracy_checker/common/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 80be0f660..76a6842e0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -551,8 +551,8 @@ def check_input_file_valid(input_path, max_file_size=MAX_JSON_FILE_SIZE): def check_need_convert(api_name): convert_type = None - for key, item in Const.CONVERT_API.items(): - if api_name not in item: + for key, value in Const.CONVERT_API.items(): + if api_name not in value: continue else: convert_type = key -- Gitee From fa137ae6af7f365aa09534c12a264e9d0e25052c Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 9 Aug 2023 16:16:07 +0800 Subject: [PATCH 4/9] add reason --- .../api_accuracy_checker/common/utils.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 76a6842e0..769f236a3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -558,8 +558,18 @@ def check_need_convert(api_name): convert_type = key return convert_type -def api_info_preprocess(api_name,value): - convert_type=check_need_convert(api_name) - if api_name=='cross_entropy' and value['args'][1]['Min'] <=0: - value['args'][1]['Min'] = 0 +def api_info_preprocess(api_name, value): + """ + Function Description: + Preprocesses the API information. + Parameter: + api_name: Name of the API. + value: Value of the API. + Return Value: + convert_type: Type of conversion. + value: Processed value of the API. + """ + convert_type = check_need_convert(api_name) + if api_name == 'cross_entropy' and value['args'][1]['Min'] <=0: + value['args'][1]['Min'] = 0#The second value in cross_entropy only can be -100 or larger than 0. return convert_type, value \ No newline at end of file -- Gitee From b35b1db780fbad7434386caef8a2979e108369a2 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 9 Aug 2023 16:17:20 +0800 Subject: [PATCH 5/9] space --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index a2850a334..b7709d0b7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -79,7 +79,7 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): def run_torch_api(api_full_name, api_setting_dict, backward_content, value): [api_type, api_name, _] = api_full_name.split("*") - convert_type,value = api_info_preprocess(api_name, value) + convert_type, value = api_info_preprocess(api_name, value) need_grad = True if value.get("kwargs") and "out" in value.get("kwargs"): need_grad = False -- Gitee From 25f0d930e5bcdac0da3420e45a671667fbd1b936 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 9 Aug 2023 16:19:32 +0800 Subject: [PATCH 6/9] space --- debug/accuracy_tools/api_accuracy_checker/common/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 769f236a3..b49892cce 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -571,5 +571,5 @@ def api_info_preprocess(api_name, value): """ convert_type = check_need_convert(api_name) if api_name == 'cross_entropy' and value['args'][1]['Min'] <=0: - value['args'][1]['Min'] = 0#The second value in cross_entropy only can be -100 or larger than 0. + value['args'][1]['Min'] = 0 #The second value in cross_entropy should be -100 or not less than 0. return convert_type, value \ No newline at end of file -- Gitee From 55b11b12bcee1d5adbec55b4a7e284dec867519d Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 9 Aug 2023 16:23:06 +0800 Subject: [PATCH 7/9] argument --- .../api_accuracy_checker/common/utils.py | 14 +++++++------- .../api_accuracy_checker/run_ut/run_ut.py | 8 ++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index b49892cce..d5a21ef76 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -558,18 +558,18 @@ def check_need_convert(api_name): convert_type = key return convert_type -def api_info_preprocess(api_name, value): +def api_info_preprocess(api_name, argument): """ Function Description: Preprocesses the API information. Parameter: api_name: Name of the API. - value: Value of the API. - Return Value: + argument: argument of the API. + Return argument: convert_type: Type of conversion. - value: Processed value of the API. + argument: Processed argument of the API. """ convert_type = check_need_convert(api_name) - if api_name == 'cross_entropy' and value['args'][1]['Min'] <=0: - value['args'][1]['Min'] = 0 #The second value in cross_entropy should be -100 or not less than 0. - return convert_type, value \ No newline at end of file + if api_name == 'cross_entropy' and argument['args'][1]['Min'] <=0: + argument['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. + return convert_type, argument \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index b7709d0b7..861f46b56 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -77,15 +77,15 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare.write_compare_csv() -def run_torch_api(api_full_name, api_setting_dict, backward_content, value): +def run_torch_api(api_full_name, api_setting_dict, backward_content, argument): [api_type, api_name, _] = api_full_name.split("*") - convert_type, value = api_info_preprocess(api_name, value) + convert_type, argument = api_info_preprocess(api_name, argument) need_grad = True - if value.get("kwargs") and "out" in value.get("kwargs"): + if argument.get("kwargs") and "out" in argument.get("kwargs"): need_grad = False if api_name[-1] == "_" or api_name in NO_GRAD_APIS: need_grad = False - args, kwargs = gen_api_params(value, need_grad, convert_type) + args, kwargs = gen_api_params(argument, need_grad, convert_type) inplace = kwargs.get("inplace") if kwargs.get("inplace") else None need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True need_backward = need_backward and need_grad -- Gitee From 9bf31976c60144ebe0c52a68af70cdd48c0165fe Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 9 Aug 2023 16:24:40 +0800 Subject: [PATCH 8/9] api_info_dict --- .../api_accuracy_checker/common/utils.py | 14 +++++++------- .../api_accuracy_checker/run_ut/run_ut.py | 8 ++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index d5a21ef76..40f565c4f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -558,18 +558,18 @@ def check_need_convert(api_name): convert_type = key return convert_type -def api_info_preprocess(api_name, argument): +def api_info_preprocess(api_name, api_info_dict): """ Function Description: Preprocesses the API information. Parameter: api_name: Name of the API. - argument: argument of the API. - Return argument: + api_info_dict: argument of the API. + Return api_info_dict: convert_type: Type of conversion. - argument: Processed argument of the API. + api_info_dict: Processed argument of the API. """ convert_type = check_need_convert(api_name) - if api_name == 'cross_entropy' and argument['args'][1]['Min'] <=0: - argument['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. - return convert_type, argument \ No newline at end of file + if api_name == 'cross_entropy' and api_info_dict['args'][1]['Min'] <=0: + api_info_dict['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. + return convert_type, api_info_dict \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 861f46b56..26e6e352a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -77,15 +77,15 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare.write_compare_csv() -def run_torch_api(api_full_name, api_setting_dict, backward_content, argument): +def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): [api_type, api_name, _] = api_full_name.split("*") - convert_type, argument = api_info_preprocess(api_name, argument) + convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) need_grad = True - if argument.get("kwargs") and "out" in argument.get("kwargs"): + if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): need_grad = False if api_name[-1] == "_" or api_name in NO_GRAD_APIS: need_grad = False - args, kwargs = gen_api_params(argument, need_grad, convert_type) + args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) inplace = kwargs.get("inplace") if kwargs.get("inplace") else None need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True need_backward = need_backward and need_grad -- Gitee From 9071fe62c0e882c8475daf5baaa46fee9de69866 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 9 Aug 2023 17:12:19 +0800 Subject: [PATCH 9/9] update --- .../api_accuracy_checker/common/utils.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 40f565c4f..8fc952ebf 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -570,6 +570,20 @@ def api_info_preprocess(api_name, api_info_dict): api_info_dict: Processed argument of the API. """ convert_type = check_need_convert(api_name) - if api_name == 'cross_entropy' and api_info_dict['args'][1]['Min'] <=0: - api_info_dict['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. - return convert_type, api_info_dict \ No newline at end of file + if api_name == 'cross_entropy': + api_info_dict = cross_entropy_process(api_info_dict) + return convert_type, api_info_dict + +def cross_entropy_process(api_info_dict): + """ + Function Description: + Preprocesses the cross_entropy API information. + Parameter: + api_info_dict: argument of the API. + Return api_info_dict: + api_info_dict: Processed argument of the API. + """ + if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]: + if api_info_dict['args'][1]['Min'] <= 0: + api_info_dict['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. + return api_info_dict \ No newline at end of file -- Gitee