diff --git a/debug/accuracy_tools/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/op_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7d3e2b226bde101c5aeba46f5353cbf3c53549d0 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/op_generator.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import argparse +import json +import os +import math +import numpy as np +import torch +try: + import torch_npu +except ImportError: + pass + +from api_accuracy_checker.compare.compare_utils import BinaryStandardApi, AbsoluteStandardApi, ULPStandardApi + + +TENSOR_DATA_LIST = ["torch.Tensor"] +TORCH_BOOL_TYPE = ["torch.bool"] +TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", + "torch.int64", "torch.long"] +TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", + "torch.float64", "torch.double"] +TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"] + + +def check_json(json_path): + json_file = os.path.realpath(json_path) + with open(json_file) as f: + json_content = json.load(f) + if not isinstance(json_content, dict): + raise ValueError("content of json file is not a dictionary!") + if len(list(json_content.items())) > 1: + raise ValueError("json file has more than one API, only one API is allowed!") + (api_full_name, api_info_dict) = list(json_content.items())[0] + (api_type, api_name, ordinal_number) = api_full_name.split(".", -1) + if api_type not in ("Functional", "Tensor", "Torch"): + raise ValueError("type {0} of API is not supported!".format(api_type)) + return (api_full_name, api_info_dict) + + +def check_user_settings(cmd_args): + iter_t = cmd_args.iter_times + if iter_t <= 0: + raise ValueError("iter_times should be an integer bigger than zero!") + (api_full_name, api_info_dict) = check_json(cmd_args.forward_json_path) + return api_full_name, api_info_dict + + +def get_compare_standard(api_name): + if api_name in BinaryStandardApi: + return "CompareStandard.BINARY_EQUALITY_STANDARD" + if api_name in AbsoluteStandardApi: + return "CompareStandard.ABSOLUTE_THRESHOLD_STANDARD" + if api_name in ULPStandardApi: + return "CompareStandard.ULP_ERROR_STANDARD" + return "CompareStandard.BENCHMARK_STANDARD" + + +def get_settings(cmd_args): + ''' + internal_settings contain all information needed for the operator program. + keys: + api_full_name: api_type.api_name.ordinal_number + api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch + api_name: name of API + ordinal_number: how many times the same api has been called + direction_status: forward + random_seed: if mode is random_data, random seed is random_seed + iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data, iter_times does not matter + args_element_assignment: code for args assignment + args_list_generator_device: code for generate args list on device + args_list_generator_bench: code for generate args list on bench + kwargs_value_assignment: code for kwargs assignment + kwargs_dict_generator_device: code for generate kwargs dict on device + kwargs_dict_generator_bench: code for generate kwargs dict on bench + ''' + api_full_name, api_info_dict = check_user_settings(cmd_args) + args_info = api_info_dict.get("args") + kwargs_info = api_info_dict.get("kwargs") + + internal_settings = {} + internal_settings["api_full_name"] = api_full_name + (api_type, api_name, ordinal_number) = api_full_name.split(".", -1) + if api_type == "Functional": + internal_settings["api_type"] = "torch.nn.functional" + elif api_type == "Tensor": + internal_settings["api_type"] = "torch.Tensor" + else: + internal_settings["api_type"] = "torch" + internal_settings["api_name"] = api_name + internal_settings["compare_standard"] = get_compare_standard(api_name) + internal_settings["ordinal_number"] = ordinal_number + internal_settings["direction_status"] = "forward" + internal_settings["random_seed"] = cmd_args.random_seed + if cmd_args.mode == "real_data": + internal_settings["iter_times"] = 1 + else: + internal_settings["iter_times"] = cmd_args.iter_times + internal_settings["args_element_assignment"] = generate_args_element_assignment_code(args_info) + internal_settings["args_list_generator_device"] = generate_args_list_device(args_info) + internal_settings["args_list_generator_bench"] = generate_args_list_bench(args_info) + internal_settings["kwargs_value_assignment"] = generate_kwargs_value_assignment_code(kwargs_info) + internal_settings["kwargs_dict_generator_device"] = generate_kwargs_dict_device(kwargs_info) + internal_settings["kwargs_dict_generator_bench"] = generate_kwargs_dict_bench(kwargs_info) + return internal_settings + + +def recursive_args_element_assignment(args_info, name_number): + args_element_assignment = "" + for index, arg in enumerate(args_info): + if isinstance(arg, (list, tuple)): + new_args_element_assignment = recursive_args_element_assignment(arg, name_number + "_" + str(index)) + args_element_assignment += new_args_element_assignment + else: + arg["parameter_name"] = "arg" + name_number + "_" + str(index) + args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " + "{}".format(str(arg)) + "\n" + args_element_assignment += " " + "arg" + name_number + "_" + str(index) + " = " + "generate_data(arg_info" + name_number + "_" + str(index) + ")" + "\n" + return args_element_assignment + + +def generate_args_element_assignment_code(args_info): + args_element_assignment = recursive_args_element_assignment(args_info, "") + return args_element_assignment + + +def recursive_args_list(args_info, flag_device=False, flag_bench=False): + args_list_generator = "" + for index, arg in enumerate(args_info): + if isinstance(arg, (list, tuple)): + (left_bracket, right_bracket) = ("[", "]") if isinstance(arg, list) else ("(", ")") + args_list_generator += left_bracket + new_args_list_generator = recursive_args_list(arg, flag_device=flag_device, flag_bench=flag_bench) + args_list_generator += new_args_list_generator + args_list_generator += right_bracket + else: + args_list_generator += arg.get("parameter_name") + if arg.get("type") in TENSOR_DATA_LIST: + if flag_device: + args_list_generator += ".to(device)" + if flag_bench: + args_list_generator += '.to(torch.device("cpu"))' + args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") + ".dtype), " + arg.get("parameter_name") + ".dtype))" + args_list_generator += ", " + return args_list_generator + + +def generate_args_list_device(args_info): + args_list_generator_device = recursive_args_list(args_info, flag_device=True) + return args_list_generator_device + + +def generate_args_list_bench(args_info): + args_list_generator_bench = recursive_args_list(args_info, flag_bench=True) + return args_list_generator_bench + + +def recursive_kwargs_value_assignment(info, key_name, name_number): + kwargs_value_assignment = "" + if isinstance(info, dict): + if info.get("type") == "torch.device" or info.get("type") == "torch.dtype": + kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + info.get("value") + else: + kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " + "{}".format(str(info)) + "\n" + kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + "generate_data(kwarg_info_" + key_name + name_number + ")" + "\n" + info["parameter_name"] = "kwarg_" + key_name + name_number + else: + for index, arg in enumerate(info): + new_kwargs_value_assignment = recursive_kwargs_value_assignment(arg, key_name, name_number + "_" + str(index)) + kwargs_value_assignment += new_kwargs_value_assignment + return kwargs_value_assignment + + +def generate_kwargs_value_assignment_code(kwargs_info): + kwargs_value_assignment = "" + for key, value in kwargs_info.items(): + kwargs_value_assignment += recursive_kwargs_value_assignment(value, key, "") + return kwargs_value_assignment + + +def recursive_kwargs_dict(info, flag_device=False, flag_bench=False): + kwargs_dict_generator = "" + if isinstance(info, dict): + kwargs_dict_generator += info.get("parameter_name") + if info.get("type") in TENSOR_DATA_LIST: + if flag_device: + kwargs_dict_generator += ".to(device)" + if flag_bench: + kwargs_dict_generator += '.to(torch.device("cpu"))' + kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") + ".dtype), " + info.get("parameter_name") + ".dtype))" + else: + (left_bracket, right_bracket) = ("[", "]") if isinstance(info, list) else ("(", ")") + kwargs_dict_generator += left_bracket + for arg in info: + kwargs_dict_generator += recursive_kwargs_dict(arg, flag_device=flag_device, flag_bench=flag_bench) + kwargs_dict_generator += ", " + kwargs_dict_generator += right_bracket + return kwargs_dict_generator + + +def generate_kwargs_dict_device(kwargs_info): + kwargs_dict_generator_device = "" + for key, value in kwargs_info.items(): + kwargs_dict_generator_device += '"' + key + '"' + ": " + kwargs_dict_generator_device += recursive_kwargs_dict(value, flag_device=True) + ", " + return kwargs_dict_generator_device + + +def generate_kwargs_dict_bench(kwargs_info): + kwargs_dict_generator_bench = "" + for key, value in kwargs_info.items(): + kwargs_dict_generator_bench += '"' + key + '"' + ": " + kwargs_dict_generator_bench += recursive_kwargs_dict(value, flag_bench=True) + ", " + return kwargs_dict_generator_bench + + +def op_generator_parser(parser): + parser.add_argument("-forward", "--forward_json_path", dest="forward_json_path", type=str, + help=" Path of forward API json file.", + required=True) + parser.add_argument("-m", "--mode", dest="mode", type=str, choices=("random_data", "real_data"), + help=" Execute mode, should be random_data or real_data.", + required=True) + parser.add_argument("-rs", "--random_seed", dest = "random_seed", type=int, default=1234, + help=" If mode is random_data, it is random seed.", + required=False) + parser.add_argument("-it", "--iter_times", dest="iter_times", type=int, default=5, + help=" If mode is random_data, generate iter_times group of data.", + required=False) + + +def main(): + parser = argparse.ArgumentParser() + op_generator_parser(parser) + cmd_args = parser.parse_args() + internal_settings = get_settings(cmd_args) + + template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template") + operator_script_path = os.path.join(os.path.dirname(__file__), "{0}.py".format(internal_settings.get("api_full_name"))) + + try: + with open(template_path, 'r') as ftemp, open(operator_script_path, 'w') as fout: + code_template = ftemp.read() + fout.write(code_template.format(**internal_settings)) + except OSError: + print(f"Failed to open file. Please check file {template_path} or {operator_script_path}.") + + print(f"Generate operator script successfully and the name is {operator_script_path}.") + + +if __name__ == "__main__": + main() diff --git a/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template new file mode 100644 index 0000000000000000000000000000000000000000..7630839aa937c6d0419629b5e93c34b51b71f295 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template @@ -0,0 +1,325 @@ +import json +import os +import math +from enum import Enum, auto +import torch +try: + import torch_npu +except ImportError: + pass + + +TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] +TORCH_BOOL_TYPE = ["torch.bool"] +TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", + "torch.int64", "torch.long"] +TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", + "torch.float64", "torch.double"] +TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"] +RAISE_PRECISION = {{ + "torch.float16": torch.float32, + "torch.half": torch.float32, + "torch.bfloat16": torch.float32, + "torch.float32": torch.float64, + "torch.float": torch.float64 +}} + + +class CompareStandard(Enum): + BINARY_EQUALITY_STANDARD = auto() + ABSOLUTE_THRESHOLD_STANDARD = auto() + ULP_ERROR_STANDARD = auto() + BENCHMARK_STANDARD = auto() + + +def get_device(): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch_npu.npu.is_available(): + device = torch.device("npu") + else: + raise Exception("Error: This device is not NPU or GPU!") + return device + + +def generate_bool_tensor(low, high, shape): + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape) + bool_tensor = torch.gt(tensor, 0) + return bool_tensor + + +def generate_numerical_tensor(low, high, shape, data_dtype): + if data_dtype in TORCH_FLOAT_TYPE: + scale = high - low + rand01 = torch.rand(shape, dtype=eval(data_dtype)) + tensor = rand01 * scale + low + elif data_dtype in TORCH_INT_TYPE: + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype)) + else: + raise NotImplementedError(f"{{data_dtype}} is not supported!") + if torch.numel(tensor) == 0: + return tensor + tmp_tensor = tensor.reshape(-1) + tmp_tensor[0] = low + tmp_tensor[-1] = high + data = tmp_tensor.reshape(shape) + return data + + +def generate_random_tensor(info): + low, high = info.get('Min'), info.get('Max') + data_dtype = info.get('dtype') + shape = tuple(info.get('shape')) + if data_dtype == "torch.bool": + data = generate_bool_tensor(low, high, shape) + else: + data = generate_numerical_tensor(low, high, shape, data_dtype) + return data + + +def generate_real_tensor(data_path): + data_path = os.path.realpath(data_path) + data = torch.load(data_path) + return data + + +def generate_data(info): + data_type = info.get("type") + data_path = info.get("datapath") + if data_type in TENSOR_DATA_LIST: + if data_path: + data = generate_real_tensor(data_path) + else: + data = generate_random_tensor(info) + else: + data = info.get("value") + return data + + +def get_input(): +{args_element_assignment} + args_device = [{args_list_generator_device}] + args_bench = [{args_list_generator_bench}] +{kwargs_value_assignment} + kwargs_device = {{{kwargs_dict_generator_device}}} + kwargs_bench = {{{kwargs_dict_generator_bench}}} + return args_device, kwargs_device, args_bench, kwargs_bench + + +def exec_api_device(args, kwargs): + output_device = {api_type}.{api_name}(*args, **kwargs) + return output_device + + +def exec_api_bench(args, kwargs): + output_bench = {api_type}.{api_name}(*args, **kwargs) + return output_bench + + +def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol): + out_bench = out_bench.to(out_device.dtype) + min = torch.finfo(out_device.dtype).min + max = torch.finfo(out_device.dtype).max + bench_clip = torch.clamp(out_bench, min=min, max=max) + device_clip = torch.clamp(out_device, min=min, max=max) + clipped_abs_ae = torch.abs(device_clip - bench_clip) + clipped_re = clipped_abs_ae / abs_bench_with_eps + pass_mask = torch.less_equal(clipped_re, rtol) + both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip)) + pass_mask = torch.logical_or(pass_mask, both_nan_mask) + not_pass_mask = torch.logical_not(pass_mask) + not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask) + inf_nan_err_cnt = torch.sum(not_pass_mask) + return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask) + + +def compute_rmse(abs_err, normal_value_mask): + if torch.sum(normal_value_mask) == 0: + return 0 + else: + masked_ae = torch.where(normal_value_mask, abs_err, 0) + mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask) + rmse = torch.sqrt(mse) + return rmse + + +def compute_error_balance(out_device, out_bench): + larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0)) + smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0)) + total_count = torch.numel(out_bench) + error_balance = abs(larger_count - smaller_count) / total_count + return error_balance + + +def compare_tensor(out_device, out_bench, api_name): + if out_device.shape != out_bench.shape: + print("ERROR: shape of out_device and out_bench is not equal!") + return None + if torch.numel(out_bench) == 0: + print("Both out_device and out_bench have zero elements.") + return None + print(f"shape is {{out_bench.shape}}") + print(f"dtype of out_device is {{out_device.dtype}}") + print(f"dtype of out_bench is {{out_bench.dtype}}") + dtype_device = out_device.dtype + dtype_bench = out_bench.dtype + if str(dtype_device) in TORCH_FLOAT_TYPE and str(dtype_bench) in TORCH_FLOAT_TYPE \ + or str(dtype_device) in TORCH_INT_TYPE and str(dtype_bench) in TORCH_INT_TYPE \ + or str(dtype_device) in TORCH_BOOL_TYPE and str(dtype_bench) in TORCH_BOOL_TYPE: + out_device = out_device.to(torch.device("cpu")) + if str(dtype_device) in TORCH_BOOL_TYPE or str(dtype_device) in TORCH_INT_TYPE or compare_standard == CompareStandard.BINARY_EQUALITY_STANDARD: + print("compare standard: binary equality standard:") + error_number = torch.sum(out_device != out_bench).item() + error_rate = error_number / torch.numel(out_bench) + print(f"error rate is {{error_rate}}.") + else: + abs_err = torch.abs(out_device - out_bench) + abs_bench = torch.abs(out_bench) + if dtype_bench == torch.float32: + eps = 2 ** -23 + if dtype_bench == torch.float64: + eps = 2 ** -52 + abs_bench_with_eps = abs_bench + eps + rel_err = torch.abs(abs_err / abs_bench_with_eps) + device_finite_mask = torch.isfinite(out_device) + bench_finite_mask = torch.isfinite(out_bench.to(dtype_device)) + both_finite_mask = torch.logical_and(device_finite_mask, bench_finite_mask) + inf_nan_mask = torch.logical_not(both_finite_mask) + if compare_standard == CompareStandard.ABSOLUTE_THRESHOLD_STANDARD: + if dtype_device == torch.float16: + rtol, small_value, small_value_atol = 1.0e-3, 1.0e-3, 1.0e-5 + elif dtype_device == torch.bfloat16: + rtol, small_value, small_value_atol = 4.0e-3, 1.0e-3, 1.0e-5 + else: + rtol, small_value, small_value_atol = 1.0e-6, 1.0e-6, 1.0e-9 + small_value_mask = torch.less_equal(abs_bench, small_value) + small_value_mask = torch.logical_and(small_value_mask, both_finite_mask) + normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask)) + inf_nan_proportion = compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol) + rel_err_mask = torch.greater(rel_err, rtol) + rel_err_mask = torch.logical_and(rel_err_mask, normal_value_mask) + if torch.sum(normal_value_mask) == 0: + rel_err_proportion = 0 + else: + rel_err_proportion = torch.sum(rel_err_mask) / torch.sum(normal_value_mask) + abs_err_mask = torch.greater(abs_err, small_value_atol) + abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask) + if torch.sum(small_value_mask) == 0: + abs_err_proportion = 0 + else: + abs_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask) + print("compare standard: absolute threshold standard") + print(f"relative error ratio is {{rel_err_proportion}}") + print(f"absolute error ratio is {{abs_err_proportion}}") + elif compare_standard == CompareStandard.ULP_ERROR_STANDARD: + if dtype_device == torch.float16: + min_eb, exponent_num = -14, 10 + elif dtype_device == torch.bfloat16: + min_eb, exponent_num = -126, 7 + else: + min_eb, exponent_num = -126, 23 + eb = torch.where(abs_bench == 0, torch.zeros(out_bench.shape), torch.floor(torch.log2(abs_bench))) + eb = torch.maximum(eb, min_eb * torch.ones(out_bench.shape)) + if dtype_device == torch.float32: + ulp_err = (out_device.to(torch.float64) - out_bench).to(torch.float64) * torch.exp2(-eb + exponent_num).to(torch.float64) + else: + ulp_err = (out_device.to(torch.float32) - out_bench).to(torch.float32) * torch.exp2(-eb + exponent_num).to(torch.float32) + ulp_err = torch.abs(ulp_err) + max_ulp_err = torch.max(ulp_err) + mean_ulp_err = torch.mean(ulp_err) + if dtype_device == torch.float32: + ulp_err_proportion = torch.sum(ulp_err > 32) / torch.numel(out_bench) + else: + ulp_err_proportion = torch.sum(ulp_err > 1) / torch.numel(out_bench) + print("compare standard: ulp error standard") + print(f"maximum ulp error is {{max_ulp_err}}") + print(f"mean ulp error is {{mean_ulp_err}}") + print(f"ulp error proportion is {{ulp_err_proportion}}") + else: + if dtype_device == torch.float16: + small_value, small_value_atol = 1.0e-3, 1.0e-5 + elif dtype_device == torch.bfloat16: + small_value, small_value_atol = 1.0e-3, 1.0e-5 + else: + small_value, small_value_atol = 1.0e-6, 1.0e-9 + small_value_mask = torch.less_equal(abs_bench, small_value) + small_value_mask = torch.logical_and(small_value_mask, both_finite_mask) + normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask)) + abs_err_mask = torch.greater(abs_err, small_value_atol) + abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask) + if torch.sum(small_value_mask) == 0: + small_value_err_proportion = 0 + else: + small_value_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask) + rel_err = torch.where(normal_value_mask, rel_err, -1 * torch.ones(out_device.shape)) + if torch.max(rel_err) >= 0: + max_rel_err = torch.max(rel_err) + else: + max_rel_err = 0 + if torch.sum(normal_value_mask) == 0: + mean_rel_err = 0 + else: + mean_rel_err = torch.sum(torch.clamp(rel_err, min=0)) / torch.sum(normal_value_mask) + rmse = compute_rmse(abs_err, normal_value_mask) + error_balance = compute_error_balance(out_device, out_bench) + print("compare standard: benchmark standard") + print(f"small value error proportion is {{small_value_err_proportion}}") + print(f"maximum relative error is {{max_rel_err}}") + print(f"mean relative error is {{mean_rel_err}}") + print(f"root mean squared error is {{rmse}}") + print(f"error balance is {{error_balance}}") + else: + print(f"ERROR: out_device dtype is {{dtype_device}}, out_bench dtype is {{dtype_bench}}, not comparable.") + return None + + +def compare_element(out_device, out_bench, api_name): + if type(out_device) != type(out_bench): + print("ERROR: out_device and out_bench is not the same type!") + return None + if isinstance(out_bench, torch.Tensor): + print(f"data type: {{type(out_bench)}}") + compare_tensor(out_device, out_bench, api_name) + elif isinstance(out_bench, (bool, int, float, str)): + print(f"data type: {{type(out_bench)}}") + if out_device == out_bench: + print("PASS: out_device and out_bench equals.") + else: + print("ERROR: out_device and out_bench is not equal!") + else: + print(f"ERROR: comparison of type {{type(out_bench)}} is not supported.") + return None + + +def compare(out_device, out_bench, api_name): + print("Compare result:") + if type(out_device) != type(out_bench): + print("ERROR: out_device and out_bench is not the same type!") + print("Compare finished.") + return None + if isinstance(out_bench, (list, tuple)): + print(f"data type: {{type(out_bench)}}") + if len(out_device) != len(out_bench): + print("ERROR: len of out_device and out_bench is different!") + print("Compare finished.") + return None + for index, _ in enumerate(out_bench): + print(f"index {{index}}:") + compare_element(out_device[index], out_bench[index], api_name) + else: + compare_element(out_device, out_bench, api_name) + print("Compare finished.") + + +device = get_device() +api_name = "{api_name}" +compare_standard = {compare_standard} +torch.manual_seed({random_seed}) +for i in range({iter_times}): + print(f"iter: {{i}}:") + args_device, kwargs_device, args_bench, kwargs_bench = get_input() + output_device = exec_api_device(args_device, kwargs_device) + output_bench = exec_api_bench(args_bench, kwargs_bench) + compare(output_device, output_bench, api_name)