diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 24959a62950f03974aaf2822dc5cc83747475cfd..5a6e409aaeef0af881858a301a5db4eb0ee63585 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -211,6 +211,21 @@ def check_mode_valid(mode): raise CompareException(CompareException.INVALID_DUMP_MODE, msg) +def check_object_type(check_object, allow_type): + """ + Function Description: + Check if the object belongs to a certain data type + Parameter: + check_object: the object to be checked + allow_type: legal data type + Exception Description: + when invalid data throw exception + """ + if not isinstance(check_object, allow_type): + print_error_log(f"{check_object} not of {allow_type} type") + raise CompareException(CompareException.INVALID_DATA_ERROR) + + def check_file_or_directory_path(path, isdir=False): """ Function Description: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 5976d90049e5f42e1ef0cb60e14037d5813e5c76..9c0e4e8ea6b8dcbcfbd163d706e1702bf86607ad 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -1 +1,195 @@ -# 用于解析落盘json,进行数据生成 \ No newline at end of file +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import torch +import numpy as np + +from ..common.utils import check_file_or_directory_path, check_object_type, print_warn_log, print_error_log, \ + CompareException + +TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] +FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16', + 'torch.half', 'torch.bfloat16'] + + +def gen_data(info, need_grad): + """ + Function Description: + Based on arg basic information, generate arg data + Parameter: + info: arg basic information. Dict + need_grad: set Tensor grad for backward + """ + check_object_type(info, dict) + data_type = info.get('type') + data_path = info.get('datapath') + if data_type in TENSOR_DATA_LIST: + if data_path: + data = gen_real_tensor(data_path) + else: + data = gen_random_tensor(info) + if info.get('requires_grad') and need_grad: + data.requires_grad_(True) + data.retain_grad() + else: + data = info.get('value') + if info.get("type") == "slice": + data = slice(*data) + return data + + +def gen_real_tensor(data_path): + """ + Function Description: + Based on API data path, generate input parameters real data + Parameter: + data_path: API data path + """ + data_path = os.path.realpath(data_path) + check_file_or_directory_path(data_path) + if not data_path.endswith('.npy'): + print_error_log(f"The file: {data_path} is not a numpy file.") + raise CompareException.INVALID_FILE_ERROR + data_np = np.load(data_path) + data = torch.from_numpy(data_np) + return data + + +def gen_random_tensor(info): + """ + Function Description: + Based on API MAX and MIN, generate input parameters random data + Parameter: + info: API data info + """ + check_object_type(info, dict) + low, high = info.get('Min'), info.get('Max') + data_dtype = info.get('dtype') + shape = tuple(info.get('shape')) + if not isinstance(low, (int, float)) or not isinstance(high, (int, float)): + print_error_log(f'Data info Min: {low} , Max: {high}, info type must be int or float') + raise CompareException.INVALID_PARAM_ERROR + if data_dtype == "torch.bool": + data = gen_bool_tensor(low, high, shape) + else: + data = gen_common_tensor(low, high, shape, data_dtype) + return data + + +def gen_common_tensor(low, high, shape, data_dtype): + """ + Function Description: + Based on API basic information, generate int or float tensor + Parameter: + low: The minimum value in Tensor + high: The max value in Tensor + shape:The shape of Tensor + data_dtype: The data type of Tensor + """ + if data_dtype in FLOAT_TYPE: + scale = high - low + rand01 = torch.rand(shape, dtype=eval(data_dtype)) + tensor = rand01 * scale + low + tmp_tensor = tensor.reshape(-1) + tmp_tensor[0] = low + tmp_tensor[-1] = high + elif 'int' in data_dtype or 'long' in data_dtype: + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype)) + else: + print_error_log('Dtype is not supported: ' + data_dtype) + raise NotImplementedError() + tmp_tensor = tensor.reshape(-1) + tmp_tensor[0] = low + tmp_tensor[-1] = high + data = tmp_tensor.reshape(shape) + return data + + +def gen_bool_tensor(low, high, shape): + """ + Function Description: + Based on API basic information, generate bool tensor + Parameter: + low: The minimum value in Tensor + high: The max value in Tensor + shape:The shape of Tensor + """ + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape) + data = torch.gt(tensor, 0) + return data + + +def gen_args(args_info, need_grad=True): + """ + Function Description: + Based on API basic information, generate input parameters: args, for API forward running + Parameter: + api_info: API basic information. List + need_grad: set Tensor grad for backward + """ + check_object_type(args_info, list) + args_result = [] + for arg in args_info: + if isinstance(arg, (list, tuple)): + data = gen_args(arg, need_grad) + elif isinstance(arg, dict): + data = gen_data(arg, need_grad) + else: + print_warn_log(f'Warning: {arg} is not supported') + raise NotImplementedError() + args_result.append(data) + return args_result + + +def gen_kwargs(api_info): + """ + Function Description: + Based on API basic information, generate input parameters: kwargs, for API forward running + Parameter: + api_info: API basic information. Dict + """ + check_object_type(api_info, dict) + kwargs_params = api_info.get("kwargs") + for key, value in kwargs_params.items(): + if value.get('type') in TENSOR_DATA_LIST: + kwargs_params[key] = gen_data(value, False) + else: + kwargs_params[key] = value.get('value') + return kwargs_params + + +def gen_api_params(api_info, need_grad=True): + """ + Function Description: + Based on API basic information, generate input parameters: args, kwargs, for API forward running + Parameter: + api_info: API basic information. Dict + need_grad: set grad for backward + """ + check_object_type(api_info, dict) + kwargs_params = gen_kwargs(api_info) + if "inplace" in kwargs_params: + need_grad = False + if api_info.get("args"): + args_params = gen_args(api_info.get("args"), need_grad) + else: + print_warn_log(f'Warning: No args in {api_info} ') + raise NotImplementedError() + return args_params, kwargs_params