From 47cf5314f85de3c70e9c4e9622bda57b0158f4fb Mon Sep 17 00:00:00 2001 From: wangchao Date: Thu, 3 Aug 2023 15:21:45 +0800 Subject: [PATCH 1/2] support gen force convert data --- .../run_ut/data_generate.py | 55 ++++++++++++++----- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 9c0e4e8ea6b..bbf8969f35c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -19,7 +19,7 @@ import os import torch import numpy as np -from ..common.utils import check_file_or_directory_path, check_object_type, print_warn_log, print_error_log, \ +from ..common.utils import Const, check_file_or_directory_path, check_object_type, print_warn_log, print_error_log, \ CompareException TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] @@ -27,7 +27,7 @@ FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', ' 'torch.half', 'torch.bfloat16'] -def gen_data(info, need_grad): +def gen_data(info, need_grad, convert_type): """ Function Description: Based on arg basic information, generate arg data @@ -40,9 +40,9 @@ def gen_data(info, need_grad): data_path = info.get('datapath') if data_type in TENSOR_DATA_LIST: if data_path: - data = gen_real_tensor(data_path) + data = gen_real_tensor(data_path, convert_type) else: - data = gen_random_tensor(info) + data = gen_random_tensor(info, convert_type) if info.get('requires_grad') and need_grad: data.requires_grad_(True) data.retain_grad() @@ -53,7 +53,7 @@ def gen_data(info, need_grad): return data -def gen_real_tensor(data_path): +def gen_real_tensor(data_path, convert_type): """ Function Description: Based on API data path, generate input parameters real data @@ -67,10 +67,15 @@ def gen_real_tensor(data_path): raise CompareException.INVALID_FILE_ERROR data_np = np.load(data_path) data = torch.from_numpy(data_np) + if convert_type: + ori_dtype = Const.CONVERT.get(convert_type)[0] + dist_dtype = Const.CONVERT.get(convert_type)[1] + if str(data.dtype) == ori_dtype: + data = data.type(eval(dist_dtype)) return data -def gen_random_tensor(info): +def gen_random_tensor(info, convert_type): """ Function Description: Based on API MAX and MIN, generate input parameters random data @@ -87,11 +92,11 @@ def gen_random_tensor(info): if data_dtype == "torch.bool": data = gen_bool_tensor(low, high, shape) else: - data = gen_common_tensor(low, high, shape, data_dtype) + data = gen_common_tensor(low, high, shape, data_dtype, convert_type) return data -def gen_common_tensor(low, high, shape, data_dtype): +def gen_common_tensor(low, high, shape, data_dtype, convert_type): """ Function Description: Based on API basic information, generate int or float tensor @@ -101,6 +106,10 @@ def gen_common_tensor(low, high, shape, data_dtype): shape:The shape of Tensor data_dtype: The data type of Tensor """ + if convert_type: + ori_dtype = Const.CONVERT.get(convert_type)[0] + if ori_dtype == data_dtype: + data_dtype = Const.CONVERT.get(convert_type)[1] if data_dtype in FLOAT_TYPE: scale = high - low rand01 = torch.rand(shape, dtype=eval(data_dtype)) @@ -136,7 +145,7 @@ def gen_bool_tensor(low, high, shape): return data -def gen_args(args_info, need_grad=True): +def gen_args(args_info, need_grad=True, convert_type=None): """ Function Description: Based on API basic information, generate input parameters: args, for API forward running @@ -148,9 +157,9 @@ def gen_args(args_info, need_grad=True): args_result = [] for arg in args_info: if isinstance(arg, (list, tuple)): - data = gen_args(arg, need_grad) + data = gen_args(arg, need_grad, convert_type) elif isinstance(arg, dict): - data = gen_data(arg, need_grad) + data = gen_data(arg, need_grad, convert_type) else: print_warn_log(f'Warning: {arg} is not supported') raise NotImplementedError() @@ -158,7 +167,7 @@ def gen_args(args_info, need_grad=True): return args_result -def gen_kwargs(api_info): +def gen_kwargs(api_info, convert_type=None): """ Function Description: Based on API basic information, generate input parameters: kwargs, for API forward running @@ -168,6 +177,8 @@ def gen_kwargs(api_info): check_object_type(api_info, dict) kwargs_params = api_info.get("kwargs") for key, value in kwargs_params.items(): + if isinstance(value, (list, tuple)): + kwargs_params[key] = gen_list_kwargs(value, convert_type) if value.get('type') in TENSOR_DATA_LIST: kwargs_params[key] = gen_data(value, False) else: @@ -175,7 +186,18 @@ def gen_kwargs(api_info): return kwargs_params -def gen_api_params(api_info, need_grad=True): +def gen_list_kwargs(kwargs_item_value, convert_type): + kwargs_item_result = [] + for item in kwargs_item_value: + if item.get('type') in TENSOR_DATA_LIST: + item_value = gen_data(item, False, convert_type) + else: + item_value = item.get('value') + kwargs_item_result.append(item_value) + return kwargs_item_result + + +def gen_api_params(api_info, need_grad=True, convert_type=None): """ Function Description: Based on API basic information, generate input parameters: args, kwargs, for API forward running @@ -184,12 +206,15 @@ def gen_api_params(api_info, need_grad=True): need_grad: set grad for backward """ check_object_type(api_info, dict) + if convert_type and convert_type not in Const.CONVERT: + print_error_log(f"convert_type params not support {convert_type} ") + raise CompareException.INVALID_PARAM_ERROR kwargs_params = gen_kwargs(api_info) if "inplace" in kwargs_params: need_grad = False if api_info.get("args"): - args_params = gen_args(api_info.get("args"), need_grad) + args_params = gen_args(api_info.get("args"), need_grad, convert_type) else: - print_warn_log(f'Warning: No args in {api_info} ') + print_error_log(f'Warning: No args in {api_info} ') raise NotImplementedError() return args_params, kwargs_params -- Gitee From 845eea009a3490f2e9aa408e83dcc8ec04b08564 Mon Sep 17 00:00:00 2001 From: wangchao Date: Thu, 3 Aug 2023 15:30:17 +0800 Subject: [PATCH 2/2] support gen force convert data --- .../run_ut/data_generate.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index bbf8969f35c..0ca84b8971b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -34,6 +34,7 @@ def gen_data(info, need_grad, convert_type): Parameter: info: arg basic information. Dict need_grad: set Tensor grad for backward + convert_type: convert ori_type to dist_type flag. """ check_object_type(info, dict) data_type = info.get('type') @@ -59,6 +60,7 @@ def gen_real_tensor(data_path, convert_type): Based on API data path, generate input parameters real data Parameter: data_path: API data path + convert_type: convert ori_type to dist_type flag. """ data_path = os.path.realpath(data_path) check_file_or_directory_path(data_path) @@ -81,6 +83,7 @@ def gen_random_tensor(info, convert_type): Based on API MAX and MIN, generate input parameters random data Parameter: info: API data info + convert_type: convert ori_type to dist_type flag. """ check_object_type(info, dict) low, high = info.get('Min'), info.get('Max') @@ -105,6 +108,7 @@ def gen_common_tensor(low, high, shape, data_dtype, convert_type): high: The max value in Tensor shape:The shape of Tensor data_dtype: The data type of Tensor + convert_type: convert ori_type to dist_type flag. """ if convert_type: ori_dtype = Const.CONVERT.get(convert_type)[0] @@ -152,6 +156,7 @@ def gen_args(args_info, need_grad=True, convert_type=None): Parameter: api_info: API basic information. List need_grad: set Tensor grad for backward + convert_type: convert ori_type to dist_type flag. """ check_object_type(args_info, list) args_result = [] @@ -173,20 +178,28 @@ def gen_kwargs(api_info, convert_type=None): Based on API basic information, generate input parameters: kwargs, for API forward running Parameter: api_info: API basic information. Dict + convert_type: convert ori_type to dist_type flag. """ check_object_type(api_info, dict) kwargs_params = api_info.get("kwargs") for key, value in kwargs_params.items(): if isinstance(value, (list, tuple)): kwargs_params[key] = gen_list_kwargs(value, convert_type) - if value.get('type') in TENSOR_DATA_LIST: - kwargs_params[key] = gen_data(value, False) + elif value.get('type') in TENSOR_DATA_LIST: + kwargs_params[key] = gen_data(value, False, convert_type) else: kwargs_params[key] = value.get('value') return kwargs_params def gen_list_kwargs(kwargs_item_value, convert_type): + """ + Function Description: + When kwargs value is list, generate the list of kwargs result + Parameter: + kwargs_item_value: kwargs value before to generate. List + convert_type: convert ori_type to dist_type flag. + """ kwargs_item_result = [] for item in kwargs_item_value: if item.get('type') in TENSOR_DATA_LIST: @@ -204,12 +217,13 @@ def gen_api_params(api_info, need_grad=True, convert_type=None): Parameter: api_info: API basic information. Dict need_grad: set grad for backward + convert_type: convert ori_type to dist_type flag. """ check_object_type(api_info, dict) if convert_type and convert_type not in Const.CONVERT: print_error_log(f"convert_type params not support {convert_type} ") raise CompareException.INVALID_PARAM_ERROR - kwargs_params = gen_kwargs(api_info) + kwargs_params = gen_kwargs(api_info, convert_type) if "inplace" in kwargs_params: need_grad = False if api_info.get("args"): -- Gitee