From 335d6baf6ccb9e2e4979f6c025cb75c2bba98b13 Mon Sep 17 00:00:00 2001 From: wangchao Date: Mon, 31 Jul 2023 09:22:47 +0000 Subject: [PATCH 1/3] =?UTF-8?q?API=E6=A3=80=E6=B5=8B=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E7=94=9F=E6=88=90=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: wangchao --- .../run_ut/data_generate.py | 184 +++++++++++++++++- 1 file changed, 183 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 5976d90049e..479036c2ec7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -1 +1,183 @@ -# 用于解析落盘json,进行数据生成 \ No newline at end of file +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import torch +import numpy as np + +from ..common.utils import check_file_or_directory_path, print_warn_log + +TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] +FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16', + 'torch.half', 'torch.bfloat16'] + + +def gen_data(info, need_grad): + """ + Function Description: + Based on arg basic information, generate arg data + Parameter: + info: arg basic information. Dict + need_grad: set Tensor grad for backward + """ + data_type = info.get('type') + data_path = info.get('datapath') + if data_type in TENSOR_DATA_LIST: + if data_path: + data = gen_real_tensor(data_path) + else: + data = gen_random_tensor(info) + if info.get('requires_grad') and need_grad: + data.requires_grad_(True) + data.retain_grad() + else: + data = info.get('value') + if info.get("type") == "slice": + data = slice(*data) + return data + + +def gen_real_tensor(data_path): + """ + Function Description: + Based on API data path, generate input parameters real data + Parameter: + data_path: API data path + """ + data_path = os.path.realpath(data_path) + check_file_or_directory_path(data_path) + data_np = np.load(data_path) + data = torch.from_numpy(data_np) + return data + + +def gen_random_tensor(info): + """ + Function Description: + Based on API MAX and MIN, generate input parameters random data + Parameter: + info: API data info + """ + low, high = info.get('Min'), info.get('Max') + data_dtype = info.get('dtype') + shape = tuple(info.get('shape')) + if data_dtype == "torch.bool": + data = gen_bool_tensor(low, high, shape) + else: + data = gen_common_tensor(low, high, shape, data_dtype) + return data + + +def gen_common_tensor(low, high, shape, data_dtype): + """ + Function Description: + Based on API basic information, generate int or float tensor + Parameter: + low: The minimum value in Tensor + high: The max value in Tensor + shape:The shape of Tensor + data_dtype: The data type of Tensor + """ + if data_dtype in FLOAT_TYPE: + scale = high - low + rand01 = torch.rand(shape, dtype=eval(data_dtype)) + tensor = rand01 * scale + low + tmp_tensor = tensor.reshape(-1) + tmp_tensor[0] = low + tmp_tensor[-1] = high + elif 'int' in data_dtype or 'long' in data_dtype: + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype)) + else: + print_warn_log('Warning: Dtype is not supported: ' + data_dtype) + raise NotImplementedError() + tmp_tensor = tensor.reshape(-1) + tmp_tensor[0] = low + tmp_tensor[-1] = high + data = tmp_tensor.reshape(shape) + return data + + +def gen_bool_tensor(low, high, shape): + """ + Function Description: + Based on API basic information, generate bool tensor + Parameter: + low: The minimum value in Tensor + high: The max value in Tensor + shape:The shape of Tensor + """ + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape) + data = torch.gt(tensor, 0) + return data + + +def gen_args(args_info, need_grad=True): + """ + Function Description: + Based on API basic information, generate input parameters: args, for API forward running + Parameter: + api_info: API basic information. Dict + need_grad: set Tensor grad for backward + """ + args_result = [] + for arg in args_info: + if isinstance(arg, (list, tuple)): + data = gen_args(arg, need_grad) + elif isinstance(arg, dict): + data = gen_data(arg, need_grad) + else: + print_warn_log(f'Warning: {arg} is not supported') + raise NotImplementedError() + args_result.append(data) + return args_result + + +def gen_kwargs(api_info): + """ + Function Description: + Based on API basic information, generate input parameters: kwargs, for API forward running + Parameter: + api_info: API basic information. Dict + """ + kwargs_params = api_info.get("kwargs") + for key, value in kwargs_params.items(): + if value.get('type') in TENSOR_DATA_LIST: + kwargs_params[key] = gen_data(value, False) + else: + kwargs_params[key] = value.get('value') + return kwargs_params + + +def gen_api_params(api_info, need_grad=True): + """ + Function Description: + Based on API basic information, generate input parameters: args, kwargs, for API forward running + Parameter: + api_info: API basic information. Dict + need_grad: set grad for backward + """ + kwargs_params = gen_kwargs(api_info) + if "inplace" in kwargs_params: + need_grad = False + if api_info.get("args"): + args_params = gen_args(api_info.get("args"), need_grad) + else: + print_warn_log(f'Warning: No args in {api_info} ') + raise NotImplementedError() + return args_params, kwargs_params -- Gitee From 1ec523a24152c1a7ea33e3b7f1c3f03510ffac18 Mon Sep 17 00:00:00 2001 From: wangchao Date: Tue, 1 Aug 2023 07:21:58 +0000 Subject: [PATCH 2/3] =?UTF-8?q?=E5=B7=A5=E5=85=B7=E7=B1=BB=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E7=B1=BB=E5=9E=8B=E6=A0=A1=E9=AA=8C=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: wangchao --- .../api_accuracy_checker/common/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 24959a62950..5a6e409aaee 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -211,6 +211,21 @@ def check_mode_valid(mode): raise CompareException(CompareException.INVALID_DUMP_MODE, msg) +def check_object_type(check_object, allow_type): + """ + Function Description: + Check if the object belongs to a certain data type + Parameter: + check_object: the object to be checked + allow_type: legal data type + Exception Description: + when invalid data throw exception + """ + if not isinstance(check_object, allow_type): + print_error_log(f"{check_object} not of {allow_type} type") + raise CompareException(CompareException.INVALID_DATA_ERROR) + + def check_file_or_directory_path(path, isdir=False): """ Function Description: -- Gitee From 162930107a27aa340ff477709cbd9d62e1023997 Mon Sep 17 00:00:00 2001 From: wangchao Date: Tue, 1 Aug 2023 07:24:05 +0000 Subject: [PATCH 3/3] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E7=94=9F=E6=88=90?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E6=96=B0=E5=A2=9E=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: wangchao --- .../run_ut/data_generate.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 479036c2ec7..9c0e4e8ea6b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -19,7 +19,8 @@ import os import torch import numpy as np -from ..common.utils import check_file_or_directory_path, print_warn_log +from ..common.utils import check_file_or_directory_path, check_object_type, print_warn_log, print_error_log, \ + CompareException TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16', @@ -34,6 +35,7 @@ def gen_data(info, need_grad): info: arg basic information. Dict need_grad: set Tensor grad for backward """ + check_object_type(info, dict) data_type = info.get('type') data_path = info.get('datapath') if data_type in TENSOR_DATA_LIST: @@ -60,6 +62,9 @@ def gen_real_tensor(data_path): """ data_path = os.path.realpath(data_path) check_file_or_directory_path(data_path) + if not data_path.endswith('.npy'): + print_error_log(f"The file: {data_path} is not a numpy file.") + raise CompareException.INVALID_FILE_ERROR data_np = np.load(data_path) data = torch.from_numpy(data_np) return data @@ -72,9 +77,13 @@ def gen_random_tensor(info): Parameter: info: API data info """ + check_object_type(info, dict) low, high = info.get('Min'), info.get('Max') data_dtype = info.get('dtype') shape = tuple(info.get('shape')) + if not isinstance(low, (int, float)) or not isinstance(high, (int, float)): + print_error_log(f'Data info Min: {low} , Max: {high}, info type must be int or float') + raise CompareException.INVALID_PARAM_ERROR if data_dtype == "torch.bool": data = gen_bool_tensor(low, high, shape) else: @@ -103,7 +112,7 @@ def gen_common_tensor(low, high, shape, data_dtype): low, high = int(low), int(high) tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype)) else: - print_warn_log('Warning: Dtype is not supported: ' + data_dtype) + print_error_log('Dtype is not supported: ' + data_dtype) raise NotImplementedError() tmp_tensor = tensor.reshape(-1) tmp_tensor[0] = low @@ -132,9 +141,10 @@ def gen_args(args_info, need_grad=True): Function Description: Based on API basic information, generate input parameters: args, for API forward running Parameter: - api_info: API basic information. Dict + api_info: API basic information. List need_grad: set Tensor grad for backward """ + check_object_type(args_info, list) args_result = [] for arg in args_info: if isinstance(arg, (list, tuple)): @@ -155,6 +165,7 @@ def gen_kwargs(api_info): Parameter: api_info: API basic information. Dict """ + check_object_type(api_info, dict) kwargs_params = api_info.get("kwargs") for key, value in kwargs_params.items(): if value.get('type') in TENSOR_DATA_LIST: @@ -172,6 +183,7 @@ def gen_api_params(api_info, need_grad=True): api_info: API basic information. Dict need_grad: set grad for backward """ + check_object_type(api_info, dict) kwargs_params = gen_kwargs(api_info) if "inplace" in kwargs_params: need_grad = False -- Gitee