diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 9c0e4e8ea6b8dcbcfbd163d706e1702bf86607ad..0ca84b8971bc285b76948c50db168087d0a24248 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -19,7 +19,7 @@ import os import torch import numpy as np -from ..common.utils import check_file_or_directory_path, check_object_type, print_warn_log, print_error_log, \ +from ..common.utils import Const, check_file_or_directory_path, check_object_type, print_warn_log, print_error_log, \ CompareException TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] @@ -27,22 +27,23 @@ FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', ' 'torch.half', 'torch.bfloat16'] -def gen_data(info, need_grad): +def gen_data(info, need_grad, convert_type): """ Function Description: Based on arg basic information, generate arg data Parameter: info: arg basic information. Dict need_grad: set Tensor grad for backward + convert_type: convert ori_type to dist_type flag. """ check_object_type(info, dict) data_type = info.get('type') data_path = info.get('datapath') if data_type in TENSOR_DATA_LIST: if data_path: - data = gen_real_tensor(data_path) + data = gen_real_tensor(data_path, convert_type) else: - data = gen_random_tensor(info) + data = gen_random_tensor(info, convert_type) if info.get('requires_grad') and need_grad: data.requires_grad_(True) data.retain_grad() @@ -53,12 +54,13 @@ def gen_data(info, need_grad): return data -def gen_real_tensor(data_path): +def gen_real_tensor(data_path, convert_type): """ Function Description: Based on API data path, generate input parameters real data Parameter: data_path: API data path + convert_type: convert ori_type to dist_type flag. """ data_path = os.path.realpath(data_path) check_file_or_directory_path(data_path) @@ -67,15 +69,21 @@ def gen_real_tensor(data_path): raise CompareException.INVALID_FILE_ERROR data_np = np.load(data_path) data = torch.from_numpy(data_np) + if convert_type: + ori_dtype = Const.CONVERT.get(convert_type)[0] + dist_dtype = Const.CONVERT.get(convert_type)[1] + if str(data.dtype) == ori_dtype: + data = data.type(eval(dist_dtype)) return data -def gen_random_tensor(info): +def gen_random_tensor(info, convert_type): """ Function Description: Based on API MAX and MIN, generate input parameters random data Parameter: info: API data info + convert_type: convert ori_type to dist_type flag. """ check_object_type(info, dict) low, high = info.get('Min'), info.get('Max') @@ -87,11 +95,11 @@ def gen_random_tensor(info): if data_dtype == "torch.bool": data = gen_bool_tensor(low, high, shape) else: - data = gen_common_tensor(low, high, shape, data_dtype) + data = gen_common_tensor(low, high, shape, data_dtype, convert_type) return data -def gen_common_tensor(low, high, shape, data_dtype): +def gen_common_tensor(low, high, shape, data_dtype, convert_type): """ Function Description: Based on API basic information, generate int or float tensor @@ -100,7 +108,12 @@ def gen_common_tensor(low, high, shape, data_dtype): high: The max value in Tensor shape:The shape of Tensor data_dtype: The data type of Tensor + convert_type: convert ori_type to dist_type flag. """ + if convert_type: + ori_dtype = Const.CONVERT.get(convert_type)[0] + if ori_dtype == data_dtype: + data_dtype = Const.CONVERT.get(convert_type)[1] if data_dtype in FLOAT_TYPE: scale = high - low rand01 = torch.rand(shape, dtype=eval(data_dtype)) @@ -136,21 +149,22 @@ def gen_bool_tensor(low, high, shape): return data -def gen_args(args_info, need_grad=True): +def gen_args(args_info, need_grad=True, convert_type=None): """ Function Description: Based on API basic information, generate input parameters: args, for API forward running Parameter: api_info: API basic information. List need_grad: set Tensor grad for backward + convert_type: convert ori_type to dist_type flag. """ check_object_type(args_info, list) args_result = [] for arg in args_info: if isinstance(arg, (list, tuple)): - data = gen_args(arg, need_grad) + data = gen_args(arg, need_grad, convert_type) elif isinstance(arg, dict): - data = gen_data(arg, need_grad) + data = gen_data(arg, need_grad, convert_type) else: print_warn_log(f'Warning: {arg} is not supported') raise NotImplementedError() @@ -158,38 +172,63 @@ def gen_args(args_info, need_grad=True): return args_result -def gen_kwargs(api_info): +def gen_kwargs(api_info, convert_type=None): """ Function Description: Based on API basic information, generate input parameters: kwargs, for API forward running Parameter: api_info: API basic information. Dict + convert_type: convert ori_type to dist_type flag. """ check_object_type(api_info, dict) kwargs_params = api_info.get("kwargs") for key, value in kwargs_params.items(): - if value.get('type') in TENSOR_DATA_LIST: - kwargs_params[key] = gen_data(value, False) + if isinstance(value, (list, tuple)): + kwargs_params[key] = gen_list_kwargs(value, convert_type) + elif value.get('type') in TENSOR_DATA_LIST: + kwargs_params[key] = gen_data(value, False, convert_type) else: kwargs_params[key] = value.get('value') return kwargs_params -def gen_api_params(api_info, need_grad=True): +def gen_list_kwargs(kwargs_item_value, convert_type): + """ + Function Description: + When kwargs value is list, generate the list of kwargs result + Parameter: + kwargs_item_value: kwargs value before to generate. List + convert_type: convert ori_type to dist_type flag. + """ + kwargs_item_result = [] + for item in kwargs_item_value: + if item.get('type') in TENSOR_DATA_LIST: + item_value = gen_data(item, False, convert_type) + else: + item_value = item.get('value') + kwargs_item_result.append(item_value) + return kwargs_item_result + + +def gen_api_params(api_info, need_grad=True, convert_type=None): """ Function Description: Based on API basic information, generate input parameters: args, kwargs, for API forward running Parameter: api_info: API basic information. Dict need_grad: set grad for backward + convert_type: convert ori_type to dist_type flag. """ check_object_type(api_info, dict) - kwargs_params = gen_kwargs(api_info) + if convert_type and convert_type not in Const.CONVERT: + print_error_log(f"convert_type params not support {convert_type} ") + raise CompareException.INVALID_PARAM_ERROR + kwargs_params = gen_kwargs(api_info, convert_type) if "inplace" in kwargs_params: need_grad = False if api_info.get("args"): - args_params = gen_args(api_info.get("args"), need_grad) + args_params = gen_args(api_info.get("args"), need_grad, convert_type) else: - print_warn_log(f'Warning: No args in {api_info} ') + print_error_log(f'Warning: No args in {api_info} ') raise NotImplementedError() return args_params, kwargs_params