diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index f1c874407fcb8169524c72940b4785859b2455c6..8fc952ebffc9f9e795cc9988fbc5d44c279ca74f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -558,3 +558,32 @@ def check_need_convert(api_name): convert_type = key return convert_type +def api_info_preprocess(api_name, api_info_dict): + """ + Function Description: + Preprocesses the API information. + Parameter: + api_name: Name of the API. + api_info_dict: argument of the API. + Return api_info_dict: + convert_type: Type of conversion. + api_info_dict: Processed argument of the API. + """ + convert_type = check_need_convert(api_name) + if api_name == 'cross_entropy': + api_info_dict = cross_entropy_process(api_info_dict) + return convert_type, api_info_dict + +def cross_entropy_process(api_info_dict): + """ + Function Description: + Preprocesses the cross_entropy API information. + Parameter: + api_info_dict: argument of the API. + Return api_info_dict: + api_info_dict: Processed argument of the API. + """ + if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]: + if api_info_dict['args'][1]['Min'] <= 0: + api_info_dict['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. + return api_info_dict \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e6162ca8b2611a3a166f8ae4304fdfcb82551bdc..26e6e352a65629fc06f85de329420ae6f25cd74b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -5,7 +5,7 @@ import torch_npu import yaml import torch from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, check_need_convert, \ +from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ print_error_log from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate @@ -77,15 +77,15 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare.write_compare_csv() -def run_torch_api(api_full_name, api_setting_dict, backward_content, value): +def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): [api_type, api_name, _] = api_full_name.split("*") - convert_type = check_need_convert(api_name) + convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) need_grad = True - if value.get("kwargs") and "out" in value.get("kwargs"): + if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): need_grad = False if api_name[-1] == "_" or api_name in NO_GRAD_APIS: need_grad = False - args, kwargs = gen_api_params(value, need_grad, convert_type) + args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) inplace = kwargs.get("inplace") if kwargs.get("inplace") else None need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True need_backward = need_backward and need_grad