From 20af32c2191c02469ef2a4a573010f20cb801282 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Thu, 3 Aug 2023 14:30:51 +0800 Subject: [PATCH 1/2] add covert type --- .../api_accuracy_checker/common/utils.py | 21 ++++++++++++++++++- .../api_accuracy_checker/run_ut/run_ut.py | 7 ++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index e62c9616ca7..b5be8b6e1e2 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -77,6 +77,14 @@ class Const: API_PATTERN = r"^[A-Za-z0-9]+[_]+([A-Za-z0-9]+[_]*[A-Za-z0-9]+)[_]+[0-9]+[_]+[A-Za-z0-9]+" WRITE_FLAGS = os.O_WRONLY | os.O_CREAT WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR + + CONVERT = { + "fp16_to_fp32": ["torch.float16", "torch.float32"] + } + + CONVERT_API = { + "fp16_to_fp32": ["conv2d", "batch_norm", "relu", "max_pool2d"] + } class CompareConst: """ @@ -536,4 +544,15 @@ def check_input_file_valid(input_path, max_file_size=MAX_JSON_FILE_SIZE): raise ValueError("The real path or file_name of input is too long.") if os.path.getsize(input_path) > max_file_size: - raise ValueError(f'The file is too large, exceeds {max_file_size // 1024 ** 2}MB') \ No newline at end of file + raise ValueError(f'The file is too large, exceeds {max_file_size // 1024 ** 2}MB') + + +def check_need_convert(api_name): + convert_type = None + for key, value in Const.CONVERT_API: + if api_name not in value: + continue + else: + convert_type = key + return convert_type + diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 3a0c9c45155..c25b2e1f3b5 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -5,7 +5,7 @@ sys.path.append("..") import yaml import torch from data_generate import gen_api_params, gen_args -from common.utils import print_info_log, print_warn_log, get_json_contents +from common.utils import print_info_log, print_warn_log, get_json_contents, check_need_convert from compare.compare import Comparator cur_path = os.path.dirname(os.path.realpath(__file__)) @@ -76,7 +76,8 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): def run_torch_api(api_full_name, api_setting_dict, backward_content, value): [api_type, api_name, _] = api_full_name.split("*") - args, kwargs = gen_api_params(value, api_name[-1] != "_") + convert_type = check_need_convert(api_name) + args, kwargs = gen_api_params(value, api_name[-1] != "_", convert_type) inplace = kwargs.get("inplace") if kwargs.get("inplace") else None need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True if inplace or api_name[-1] == "_": @@ -153,4 +154,4 @@ def _run_ut(): if __name__ == '__main__': _run_ut() - print_info_log("UT task completed.") \ No newline at end of file + print_info_log("UT task completed.") -- Gitee From 0424b66560b7fe28d6d2a264ee7f6736a6853e6a Mon Sep 17 00:00:00 2001 From: l30036321 Date: Thu, 3 Aug 2023 14:37:31 +0800 Subject: [PATCH 2/2] add covert type --- debug/accuracy_tools/api_accuracy_checker/common/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index b5be8b6e1e2..fa4540c58ef 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -83,7 +83,7 @@ class Const: } CONVERT_API = { - "fp16_to_fp32": ["conv2d", "batch_norm", "relu", "max_pool2d"] + "fp16_to_fp32": ["conv2d", "batch_norm", "relu", "max_pool2d", "interpolate"] } class CompareConst: -- Gitee