From 75e7fad1db3531ba5246aad2da5d77a185bddd25 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 10:19:11 +0800 Subject: [PATCH 01/13] =?UTF-8?q?=E5=8D=95=E7=AE=97=E5=AD=90=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=E7=94=9F=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generate_op_script/config_op.json | 9 + .../generate_op_script/op_generator.py | 468 ++++ .../operator_replication.template | 2293 +++++++++++++++++ 3 files changed, 2770 insertions(+) create mode 100644 debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json create mode 100644 debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py create mode 100644 debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json new file mode 100644 index 0000000000..68a47dc26c --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json @@ -0,0 +1,9 @@ +{ + "dump_json_path": "./dump.json", + "api_name": "Mint.split.1", + "extract_api_path": "Mint.split.1.json", + "propagation": "backward", + "data_mode": "random_data", + "random_seed": 1234, + "iter_times": 1 +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py new file mode 100644 index 0000000000..ecc6ac1c8f --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -0,0 +1,468 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mindspore +import numpy as np +from mindspore._c_expression import typing +import argparse +import json +import os +import re +import string + +import math +import numpy as np +import torch + +from msprobe.core.common.file_utils import FileOpen, load_json, save_json +from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int +from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst +from msprobe.core.common.log import logger +from msprobe.core.common.file_utils import make_dir, change_mode +from msprobe.core.common.decorator import recursion_depth_decorator + +MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor" +BOOL_TYPE_STR = "bool" +INT_TYPE_STR = "int" +FLOAT_TYPE_STR = "float" +SLICE_TYPE_STR = "slice" +TUPLE_TYPE_STR = "tuple" +STR_TYPE_STR = "str" +MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype" +TORCH_DTYPE_TYPE_STR = "torch.dtype" + +api_info_type_str_to_type = { + MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor, + BOOL_TYPE_STR: bool, + INT_TYPE_STR: int, + FLOAT_TYPE_STR: float, + SLICE_TYPE_STR: slice, + STR_TYPE_STR: str, + MINDSPORE_DTYPE_TYPE_STR: typing.Type, +} +type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()} + +TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] +TORCH_BOOL_TYPE = ["torch.bool"] +TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", + "torch.int64", "torch.long"] +TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", + "torch.float64", "torch.double"] +TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", + "torch.cdouble"] +OPERATOR_TYPE = ("Functional", "Tensor", "Torch", "Mint") + +API_INFO = 2 +FOUR_SEGMENT = 4 +FIVE_SEGMENT = 5 +DATA_NAME = "data_name" +API_MAX_LENGTH = 30 +PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD] +DATAMODE_LIST = ["random_data", "real_data"] +ITER_MAX_TIMES = 1000 + + +class APIInfo: + def __init__(self, api_full_name, api_info_dict, backward_info=None): + self.api_full_name = api_full_name + self.api_info_dict = api_info_dict + self.backward_info = backward_info + + @property + def api_type(self): + return self.api_full_name.split(Const.SEP, -1)[0] + + @classmethod + def from_json(cls, json_content, propagation): + forward_name, forward_dict = list(json_content.items())[0] + forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict) + + if propagation == Const.BACKWARD: + backward_name, backward_dict = list(json_content.items())[1] + backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict) + forward_info.backward_info = backward_info + + if not forward_info.is_supported_type(): + raise ValueError(f"type {forward_info.api_type} of API is not supported!") + + return forward_info + + def is_supported_type(self): + return self.api_type in OPERATOR_TYPE + + +class CommonConfig: + def __init__(self, json_config): + self.dump_json_path = json_config.get('dump_json_path') + self.api_name = json_config.get('api_name') + self.extract_api_path = json_config.get('extract_api_path') + self.propagation = json_config.get('propagation') + self.data_mode = json_config.get('data_mode') + self.random_seed = json_config.get('random_seed') + self.iter_times = json_config.get('iter_times') + self._check_config() + + def check_user_settings(self): + iter_t = self.iter_times + if iter_t <= 0: + raise ValueError("iter_times should be an integer bigger than zero!") + if iter_t > ITER_MAX_TIMES: + raise ValueError("iter_times should not be greater than 1000!") + + json_file = self.extract_api_path + propagation = self.propagation + + json_content = load_json(json_file) + + # ensure the dict is not empty + if not json_content: + raise ValueError(f'json file is empty!') + + # ensure json_content is of type dict + if not isinstance(json_content, dict): + raise ValueError(f'content of json file is not a dict!') + + # ensure the length of json_content is within allowed limits + print(f"json_content:{json_content}") + print(f"len(json_content):{len(json_content)}") + + if len(json_content) > API_INFO + 2: + raise ValueError(f'json file has more than one API, the API only contains forward and backward info') + + # Retrieve the first API name and dictionary + forward_item = next(iter(json_content.items()), None) + if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]: + raise ValueError(f'Invalid forward API data in json_content!') + # 需要去除掉影响key + if propagation == Const.BACKWARD: + backward_item = list(json_content.items())[1] + if not isinstance(backward_item[1], dict) or not backward_item[1]: + raise ValueError(f'Invalid backward API data in json_content!') + + return json_content + + def _check_config(self): + if self.dump_json_path: + check_file_or_directory_path(self.dump_json_path) + if self.api_name: + check_op_str_pattern_valid(self.api_name) + if len(self.api_name) > API_MAX_LENGTH: + raise ValueError(f'API name {self.api_name} is too long!') + make_dir(os.path.dirname(self.extract_api_path)) + if self.propagation and self.propagation not in PROPAGATION_LIST: + raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}') + if self.data_mode and self.data_mode not in DATAMODE_LIST: + raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}') + if not is_int(self.random_seed): + raise ValueError(f'random_seed is invalid, it should be an int') + if not is_int(self.iter_times): + raise ValueError(f'iter_times is invalid, it should be an int') + + +class APIExtractor: + def __init__(self, api_name, dump_json_path, output_file): + self.api_name = api_name + self.dump_json_path = dump_json_path + self.output_file = output_file + self.data = None + self.framework = None + self.real_data_path = None + + def extract_op(self): + self.data = load_json(self.dump_json_path) + # 拿到 framework + self.framework = self.data.get('framework', None) + # print(f"self.data:{self.data}") + new_data = {} + extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+") # 修改为只要包含或等于apiname即可,不需要是只包含 + + self.real_data_path = self.data.get('dump_data_dir', '') + + for key, value in self.data.get('data', {}).items(): + print(f"key:{key}") + if extract_key_pattern.match(key): + if self.real_data_path: + print(f"self.real_data_path:{self.real_data_path}") + value = self.load_real_data_path(value, self.real_data_path) + print(f"value:{value}") + new_data[key] = value + + if self.real_data_path is not None: + new_data['real_data_path'] = self.real_data_path + + # 把 framework 加进去 + if self.framework is not None: + new_data['framework'] = self.framework + if not new_data: + logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.") + else: + save_json(self.output_file, new_data, indent=4) + logger.info( + f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}") + + def load_real_data_path(self, value, dump_data_dir): + parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT] + for parameter in parameters: + for v in value.get(parameter, []): + if v is not None: + self.update_data_name(v, dump_data_dir) + return value + + @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name") + def update_data_name(self, data, dump_data_dir): + if isinstance(data, list): + for item in data: + self.update_data_name(item, dump_data_dir) + elif DATA_NAME in data: + data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME]) + + +class OperatorScriptGenerator: + def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward): + self.common_config = common_config + self.args_info_forward = args_info_forward + self.kwargs_info_forward = kwargs_info_forward + self.args_info_backward = args_info_backward + + @staticmethod + def extract_detailed_api_segments(full_api_name): + """ + Function Description: + Extract the name of the API. + Parameter: + full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0 + Return: + api_name: Name of api. Example: matmul, mul, etc. + full_api_name: Full name of api. Example: torch.matmul.0 + direction_status: Direction status of api. Example: forward, backward, etc. + """ + api_parts = full_api_name.split(Const.SEP) + api_parts_length = len(api_parts) + api_type, api_name, api_order = None, None, None + if api_parts_length == FOUR_SEGMENT: + api_type, api_name, api_order, _ = api_parts + elif api_parts_length == FIVE_SEGMENT: + api_type, prefix, api_name, api_order, _ = api_parts + api_name = Const.SEP.join([prefix, api_name]) + return api_type, api_name, api_order + + def get_settings(self, api_full_name): + ''' + internal_settings contain all information needed for the operator program. + keys: + api_full_name: api_type.api_name.ordinal_number + api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch + api_name: name of API + ordinal_number: how many times the same api has been called + direction_status: forward + random_seed: if mode is random_data, random seed is random_seed + iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data, + iter_times does not matter + args_element_assignment: code for args assignment + args_list_generator_device: code for generate args list on device + args_list_generator_bench: code for generate args list on bench + kwargs_value_assignment: code for kwargs assignment + kwargs_dict_generator_device: code for generate kwargs dict on device + kwargs_dict_generator_bench: code for generate kwargs dict on bench + ''' + # Generate an internal setting dictionary based on user settings + # including API name, type, comparison standard, random seed, number of iterations and other information + internal_settings = {} + internal_settings["propagation"] = self.common_config.propagation + internal_settings["api_full_name"] = api_full_name + api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name) + if api_type == "Functional": + internal_settings["api_type"] = "torch.nn.functional" + elif api_type == "Tensor": + internal_settings["api_type"] = "torch.Tensor" + else: + internal_settings["api_type"] = "torch" + internal_settings["api_name"] = api_name + internal_settings["ordinal_number"] = ordinal_number + internal_settings["direction_status"] = self.common_config.propagation + internal_settings["random_seed"] = self.common_config.random_seed + internal_settings["data_mode"] = self.common_config.data_mode + if self.common_config.data_mode == "real_data": + internal_settings["iter_times"] = 1 + else: + internal_settings["iter_times"] = self.common_config.iter_times + + internal_settings["args_info_forward"] = self.args_info_forward + internal_settings["kwargs_info_forward"] = self.kwargs_info_forward + internal_settings["args_info_backward"] = self.args_info_backward + + return internal_settings + + def generate_forward_inputs_code(self, args_info): + # 先把 generate_args_element_assignment_code 里已定义的 arg_info_x 变量名 + # 取出来,拼成列表 + names = [] + + def collect(info): + if isinstance(info, dict): + names.append(info["parameter_name"]) + else: + for sub in info: collect(sub) + + collect(args_info) + + return ( + " forward_inputs = [\n" + " ComputeElement(parameter=info)\n" + " for info in (" + ", ".join(names) + ")\n" + " ]\n" + ) + + def generate_kwargs_compute_element_dict_code(self): + # 我们这里假定 kwargs_device 已经是一个 dict 变量 + return ( + " # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n" + " kwargs_compute_element_dict = {\n" + " key_str: ComputeElement(compute_element_info=compute_element_info)\n" + " for key_str, compute_element_info in kwargs_device.items()\n" + " }\n" + ) + + def generate_gradient_inputs_code(self, args_info_backward): + # 同理收集反向梯度的 arg_info_x + names = [] + + def collect(info): + if isinstance(info, dict): + names.append(info["parameter_name"]) + else: + for sub in info: collect(sub) + + collect(args_info_backward) + + return ( + " # —— 构造反向梯度 ComputeElement 列表 —— #\n" + " gradient_inputs = [\n" + " ComputeElement(parameter=info)\n" + " for info in (" + ", ".join(names) + ")\n" + " ]\n" + ) + + +def _op_generator_parser(parser): + parser.add_argument("-i", "--config_input", dest="config_input", type=str, + help=" Path of config json file", required=True) + parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str, + help=" Path of extract api_name.json.", required=True) + + +def parse_json_config(json_file_path): + if not json_file_path: + raise Exception("config_input path can not be empty, please check.") + json_config = load_json(json_file_path) + common_config = CommonConfig(json_config) + return common_config + + +def _run_operator_generate_commond(cmd_args): + common_config = parse_json_config(cmd_args.config_input) + + if common_config.dump_json_path: + api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path) + api_extract.extract_op() + framework = api_extract.framework + real_data_path = api_extract.real_data_path + check_file_or_directory_path(common_config.extract_api_path) + check_file_or_directory_path(cmd_args.api_output_path, isdir=True) + json_content = common_config.check_user_settings() + api_info = APIInfo.from_json(json_content, common_config.propagation) + + if common_config.propagation == Const.BACKWARD: + # read and check json + api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict + api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name, + api_info.backward_info.api_info_dict) + args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS) + kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS) + if Const.GRAD_INPUT in api_info_dict_backward: + args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT) + elif Const.INPUT in api_info_dict_backward: + args_info_backward = api_info_dict_backward.get(Const.INPUT) + op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward) + internal_settings = op_generate.get_settings(api_full_name_backward) + internal_settings['framework'] = framework + internal_settings['real_data_path'] = real_data_path + else: + # read and check json + api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict + print(f"api_full_name_forward:{api_full_name_forward},api_info_dict_forward:{api_info_dict_forward}") + args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS) + + kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS) + print(f"args_info_forward:{args_info_forward},kwargs_info_forward:{kwargs_info_forward}") + + op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None) + internal_settings = op_generate.get_settings(api_full_name_forward) + internal_settings['framework'] = framework + internal_settings['real_data_path'] = real_data_path + + template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template") + operator_script_path = os.path.join(cmd_args.api_output_path, + "{0}.py".format(internal_settings.get("api_full_name"))) + + class SafeDict(dict): + def __missing__(self, key): + # leave {key} in the output if it’s not in the dict + return '{' + key + '}' + + class RobustFormatter(string.Formatter): + def vformat(self, format_string, args, kwargs): + result = [] + # parse() 会把文本和每个占位符拆开 + for literal, field_name, format_spec, conversion in self.parse(format_string): + # 输出字面文本 + result.append(literal) + if field_name is None: + continue + try: + # 正常获取变量并格式化 + obj, _ = self.get_field(field_name, args, kwargs) + if conversion: + obj = self.convert_field(obj, conversion) + result.append(self.format_field(obj, format_spec)) + except Exception: + # 不管是 KeyError 还是 ValueError,都原样回写 {field_name[:format_spec]} + placeholder = '{' + field_name + if conversion: + placeholder += '!' + conversion + if format_spec: + placeholder += ':' + format_spec + placeholder += '}' + result.append(placeholder) + return ''.join(result) + + fmt = RobustFormatter() + with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout: + code_template = ftemp.read() + # 这里用 fmt.format,不用 format_map + fout.write(fmt.format(code_template, **internal_settings)) + + change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY) + + logger.info(f"Generate operator script successfully and the name is {operator_script_path}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _op_generator_parser(parser) + cmd_args = parser.parse_args() + _run_operator_generate_commond(cmd_args) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template new file mode 100644 index 0000000000..22b30b86b0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -0,0 +1,2293 @@ +import os +import re +import stat +import time +from enum import Enum, auto +from abc import ABC, abstractmethod +import csv + +import gc +import sys +from pathlib import Path +import mindspore +from mindspore import ops + + +from tabulate import tabulate + +import logging + +import traceback + + + +def error_log_with_exp(self, msg: str, exp: Exception): + """ + msg: 你的错误提示 + exp: 你要记录的 Exception 实例 + """ + # 将 Exception 的类型、消息和 traceback 通过 exc_info 参数一并传给 .error() + self.error(msg, exc_info=(type(exp), exp, exp.__traceback__)) + +# 把它挂到 Logger 上 +logging.Logger.error_log_with_exp = error_log_with_exp + + + +# 1. 基本配置:设置日志级别为 INFO,默认输出到控制台 +logging.basicConfig(level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + datefmt='%H:%M:%S') + +logger = logging.getLogger() + + +# ======= 常数类 ======= + +class CodedException(Exception): + def __init__(self, code, error_info=''): + super().__init__() + self.code = code + self.error_info = self.err_strs.get(code) + error_info + + def __str__(self): + return self.error_info + + +class ApiAccuracyCheckerException(CodedException): + ParseJsonFailed = 0 + UnsupportType = 1 + WrongValue = 2 + ApiWrong = 3 + err_strs = { + ParseJsonFailed: "[msprobe] Api Accuracy Checker parse json failed: ", + UnsupportType: "[msprobe] Api Accuracy Checker get unsupported type: ", + WrongValue: "[msprobe] Api Accuracy Checker get wrong value: ", + ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ", + } + + +class FileCheckConst: + """ + Class for file check const + """ + READ_ABLE = "read" + WRITE_ABLE = "write" + READ_WRITE_ABLE = "read and write" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + JSON_SUFFIX = ".json" + PT_SUFFIX = ".pt" + CSV_SUFFIX = ".csv" + XLSX_SUFFIX = ".xlsx" + YAML_SUFFIX = ".yaml" + IR_SUFFIX = ".ir" + ZIP_SUFFIX = ".zip" + SHELL_SUFFIX = ".sh" + MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_ZIP_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 + DIR = "dir" + FILE = "file" + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + FILE_SIZE_DICT = { + PKL_SUFFIX: MAX_PKL_SIZE, + NUMPY_SUFFIX: MAX_NUMPY_SIZE, + JSON_SUFFIX: MAX_JSON_SIZE, + PT_SUFFIX: MAX_PT_SIZE, + CSV_SUFFIX: MAX_CSV_SIZE, + XLSX_SUFFIX: MAX_XLSX_SIZE, + YAML_SUFFIX: MAX_YAML_SIZE, + IR_SUFFIX: MAX_IR_SIZE, + ZIP_SUFFIX: MAX_ZIP_SIZE + } + CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]' + +class Const: + MAX_DEPTH = 10 + PT_FRAMEWORK = "pytorch" + MS_FRAMEWORK = "mindspore" + MT_FRAMEWORK = "mindtorch" + SEP = "." + KWARGS = 'kwargs' + INPUT = 'input' + OUTPUT = 'output' + INPUT_ARGS = 'input_args' + INPUT_KWARGS = 'input_kwargs' + GRAD_INPUT = 'grad_input' + GRAD_OUTPUT = 'grad_output' + BACKWARD = 'backward' + FORWARD = 'forward' + + +class CompareConst: + # compare result data + PASS = 'pass' + WARNING = 'Warning' + ERROR = 'error' + TRUE = 'TRUE' + FALSE = 'FALSE' + SKIP = 'SKIP' + + # compare result column name + COSINE = "Cosine" + EUC_DIST = "EucDist" + MAX_ABS_ERR = "MaxAbsErr" + MAX_RELATIVE_ERR = "MaxRelativeErr" + MIN_RELATIVE_ERR = "MinRelativeErr" + MEAN_RELATIVE_ERR = "MeanRelativeErr" + NORM_RELATIVE_ERR = "NormRelativeErr" + + # accuracy standards + COS_THRESHOLD = 0.99 + MAX_ABS_ERR_THRESHOLD = 0.001 + MAX_RELATIVE_ERR_THRESHOLD = 0.001 + COS_MAX_THRESHOLD = 0.9 + MAX_ABS_ERR_MAX_THRESHOLD = 1 + +class MsCompareConst: + # api_info field + MINT = "Mint" + MINT_FUNCTIONAL = "MintFunctional" + TENSOR_API = "Tensor" + FUNCTIONAL_API = "Functional" + FUSION_API = "FUSION" + + API_NAME_STR_LENGTH = 4 + MAX_RECURSION_DEPTH = 20 + + # Mindtorch api_info field + MINDTORCH_TENSOR = "Tensor" + MINDTORCH = "Torch" + MINDTORCH_FUNC = "Functional" + MINDTORCH_NPU = "NPU" + MINDTORCH_DIST = "Distributed" + + MT_VALID_API_TYPES = [ + MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR + ] + SUPPORTED_FUSION_LIST = ["flash_attention_score"] + + TASK_FIELD = "task" + STATISTICS_TASK = "statistics" + FRAMEWORK = "framework" + TENSOR_TASK = "tensor" + DUMP_DATA_DIR_FIELD = "dump_data_dir" + DATA_FIELD = "data" + + # supported api yaml + SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" + SUPPORTED_TENSOR_LIST_KEY = "tensor" + + # detail_csv + DETAIL_CSV_API_NAME = "API Name" + DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" + DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" + DETAIL_CSV_SHAPE = "Shape" + DETAIL_CSV_PASS_STATUS = "Status" + DETAIL_CSV_MESSAGE = "Message" + DETAIL_CSV_FILE_NAME = "accuracy_checking_details" + + # result_csv + RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" + RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" + RESULT_CSV_FILE_NAME = "accuracy_checking_result" + + EPSILON = 1e-8 + + class ProcessStatus: + SUCCESS = "success" + API_NOT_FOUND = "api_not_found" + EXCEPTION_SKIP = "exception_skip" + +# ======= mindtorch支持 ======== +import torch as mindtorch +from torch import Tensor as mindtorch_tensor +import torch.nn.functional as mindtorch_func +import torch.distributed as mindtorch_dist + +is_valid_pt_mt_env = True + + +def is_mindtorch(): + mindtorch_check_result = False + try: + import torch as test_torch + from mindspore import Tensor as MindsporeTensor + except ImportError: + return mindtorch_check_result + tensor = test_torch.tensor(0.0) + if isinstance(tensor, MindsporeTensor): + mindtorch_check_result = True + + return mindtorch_check_result + + +def remove_torch_related_paths(): + removed_paths = [] + if not is_mindtorch(): + return + try: + import torch as remove_torch + torch_file = remove_torch.__file__ + except ImportError: + return + + torch_dir = os.path.dirname(torch_file) + + torch_dir_path = Path(torch_dir).resolve() + parent_dir = torch_dir_path.parent + + paths_to_remove = [str(parent_dir)] + + for path in paths_to_remove: + try: + path_resolved = str(Path(path).resolve()) + except Exception as error: + logger.debug(f"Failed to resolve path {path}: {error}") + + + if path_resolved in sys.path: + index = sys.path.index(path_resolved) + removed_paths.append((path_resolved, index)) + sys.path.pop(index) + + return + + +def clear_torch_from_sys_modules(): + modules_to_remove = [] + for module in sys.modules: + if module == "torch" or module.startswith("torch."): + modules_to_remove.append(module) + + for module in modules_to_remove: + del sys.modules[module] + + +def set_pt_mt_env_invalid(): + global is_valid_pt_mt_env + is_valid_pt_mt_env = False + + +def delete_torch_paths(): + + if not is_mindtorch(): + set_pt_mt_env_invalid() + + clear_torch_from_sys_modules() + + for count_delete_env_path in range(MsCompareConst.MAX_RECURSION_DEPTH): + if not is_mindtorch(): + break + + remove_torch_related_paths() + + clear_torch_from_sys_modules() + + if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1: + raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure " + f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.") + + +if not is_mindtorch(): + set_pt_mt_env_invalid() + +else: + initial_sys_path = sys.path.copy() + delete_torch_paths() + + gc.collect() + + import torch + + if is_mindtorch(): + set_pt_mt_env_invalid() + + sys.path = initial_sys_path + + + +if not is_valid_pt_mt_env: + import torch + + + +# ======= 常数类 ======= +import numpy as np +from mindspore._c_expression import typing +from mindspore.common import dtype as mstype + + +TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] +TORCH_BOOL_TYPE = ["torch.bool"] +TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", + "torch.int64", "torch.long"] +TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", + "torch.float64", "torch.double"] +TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"] +RAISE_PRECISION = {{ + "torch.float16": torch.float32, + "torch.half": torch.float32, + "torch.bfloat16": torch.float32, + "torch.float32": torch.float64, + "torch.float": torch.float64 +}} +THOUSANDTH_THRESHOLDING = 0.001 +BACKWARD = 'backward' +DIR = "dir" +FILE = "file" +READ_ABLE = "read" +WRITE_ABLE = "write" +READ_WRITE_ABLE = "read and write" +DIRECTORY_LENGTH = 4096 +FILE_NAME_LENGTH = 255 +SOFT_LINK_ERROR = "检测到软链接" +FILE_PERMISSION_ERROR = "文件权限错误" +INVALID_FILE_ERROR = "无效文件" +ILLEGAL_PATH_ERROR = "非法文件路径" +ILLEGAL_PARAM_ERROR = "非法打开方式" +FILE_TOO_LARGE_ERROR = "文件过大" +FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" +FILE_SIZE_DICT = {{ + ".pkl": 1073741824, # 1 * 1024 * 1024 * 1024 + ".npy": 10737418240, # 10 * 1024 * 1024 * 1024 + ".json": 1073741824, # 1 * 1024 * 1024 * 1024 + ".pt": 10737418240, # 10 * 1024 * 1024 * 1024 + ".csv": 1073741824, # 1 * 1024 * 1024 * 1024 + ".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024 + ".yaml": 1073741824, # 1 * 1024 * 1024 * 1024 + ".ir": 1073741824 # 1 * 1024 * 1024 * 1024 +}} +COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 + + +INT8 = "Int8" +UINT8 = "UInt8" +INT16 = "Int16" +UINT16 = "UInt16" +INT32 = "Int32" +UINT32 = "UInt32" +INT64 = "Int64" +UINT64 = "UInt64" +FLOAT16 = "Float16" +FLOAT32 = "Float32" +FLOAT64 = "Float64" +BOOL = "Bool" +BFLOAT16 = "BFloat16" +INT4 = "Int4" + +dtype_str_to_ms_dtype = { + INT8: mstype.int8, + UINT8: mstype.uint8, + INT16: mstype.int16, + UINT16: mstype.uint16, + INT32: mstype.int32, + UINT32: mstype.uint32, + INT64: mstype.int64, + UINT64: mstype.uint64, + FLOAT16: mstype.float16, + FLOAT32: mstype.float32, + FLOAT64: mstype.float64, + BOOL: mstype.bool_, + BFLOAT16: mstype.bfloat16, + INT4: mstype.qint4x2 +} +ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()} + +dtype_str_to_np_dtype = { + INT8: np.int8, + UINT8: np.uint8, + INT16: np.int16, + UINT16: np.uint16, + INT32: np.int32, + UINT32: np.uint32, + INT64: np.int64, + UINT64: np.uint64, + FLOAT16: np.float16, + FLOAT32: np.float32, + FLOAT64: np.float64, + BOOL: np.bool_ +} +np_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_np_dtype.items()} + +dtype_str_to_torch_dtype = { + INT8: torch.int8, + UINT8: torch.uint8, + INT16: torch.int16, + INT32: torch.int32, + INT64: torch.int64, + FLOAT16: torch.float16, + FLOAT32: torch.float32, + FLOAT64: torch.float64, + BOOL: torch.bool, + BFLOAT16: torch.bfloat16, +} +torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()} + + +dtype_str_to_mindtorch_dtype = { + INT8: mindtorch.int8, + UINT8: mindtorch.uint8, + INT16: mindtorch.int16, + INT32: mindtorch.int32, + INT64: mindtorch.int64, + FLOAT16: mindtorch.float16, + FLOAT32: mindtorch.float32, + FLOAT64: mindtorch.float64, + BOOL: mindtorch.bool, + BFLOAT16: mindtorch.bfloat16, +} +mindtorch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_mindtorch_dtype.items()} + +MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor" +BOOL_TYPE_STR = "bool" +INT_TYPE_STR = "int" +FLOAT_TYPE_STR = "float" +SLICE_TYPE_STR = "slice" +TUPLE_TYPE_STR = "tuple" +STR_TYPE_STR = "str" +MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype" +TORCH_DTYPE_TYPE_STR = "torch.dtype" + +api_info_type_str_to_type = { + MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor, + BOOL_TYPE_STR: bool, + INT_TYPE_STR: int, + FLOAT_TYPE_STR: float, + SLICE_TYPE_STR: slice, + STR_TYPE_STR: str, + MINDSPORE_DTYPE_TYPE_STR: typing.Type, +} +type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()} + +DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE = np.float64 +DEFAULT_CONSTRUCT_NP_INT_DTYPE = np.float64 +DEFAULT_CONSTRUCT_NP_UINT_DTYPE = np.float64 + +float_dtype_str_list = [ + FLOAT16, + FLOAT32, + FLOAT64, + BFLOAT16, +] + +int_dtype_str_list = [ + INT8, + INT16, + INT32, + INT64, + BOOL, + INT4, +] + +uint_dtype_str_list = [ + UINT8, + UINT16, + UINT32, + UINT64, +] + + + + + +# ======= 比对类 ======= + + + +class CompareResult: + def __init__(self, compare_value, pass_status, err_msg): + self.compare_value = compare_value + self.pass_status = pass_status + self.err_msg = err_msg + + +class BaseCompareAlgorithm(ABC): + def __init__(self) -> None: + super().__init__() + self.compare_algorithm_name = None + self.err_msg_mapping = { + CompareConst.COSINE: { + CompareConst.PASS: "", + CompareConst.ERROR: f"cosine similarity is less than threshold: {CompareConst.COS_THRESHOLD} ", + CompareConst.SKIP: "two inputs are not valid for computing cosine similarity, skip comparing ", + }, + CompareConst.MAX_ABS_ERR: { + CompareConst.PASS: "", + CompareConst.ERROR: "max absolute difference is greater than " \ + f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ", + CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ", + }, + CompareConst.MAX_RELATIVE_ERR: { + CompareConst.PASS: "", + CompareConst.ERROR: "", + CompareConst.SKIP: "", + }, + } + + def __call__(self, bench_compute_element, tested_compute_element): + ''' + Args: + bench_compute_element: ComputeElement + tested_compute_element: ComputeElement + + Return: + compare_result: CompareResult + ''' + if self.check_validity(bench_compute_element, tested_compute_element): + compare_value = self.run_compare(bench_compute_element, tested_compute_element) + pass_status = self.check_pass(compare_value) + else: + logger.warning(f"not suitable for computing {self.compare_algorithm_name}, skip this.") + compare_value = None + pass_status = CompareConst.SKIP + + err_msg = self.err_msg_mapping.get(self.compare_algorithm_name).get(pass_status) + + compare_result = CompareResult(compare_value, pass_status, err_msg) + return compare_result + + @staticmethod + def convert_to_np_float64_ndarray(tensor): + if isinstance(tensor, mindspore.Tensor): + ndarray = tensor.astype(mindspore.float64).numpy() + elif isinstance(tensor, torch.Tensor): + ndarray = tensor.to(torch.float64, copy=True).numpy() + else: + err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \ + "input is not mindspore.Tensor or torch.Tensor" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + return ndarray + + @staticmethod + def check_two_tensor(bench_compute_element, tested_compute_element): + bench_parameter = bench_compute_element.get_parameter() + tested_parameter = tested_compute_element.get_parameter() + + bench_is_tensor = isinstance(bench_parameter, (mindspore.Tensor, torch.Tensor)) + tested_is_tensor = isinstance(tested_parameter, (mindspore.Tensor, torch.Tensor)) + shape_same = bench_compute_element.get_shape() == tested_compute_element.get_shape() + return bench_is_tensor and tested_is_tensor and shape_same + + @abstractmethod + def check_validity(self, bench_compute_element, tested_compute_element): + ''' + Args: + bench_compute_element: ComputeElement + tested_compute_element: ComputeElement + + Return: + check_res: boolean + ''' + raise NotImplementedError + + @abstractmethod + def run_compare(self, bench_compute_element, tested_compute_element): + ''' + Args: + bench_compute_element: ComputeElement + tested_compute_element: ComputeElement + + Return: + compare_value: float/int + ''' + raise NotImplementedError + + @abstractmethod + def check_pass(self, compare_value): + ''' + Args: + compare_value: float/int + + Return: + pass_status: str + ''' + raise NotImplementedError + + +class CosineSimilarityCompareAlgorithm(BaseCompareAlgorithm): + def __init__(self) -> None: + super().__init__() + self.compare_algorithm_name = CompareConst.COSINE + + def check_validity(self, bench_compute_element, tested_compute_element): + return self.check_two_tensor(bench_compute_element, tested_compute_element) + + def run_compare(self, bench_compute_element, tested_compute_element): + bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter()) + tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter()) + + bench_norm = np.linalg.norm(bench_ndarray) + tested_norm = np.linalg.norm(tested_ndarray) + dot_product = np.dot(bench_ndarray.flatten(), tested_ndarray.flatten()) + cosine_similarity = (MsCompareConst.EPSILON + dot_product) / (MsCompareConst.EPSILON + bench_norm * tested_norm) + return cosine_similarity + + def check_pass(self, compare_value): + if compare_value > CompareConst.COS_THRESHOLD: + return CompareConst.PASS + else: + return CompareConst.ERROR + + +class MaxAbsoluteDiffCompareAlgorithm(BaseCompareAlgorithm): + def __init__(self) -> None: + super().__init__() + self.compare_algorithm_name = CompareConst.MAX_ABS_ERR + + def check_validity(self, bench_compute_element, tested_compute_element): + return self.check_two_tensor(bench_compute_element, tested_compute_element) + + def run_compare(self, bench_compute_element, tested_compute_element): + bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter()) + tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter()) + + max_absolute_diff = np.max(np.abs(bench_ndarray - tested_ndarray)) + return max_absolute_diff + + def check_pass(self, compare_value): + if compare_value < CompareConst.MAX_ABS_ERR_THRESHOLD: + return CompareConst.PASS + else: + return CompareConst.ERROR + + +class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm): + def __init__(self) -> None: + super().__init__() + self.compare_algorithm_name = CompareConst.MAX_RELATIVE_ERR + + def check_validity(self, bench_compute_element, tested_compute_element): + return self.check_two_tensor(bench_compute_element, tested_compute_element) + + def run_compare(self, bench_compute_element, tested_compute_element): + bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter()) + tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter()) + + abs_diff = np.abs(bench_ndarray - tested_ndarray) + bench_ndarray_nonzero = np.abs(bench_ndarray) + (bench_ndarray == 0) * MsCompareConst.EPSILON + max_relative_diff = np.max(abs_diff / bench_ndarray_nonzero) + return max_relative_diff + + def check_pass(self, compare_value): + if compare_value < CompareConst.MAX_RELATIVE_ERR_THRESHOLD: + return CompareConst.PASS + else: + return CompareConst.ERROR + + +compare_algorithms = { + CompareConst.COSINE: CosineSimilarityCompareAlgorithm(), + CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(), + CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(), +} + + + +class CompareStandard(Enum): + BINARY_EQUALITY_STANDARD = auto() + ABSOLUTE_THRESHOLD_STANDARD = auto() + ULP_ERROR_STANDARD = auto() + BENCHMARK_STANDARD = auto() + THOUSANDTH_STANDARD = auto() + + +class CompareStandard(Enum): + BINARY_EQUALITY_STANDARD = auto() + ABSOLUTE_THRESHOLD_STANDARD = auto() + ULP_ERROR_STANDARD = auto() + BENCHMARK_STANDARD = auto() + THOUSANDTH_STANDARD = auto() + + +# ======== 文件操作类 ========== + +from collections import defaultdict +from functools import wraps + + +def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None): + ''' + Args: + dict_instance: dict, dict parsed from input json + key: str + key_description: str + accepted_type: tuple + accepted_value: Union[tuple, list] + + Return: + value, the corresponding value of "key" in "dict_instance" + + Exception: + raise ApiAccuracyCheckerException.ParseJsonFailed error when + 1. dict_instance is not a dict + 2. value is None + 3. value is not accepted type + 4. value is not accepted value + ''' + if not isinstance(dict_instance, dict): + error_info = "check_and_get_from_json_dict failed: input is not a dict" + raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) + value = dict_instance.get(key) + if value is None: + error_info = f"check_and_get_from_json_dict failed: {key_description} is missing" + raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) + elif accepted_type is not None and not isinstance(value, accepted_type): + error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}" + raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) + elif accepted_value is not None and value not in accepted_value: + error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}" + raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) + return value + + +def convert_to_tuple(args): + if isinstance(args, (tuple, list)): + return tuple(args) + else: + input_list = [args] + return tuple(input_list) + + +def trim_output_compute_element_list(compute_element_list, forward_or_backward): + ''' + Args: + compute_element_list: List[ComputeElement] + forward_or_backward: str, Union["forward", "backward"] + ''' + trimmed_list = [] + for compute_element in compute_element_list: + if compute_element.get_parameter() is None or \ + (forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list): + # trim case: 1. parameter is None. 2. backward output has non float parameter + continue + trimmed_list.append(compute_element) + return trimmed_list + + + + +# 记录工具函数递归的深度 +recursion_depth = defaultdict(int) + + +def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH): + """装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。""" + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + func_id = id(func) + recursion_depth[func_id] += 1 + + try: + result = func(*args, **kwargs) + finally: + recursion_depth[func_id] -= 1 + return result + + return wrapper + + return decorator + + + +class FileChecker: + """ + The class for check file. + + Attributes: + file_path: The file or dictionary path to be verified. + path_type: file or dictionary + ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability + file_type(str): The correct file type for file + """ + + def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): + self.file_path = file_path + self.path_type = self._check_path_type(path_type) + self.ability = ability + self.file_type = file_type + self.is_script = is_script + + @staticmethod + def _check_path_type(path_type): + if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: + logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) + return path_type + + def common_check(self): + """ + 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 + 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 + """ + check_path_exists(self.file_path) + check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + check_path_length(self.file_path) + check_path_type(self.file_path, self.path_type) + self.check_path_ability() + if self.is_script: + check_path_owner_consistent(self.file_path) + check_path_pattern_valid(self.file_path) + check_common_file_size(self.file_path) + check_file_suffix(self.file_path, self.file_type) + if self.path_type == FileCheckConst.FILE: + check_dirpath_before_read(self.file_path) + return self.file_path + + def check_path_ability(self): + if self.ability == FileCheckConst.WRITE_ABLE: + check_path_writability(self.file_path) + if self.ability == FileCheckConst.READ_ABLE: + check_path_readability(self.file_path) + if self.ability == FileCheckConst.READ_WRITE_ABLE: + check_path_readability(self.file_path) + check_path_writability(self.file_path) + + +class FileOpen: + """ + The class for open file by a safe way. + + Attributes: + file_path: The file or dictionary path to be opened. + mode(str): The file open mode + """ + SUPPORT_READ_MODE = ["r", "rb"] + SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"] + SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"] + + def __init__(self, file_path, mode, encoding='utf-8'): + self.file_path = file_path + self.mode = mode + self.encoding = encoding + self._handle = None + + def __enter__(self): + self.check_file_path() + binary_mode = "b" + if binary_mode not in self.mode: + self._handle = open(self.file_path, self.mode, encoding=self.encoding) + else: + self._handle = open(self.file_path, self.mode) + return self._handle + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._handle: + self._handle.close() + + def check_file_path(self): + support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE + if self.mode not in support_mode: + logger.error("File open not support %s mode" % self.mode) + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) + check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + check_path_length(self.file_path) + self.check_ability_and_owner() + check_path_pattern_valid(self.file_path) + if os.path.exists(self.file_path): + check_common_file_size(self.file_path) + check_dirpath_before_read(self.file_path) + + def check_ability_and_owner(self): + if self.mode in self.SUPPORT_READ_MODE: + check_path_exists(self.file_path) + check_path_readability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path): + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path): + check_path_readability(self.file_path) + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + + +def check_link(path): + abs_path = os.path.abspath(path) + if os.path.islink(abs_path): + logger.error('The file path {} is a soft link.'.format(path)) + raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) + + +def check_path_length(path, name_length=None): + file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH + if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ + len(os.path.basename(path)) > file_max_name_length: + logger.error('The file path length exceeds limit.') + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_path_exists(path): + if not os.path.exists(path): + logger.error('The file path %s does not exist.' % path) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_path_readability(path): + if not os.access(path, os.R_OK): + logger.error('The file path %s is not readable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_writability(path): + if not os.access(path, os.W_OK): + logger.error('The file path %s is not writable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_executable(path): + if not os.access(path, os.X_OK): + logger.error('The file path %s is not executable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_other_user_writable(path): + st = os.stat(path) + if st.st_mode & 0o002: + logger.error('The file path %s may be insecure because other users have write permissions. ' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_owner_consistent(path): + file_owner = os.stat(path).st_uid + if file_owner != os.getuid() and os.getuid() != 0: + logger.error('The file path %s may be insecure because is does not belong to you.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_pattern_valid(path): + if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): + logger.error('The file path %s contains special characters.' % (path)) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_file_size(file_path, max_size): + try: + file_size = os.path.getsize(file_path) + except OSError as os_error: + logger.error(f'Failed to open "{file_path}". {str(os_error)}') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from os_error + if file_size >= max_size: + logger.error(f'The size ({file_size}) of {file_path} exceeds ({max_size}) bytes, tools not support.') + raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) + + +def check_common_file_size(file_path): + if os.path.isfile(file_path): + for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): + if file_path.endswith(suffix): + check_file_size(file_path, max_size) + return + check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE) + + +def check_file_suffix(file_path, file_suffix): + if file_suffix: + if not file_path.endswith(file_suffix): + logger.error(f"The {file_path} should be a {file_suffix} file!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def check_path_type(file_path, file_type): + if file_type == FileCheckConst.FILE: + if not os.path.isfile(file_path): + logger.error(f"The {file_path} should be a file!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + if file_type == FileCheckConst.DIR: + if not os.path.isdir(file_path): + logger.error(f"The {file_path} should be a dictionary!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def check_others_writable(directory): + dir_stat = os.stat(directory) + is_writable = ( + bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写 + bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写 + ) + return is_writable + + +def make_dir(dir_path): + check_path_before_create(dir_path) + dir_path = os.path.realpath(dir_path) + if os.path.isdir(dir_path): + return + try: + os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) + except OSError as ex: + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, + f"Failed to create {dir_path}. " + f"Please check the path permission or disk space. {str(ex)}") from ex + file_check = FileChecker(dir_path, FileCheckConst.DIR) + file_check.common_check() + + + + +@recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16) +def create_directory(dir_path): + """ + Function Description: + creating a safe directory with specified permissions + Parameter: + dir_path: directory path + Exception Description: + when invalid data throw exception + """ + check_link(dir_path) + check_path_before_create(dir_path) + dir_path = os.path.realpath(dir_path) + parent_dir = os.path.dirname(dir_path) + if not os.path.isdir(parent_dir): + create_directory(parent_dir) + make_dir(dir_path) + + +def check_path_before_create(path): + check_link(path) + if path_len_exceeds_limit(path): + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.') + + if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)): + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, + 'The file path {} contains special characters.'.format(path)) + + +def check_dirpath_before_read(path): + path = os.path.realpath(path) + dirpath = os.path.dirname(path) + + +def check_file_or_directory_path(path, isdir=False): + """ + Function Description: + check whether the path is valid + Parameter: + path: the path to check + isdir: the path is dir or file + Exception Description: + when invalid data throw exception + """ + if isdir: + path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) + else: + path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE) + path_checker.common_check() + + +def change_mode(path, mode): + if not os.path.exists(path) or os.path.islink(path): + return + try: + os.chmod(path, mode) + except PermissionError as ex: + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, + 'Failed to change {} authority. {}'.format(path, str(ex))) from ex + + +@recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod') +def recursive_chmod(path): + """ + 递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750 + + :param path: 要修改权限的目录路径 + """ + for _, dirs, files in os.walk(path): + for file_name in files: + file_path = os.path.join(path, file_name) + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + for dir_name in dirs: + dir_path = os.path.join(path, dir_name) + change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY) + recursive_chmod(dir_path) + + +def path_len_exceeds_limit(file_path): + return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ + len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH + + +def check_file_type(path): + """ + Function Description: + determine if it is a file or a directory + Parameter: + path: path + Exception Description: + when neither a file nor a directory throw exception + """ + if os.path.isdir(path): + return FileCheckConst.DIR + elif os.path.isfile(path): + return FileCheckConst.FILE + else: + logger.error(f'{path} does not exist, please check!') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def load_npy(filepath): + check_file_or_directory_path(filepath) + try: + npy = np.load(filepath, allow_pickle=False) + except Exception as e: + logger.error(f"The numpy file failed to load. Please check the path: {filepath}.") + raise RuntimeError(f"Load numpy file {filepath} failed.") from e + return npy + + +def check_file_or_directory_path(path, isdir=False): + """ + Function Description: + check whether the path is valid + Parameter: + path: the path to check + isdir: the path is dir or file + Exception Description: + when invalid data throw exception + """ + if isdir: + path_checker = FileChecker(path, DIR, WRITE_ABLE) + else: + path_checker = FileChecker(path, FILE, READ_ABLE) + path_checker.common_check() + + +def change_mode(path, mode): + if not os.path.exists(path) or os.path.islink(path): + return + try: + os.chmod(path, mode) + except PermissionError as ex: + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, + 'Failed to change {} authority. {}'.format(path, str(ex))) from ex + + +@recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod') +def recursive_chmod(path): + """ + 递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750 + + :param path: 要修改权限的目录路径 + """ + for _, dirs, files in os.walk(path): + for file_name in files: + file_path = os.path.join(path, file_name) + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + for dir_name in dirs: + dir_path = os.path.join(path, dir_name) + change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY) + recursive_chmod(dir_path) + + +def path_len_exceeds_limit(file_path): + return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ + len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH + + +def write_csv(data, filepath, mode="a+", malicious_check=False): + def csv_value_is_valid(value: str) -> bool: + if not isinstance(value, str): + return True + try: + # -1.00 or +1.00 should be considered as digit numbers + float(value) + except ValueError: + # otherwise, they will be considered as formular injections + return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) + return True + + if malicious_check: + for row in data: + for cell in row: + if not csv_value_is_valid(cell): + raise RuntimeError(f"Malicious value [{cell}] is not allowed " + f"to be written into the csv: {filepath}.") + + check_path_before_create(filepath) + file_path = os.path.realpath(filepath) + try: + with FileOpen(filepath, mode, encoding='utf-8-sig') as f: + writer = csv.writer(f) + writer.writerows(data) + except Exception as e: + logger.error(f'Save csv file "{os.path.basename(file_path)}" failed') + raise RuntimeError(f"Save csv file {file_path} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + print(f"file_path:{file_path}") + + + +def write_csv_header(csv_path, header_func): + """如果是第一次写入,则写入 CSV 表头""" + header = header_func() # 获取表头 + logger.debug(f"Writing CSV header: {header}") + write_csv([header], csv_path, mode="a+") + + +def get_result_csv_header(): + """获取结果 CSV 文件的表头""" + return [ + MsCompareConst.DETAIL_CSV_API_NAME, + MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS, + MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS, + MsCompareConst.DETAIL_CSV_MESSAGE, + ] + + +def get_detail_csv_header(): + """获取详细 CSV 文件的表头""" + detail_csv_header_basic_info = [ + MsCompareConst.DETAIL_CSV_API_NAME, + MsCompareConst.DETAIL_CSV_BENCH_DTYPE, + MsCompareConst.DETAIL_CSV_TESTED_DTYPE, + MsCompareConst.DETAIL_CSV_SHAPE, + ] + detail_csv_header_compare_result = list(compare_algorithms.keys()) + detail_csv_header_status = [ + MsCompareConst.DETAIL_CSV_PASS_STATUS, + MsCompareConst.DETAIL_CSV_MESSAGE, + ] + return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status + + +def check_csv_header(headers, required_constants, csv_path): + """校验 CSV 文件表头是否包含所有必需的常量""" + missing_constants = [const for const in required_constants if not any(const in header for header in headers)] + + if missing_constants: + raise MsprobeBaseException( + MsprobeBaseException.MISSING_HEADER_ERROR, + f"{csv_path} 缺少以下必需的表头字段: {missing_constants}" + ) +def add_time_as_suffix(name): + return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) + + +# ======= 结果落盘管理类 ======== + + + + +class DataManager: + def __init__(self, csv_dir, result_csv_path): + self.results = {} + self.results_exception_skip = {} + self.is_first_write = True # 标记用于添加表头 + self.csv_dir = csv_dir + self.api_names_set = set() # 存储已经出现的 API 名称的集合 + # 如果传入了 result_csv_path,则启用断点续检 + if result_csv_path: + self.resume_from_last_csv(result_csv_path) + self.initialize_api_names_set(result_csv_path) + else: + # 默认情况下,设置输出路径为空,等待首次写入时初始化 + self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME)) + self.detail_out_path = os.path.join( + self.csv_dir, + os.path.basename(self.result_out_path).replace("result", "details") + ) + + if self.detail_out_path and os.path.exists(self.detail_out_path): + check_file_or_directory_path(self.detail_out_path) + + if self.result_out_path and os.path.exists(self.result_out_path): + check_file_or_directory_path(self.result_out_path) + + def initialize_api_names_set(self, result_csv_path): + """读取现有的 CSV 文件并存储已经出现的 API 名称到集合中""" + # 使用新的 read_csv 函数读取数据 + csv_data = read_csv(result_csv_path, as_pd=False) + + # 读取标题行 + headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空 + + # 使用提取的表头校验函数 + if check_csv_header(headers, get_result_csv_header(), result_csv_path): + + # 获取 "API Name" 列的索引 + api_name_index = None + for i, header in enumerate(headers): + if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找 + api_name_index = i + break + + if api_name_index is None: + logger.warning(f"{result_csv_path} No column contains 'API Name'.") + return + + # 读取每一行的 API 名称 + for row in csv_data[1:]: # 跳过标题行,从第二行开始 + if row and len(row) > api_name_index: + api_name = row[api_name_index] + if api_name: + self.api_names_set.add(api_name) + + logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}") + + def is_unique_api(self, api_name): + """检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True""" + if api_name in self.api_names_set: + return False + self.api_names_set.add(api_name) + return True + + def resume_from_last_csv(self, result_csv_path): + """从上次运行的 result_csv_path 恢复断点""" + # 获取上次的目录路径 + last_dir = os.path.dirname(result_csv_path) + + # 设置当前目录和输出路径,确保在首次写入时使用 + self.csv_dir = last_dir + self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details")) + if self.detail_out_path and os.path.exists(self.detail_out_path): + check_file_or_directory_path(self.detail_out_path) + self.result_out_path = result_csv_path + self.is_first_write = False + + def save_results(self, api_name_str): + if self.is_first_write: + # 直接写入表头 + logger.info("Writing CSV headers for the first time.") + write_csv_header(self.detail_out_path, get_detail_csv_header) + write_csv_header(self.result_out_path, get_result_csv_header) + self.is_first_write = False # 写入后标记为 False,避免重复写入表头 + + """写入详细输出和结果摘要并清理结果""" + logger.debug("Starting to write detailed output to CSV.") + self.to_detail_csv(self.detail_out_path) + logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.") + + logger.debug("Starting to write result summary to CSV.") + self.to_result_csv(self.result_out_path) + logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.") + + # 清理记录,准备下一次调用 + self.clear_results() + + def record(self, output_list): + if output_list is None: + return + for output in output_list: + api_real_name, forward_or_backward, basic_info, compare_result_dict = output + key = (api_real_name, forward_or_backward) + if key not in self.results: + self.results[key] = [] + self.results[key].append((basic_info, compare_result_dict)) + # logger.debug(f"Updated self.results for key {key}: {self.results[key]}") + logger.debug(f"Complete self.results after recording: {self.results}") + + def record_exception_skip(self, api_name, forward_or_backward, err_msg): + ''' + record exception_skip information into self.record_exception_skip. + self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}} + string in key is api_name, string in value is err_msg + ''' + if api_name not in self.results_exception_skip: + self.results_exception_skip[api_name] = {Const.FORWARD: None, Const.BACKWARD: None} + self.results_exception_skip[api_name][forward_or_backward] = err_msg + + def clear_results(self): + """清空 self.results 数据""" + logger.debug("Clearing self.results data.") + self.results.clear() + self.results_exception_skip.clear() + + def to_detail_csv(self, csv_path): + logger.debug("Preparing detail CSV headers and rows.") + detail_csv = [] + + detail_csv_header_compare_result = list(compare_algorithms.keys()) + + for _, results in self.results.items(): + for res in results: + basic_info, compare_result_dict = res + csv_row_basic_info = [ + basic_info.api_name, + basic_info.bench_dtype, + basic_info.tested_dtype, + basic_info.shape + ] + csv_row_compare_result = [ + compare_result_dict.get(algorithm_name).compare_value + for algorithm_name in detail_csv_header_compare_result + ] + csv_row_status = [basic_info.status, basic_info.err_msg] + csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status + detail_csv.append(csv_row) + logger.debug(f"Detail CSV row added: {csv_row}") + + logger.debug(f"Writing detail CSV to {csv_path}.") + write_csv(detail_csv, csv_path, mode="a+") + logger.debug(f"Detail CSV written successfully to {csv_path}.") + + def to_result_csv(self, csv_path): + ''' + depend on both self.results and self.results_exception_skip + ''' + logger.debug("Preparing result CSV data.") + result_csv = [] + + result_csv_dict = {} + for key, results in self.results.items(): + api_real_name, forward_or_backward = key + pass_status = CompareConst.PASS + overall_err_msg = "" + + for res in results: + basic_info, _ = res + if basic_info.status != CompareConst.PASS: + pass_status = CompareConst.ERROR + overall_err_msg += basic_info.err_msg + + overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg + + if api_real_name not in result_csv_dict: + result_csv_dict[api_real_name] = ResultCsvEntry() + if forward_or_backward == Const.FORWARD: + result_csv_dict[api_real_name].forward_pass_status = pass_status + result_csv_dict[api_real_name].forward_err_msg = overall_err_msg + else: + result_csv_dict[api_real_name].backward_pass_status = pass_status + result_csv_dict[api_real_name].backward_err_msg = overall_err_msg + + for api_name, entry in result_csv_dict.items(): + overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and + entry.backward_pass_status == CompareConst.PASS) else \ + entry.forward_err_msg + entry.backward_err_msg + row = [ + api_name, + entry.forward_pass_status, + entry.backward_pass_status, + overall_err_msg + ] + # change row if this api has exception_skip information + if api_name in self.results_exception_skip: + if self.results_exception_skip[api_name][Const.FORWARD] is not None: + row[1] = CompareConst.SKIP + row[-1] += self.results_exception_skip[api_name][Const.FORWARD] + if self.results_exception_skip[api_name][Const.BACKWARD] is not None: + row[2] = CompareConst.SKIP + row[-1] += self.results_exception_skip[api_name][Const.BACKWARD] + del self.results_exception_skip[api_name] + result_csv.append(row) + logger.debug(f"Result CSV row added: {row}") + for api_name in self.results_exception_skip: + current_exception_skip = self.results_exception_skip[api_name] + forward_status = None + backward_status = None + err_msg = "" + if current_exception_skip[Const.FORWARD] is not None: + forward_status = CompareConst.SKIP + err_msg += current_exception_skip[Const.FORWARD] + if current_exception_skip[Const.BACKWARD] is not None: + backward_status = CompareConst.SKIP + err_msg += current_exception_skip[Const.BACKWARD] + row = [api_name, forward_status, backward_status, err_msg] + result_csv.append(row) + + write_csv(result_csv, csv_path, mode="a+") + logger.debug(f"Result CSV written successfully to {csv_path}.") + + # 设置标记为 False,防止后续重复添加表头 + self.is_first_write = False + + +# ======== 输入类型类 ======= +class GlobalContext: + def __init__(self): + self.is_constructed = True + self.dump_data_dir = "" + self.framework = Const.MS_FRAMEWORK + + def init(self, is_constructed, dump_data_dir, framework): + self.is_constructed = is_constructed + self.dump_data_dir = dump_data_dir + self.framework = framework + + def get_dump_data_dir(self): + return self.dump_data_dir + + def get_is_constructed(self): + return self.is_constructed + + def get_framework(self): + return self.framework + + +global_context = GlobalContext() + + + +class ApiInputAggregation: + def __init__(self, inputs, kwargs, gradient_inputs) -> None: + """ + Args: + inputs: List[ComputeElement] + kwargs: dict{str: ComputeElement} + gradient_inputs: Union[List[ComputeElement], None] + """ + self.inputs = inputs + self.kwargs = kwargs + self.gradient_inputs = gradient_inputs + + +api_parent_module_mapping = { + (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint, + (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch, + (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional, + (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional, + (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor, + (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor, + (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): mindtorch_tensor, + (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): torch.Tensor, + (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): mindtorch, + (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): torch, + (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func, + (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional, + (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist, + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed, + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops + +} + + +api_parent_module_str_mapping = { + (MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint", + (MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch", + (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional", + (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional", + (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor", + (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor", + (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): "mindtorch_tensor", + (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): "torch.Tensor", + (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): "mindtorch", + (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): "torch", + (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func", + (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional", + (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist", + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed", + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops" +} + + +class ApiRunner: + def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD, + api_platform=Const.MS_FRAMEWORK): + ''' + Args: + api_input_aggregation: ApiInputAggregation + api_name_str: str, e.g. "MintFunctional.relu.0" + forward_or_backward: str, Union["forward", "backward"] + api_platform: str, Union["mindspore", "torch", "mindtorch"] + + Return: + outputs: list[ComputeElement] + + Description: + run mindspore.mint/torch api + ''' + + api_type_str, api_sub_name = self.get_info_from_name(api_name_str, api_platform) + api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform) + + return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform) + + @staticmethod + def get_info_from_name(api_name_str, api_platform=Const.MS_FRAMEWORK): + """ + Args: + api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0" + api_platform: str, the platform for the API, which can be either "mindspore" or "mindtorch". + It specifies which framework is being used. Default is "mindspore". + Return: + api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Torch", "Functional"] + api_sub_name: str, e.g. "relu" + """ + api_name_list = api_name_str.split(Const.SEP) + if len(api_name_list) != 3: + err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + api_type_str, api_sub_name = api_name_list[0], api_name_list[1] + if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API, + MsCompareConst.FUNCTIONAL_API] \ + and api_platform == Const.MS_FRAMEWORK: + err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + + if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK: + err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + return api_type_str, api_sub_name + + @staticmethod + def get_api_instance(api_type_str, api_sub_name, api_platform): + """ + Args: + api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"] + api_sub_name: str, e.g. "relu" + api_platform: str: Union["mindspore", "pytorch"] + + Return: + api_instance: function object + + Description: + get mindspore.mint/torch api function + mindspore.mint.{api_sub_name} <--> torch.{api_sub_name} + mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name} + """ + + api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) + full_api_name = api_parent_module_str + Const.SEP + api_sub_name + + if not hasattr(api_parent_module, api_sub_name): + err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong)) + + api_instance = getattr(api_parent_module, api_sub_name) + if not callable(api_instance): + err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong)) + + return api_instance + + @staticmethod + def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform): + inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform) + for compute_element in api_input_aggregation.inputs) + kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform) + for key, value in api_input_aggregation.kwargs.items()} + gradient_inputs = api_input_aggregation.gradient_inputs + + if forward_or_backward == Const.FORWARD: + forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple + forward_result_tuple = convert_to_tuple(forward_result) + res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple] + if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK: + return res_compute_element_list, inputs, kwargs, forward_result_tuple + else: + if gradient_inputs is None: + err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform) + for compute_element in gradient_inputs) + if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK: + if len(gradient_inputs) == 1: + gradient_inputs = gradient_inputs[0] + + def api_with_kwargs(*forward_inputs): + return api_instance(*forward_inputs, **kwargs) + + grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs) + backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple + backward_result_tuple = convert_to_tuple(backward_result) + res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple] + return res_compute_element_list, gradient_inputs, backward_result_tuple + else: + # set requires_grad + requires_grad_index = [] + for index, tensor in enumerate(inputs): + if isinstance(tensor, torch.Tensor) and \ + torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list: + setattr(tensor, "requires_grad", True) + requires_grad_index.append(index) + forward_results = api_instance(*inputs, **kwargs) + forward_results = convert_to_tuple(forward_results) + for forward_res, gradient_in in zip(forward_results, gradient_inputs): + forward_res.backward(gradient_in) + backward_result_list = [] + for index in requires_grad_index: + backward_result_list.append(getattr(inputs[index], "grad")) + res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list] + + return res_compute_element_list + + +api_runner = ApiRunner() + +# ======== 数据结构类 ======== + +class ResultCsvEntry: + def __init__(self) -> None: + self.forward_pass_status = None + self.backward_pass_status = None + self.forward_err_msg = "" + self.backward_err_msg = "" + self.overall_err_msg = None + +class ProcessResultPacket: + def __init__(self, process_status, result, err_msg) -> None: + self.process_status = process_status + self.result = result + self.err_msg = err_msg + +class MstensorMetaData: + def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None: + self.dtype_str = dtype_str + self.npy_path = npy_path + self.maximum = maximum + self.minimum = minimum + self.shape = shape + + +class DtypeMetaData: + def __init__(self, dtype_str) -> None: + self.dtype_str = dtype_str + + +class ComputeElement: + def __init__(self, compute_element_info=None, parameter=None): + self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple]) + if parameter is not None: + self._init_with_parameter(parameter) + elif isinstance(compute_element_info, (list, dict)): + self._init_from_compute_element_info(compute_element_info) + elif compute_element_info is None: + self._init_from_null_compute_element_info() + else: + pass + logger.error_log_with_exp( + "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)", + ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + + @staticmethod + def transfer_to_torch_tensor(ms_tensor): + ''' + Args: + ms_tensor: mindspore.Tensor + Return: + torch_tensor: torch.Tensor + ''' + ms_dtype = ms_tensor.dtype + dtype_str = ms_dtype_to_dtype_str.get(ms_dtype) + if dtype_str not in dtype_str_to_torch_dtype: + err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + else: + torch_dtype = dtype_str_to_torch_dtype.get(dtype_str) + + if dtype_str in int_dtype_str_list: + middle_dtype = mindspore.int64 + else: + middle_dtype = mindspore.float64 + np_ndarray = ms_tensor.astype(middle_dtype).numpy() + torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype) + return torch_tensor + + @staticmethod + def transfer_to_mindtorch_tensor(ms_tensor): + """ + Args: + ms_tensor: mindspore.Tensor + Return: + mindtorch_tensor: mindtorch.Tensor + """ + + ms_dtype = ms_tensor.dtype + + dtype_str = ms_dtype_to_dtype_str.get(ms_dtype) + + if dtype_str not in dtype_str_to_mindtorch_dtype: + err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}" + logger.error_log_with_exp(err_msg, + ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + else: + mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str) + + if dtype_str in int_dtype_str_list: + middle_dtype = mindspore.int64 + else: + middle_dtype = mindspore.float64 + + np_ndarray = ms_tensor.astype(middle_dtype).numpy() + + mindtorch_tensor = mindtorch.from_numpy(np_ndarray).to(ms_dtype) + + return mindtorch_tensor + + @staticmethod + def transfer_to_mindspore_tensor(torch_tensor): + ''' + Args: + torch_tensor: torch.Tensor + + Return: + ms_tensor: mindspore.Tensor + ''' + torch_dtype = torch_tensor.dtype + dtype_str = torch_dtype_to_dtype_str.get(torch_dtype) + if dtype_str not in dtype_str_to_ms_dtype: + err_msg = \ + f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + else: + ms_dtype = dtype_str_to_ms_dtype.get(dtype_str) + + if dtype_str in int_dtype_str_list: + middle_dtype = torch.int64 + else: + middle_dtype = torch.float64 + np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy() + ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype) + return ms_tensor + + @staticmethod + def convert_inf_to_real_num(value, dtype_str): + if value == float("inf"): + np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE) + value = np.finfo(np_dtype).max + elif value == float("-inf"): + np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE) + value = np.finfo(np_dtype).min + return value + + def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK): + ''' + Args: + get_origin: boolean + tensor_platform: str, Union["mindspore", "pytorch"] + + Return: + parameter: Union[int, float, str, slice, tuple, torch.Tensor, mindspore.Tensor] + ''' + if self.parameter is None: + return self.parameter + if isinstance(self.parameter, tuple): + return tuple([compute_element.get_parameter(get_origin=get_origin, tensor_platform=tensor_platform) + for compute_element in self.parameter]) + elif isinstance(self.parameter, self.supported_parameter_type): + parameter_tmp = self.parameter + elif isinstance(self.parameter, DtypeMetaData): + if tensor_platform == Const.MS_FRAMEWORK: + parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str) + elif tensor_platform == Const.PT_FRAMEWORK: + parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str) + elif tensor_platform == Const.MT_FRAMEWORK: + parameter_tmp = dtype_str_to_mindtorch_dtype.get(self.parameter.dtype_str) + + elif isinstance(self.parameter, MstensorMetaData): + mstensor_meta_data = self.parameter + ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str) + if global_context.get_is_constructed(): + np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE) + ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum, + mstensor_meta_data.minimum, np_dtype) + else: + ndarray = load_npy(mstensor_meta_data.npy_path) + parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype) + else: + err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \ + "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + + # if necessary, do transfer + if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK: + parameter = self.transfer_to_torch_tensor(parameter_tmp) + elif not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.MT_FRAMEWORK: + parameter = self.transfer_to_mindtorch_tensor(parameter_tmp) + elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK: + parameter = self.transfer_to_mindspore_tensor(parameter_tmp) + else: + parameter = parameter_tmp + + return parameter + + def get_shape(self): + return self.shape + + def get_dtype(self): + return self.dtype_str + + def _construct_ndarray(self, shape, maximum, minimum, np_dtype): + shape = tuple(shape) + np.random.seed(42) + if np_dtype == np.bool_: + ndarray = np.random.rand(*shape) > 0.5 + else: + maximum = self.convert_inf_to_real_num(maximum, np_dtype) + minimum = self.convert_inf_to_real_num(minimum, np_dtype) + ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype) + return ndarray + + def _init_from_null_compute_element_info(self): + self.parameter = None + self.shape = tuple() + self.dtype = "None" + + def _init_from_compute_element_info(self, compute_element_info): + ''' + Args: + compute_element_info: Union[list, dict] + + Return: + void + + init member attributes: self.shape, self.dtype_str, self.parameter + ''' + if isinstance(compute_element_info, list): + self.shape = tuple() + self.dtype_str = TUPLE_TYPE_STR + self.parameter = tuple([ComputeElement(compute_element_info=sub_info) + for sub_info in compute_element_info]) + else: + type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json", + accepted_type=str, accepted_value=api_info_type_str_to_type.keys()) + self.shape = tuple() + self.dtype_str = type_str + if type_str == MINDSPORE_TENSOR_TYPE_STR: + self._init_from_mstensor_compute_element_info(compute_element_info) + else: + value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json") + if type_str == MINDSPORE_DTYPE_TYPE_STR: + self.parameter = DtypeMetaData(value) + elif type_str == SLICE_TYPE_STR: + self.parameter = slice(*tuple(value)) + else: # type_str in ("str", "int", "float", "bool") + self.parameter = value + + def _init_from_mstensor_compute_element_info(self, compute_element_info): + ''' + do not load real tensor, only record meta data + ''' + dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json", + accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys()) + shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json", + accepted_type=(list,)) + if global_context.get_is_constructed(): + maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json", + accepted_type=(int, float)) + minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json", + accepted_type=(int, float)) + + npy_path = None + else: + maximum, minimum = None, None + data_name = check_and_get_from_json_dict(compute_element_info, "data_name", + "data_name field in api_info.json", accepted_type=(str,)) + npy_path = os.path.join(global_context.get_dump_data_dir(), data_name) + mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape) + self.parameter = mstensor_meta_data + self.dtype_str = dtype_str + self.shape = tuple(shape) + + def _init_with_parameter(self, parameter): + self.parameter = parameter + print(f"parameter:{parameter}") + print(f"self.supported_parameter_type:{self.supported_parameter_type}") + if isinstance(parameter, dict): + # 这里假设 dict 中有 'type'、'shape'、'dtype' 等字段 + return self._init_from_compute_element_info(parameter) + self.shape = tuple() + if not isinstance(parameter, self.supported_parameter_type): + err_msg = "ComputeElement._init_with_parameter failed: " \ + "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)" + logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) + if isinstance(parameter, mindspore.Tensor): + self.shape = tuple(parameter.shape) + self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype) + elif isinstance(parameter, torch.Tensor): + self.shape = tuple(parameter.shape) + self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype) + elif isinstance(parameter, typing.Type): + self.dtype_str = MINDSPORE_DTYPE_TYPE_STR + self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter)) + elif isinstance(parameter, torch.dtype): + self.dtype_str = TORCH_DTYPE_TYPE_STR + self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter)) + elif isinstance(parameter, tuple): + self.dtype_str = TUPLE_TYPE_STR + self.parameter = tuple([ComputeElement(parameter=param) for param in parameter]) + else: + self.dtype_str = type_to_api_info_type_str.get(type(parameter)) + print(f"self.dtype_str{self.dtype_str}") + +class BasicInfoAndStatus: + def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None: + self.api_name = api_name + self.bench_dtype = bench_dtype + self.tested_dtype = tested_dtype + self.shape = shape + self.status = status + self.err_msg = err_msg + + + + +# ======== 划分类 ======= + + +def generate_bool_tensor(low, high, shape): + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape) + bool_tensor = torch.gt(tensor, 0) + return bool_tensor + + +def generate_numerical_tensor(low, high, shape, data_dtype): + if data_dtype in TORCH_FLOAT_TYPE: + scale = high - low + rand01 = torch.rand(shape, dtype=eval(data_dtype)) + tensor = rand01 * scale + low + elif data_dtype in TORCH_INT_TYPE: + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype)) + else: + raise NotImplementedError(f"{{data_dtype}} is not supported!") + if torch.numel(tensor) == 0: + return tensor + tmp_tensor = tensor.reshape(-1) + tmp_tensor[0] = low + tmp_tensor[-1] = high + data = tmp_tensor.reshape(shape) + return data + + +def generate_random_tensor(info): + low, high = info.get('Min'), info.get('Max') + data_dtype = info.get('dtype') + shape = tuple(info.get('shape')) + if data_dtype == "torch.bool": + data = generate_bool_tensor(low, high, shape) + else: + data = generate_numerical_tensor(low, high, shape, data_dtype) + return data + + +def generate_real_tensor(data_path): + data_path = os.path.realpath(data_path) + data = load_pt(data_path, to_cpu = True) + return data + + +def generate_data(info): + data_type = info.get("type") + data_path = info.get("data_name") + data_grad = info.get("requires_grad") + if data_type in TENSOR_DATA_LIST: + if data_path: + data = generate_real_tensor(data_path) + else: + data = generate_random_tensor(info) + else: + data = info.get("value") + if data_grad == True: + data.requires_grad_(True) + return data + + +def get_input(propagation): + args_info_forward = {args_info_forward} + kwargs_info_forward = {kwargs_info_forward} + args_info_backward = {args_info_backward} + forward_inputs = [ComputeElement(compute_element_info=compute_element_info) + for compute_element_info in args_info_forward] + kwargs_compute_element_dict = { + key_str: ComputeElement(compute_element_info=compute_element_info) + for key_str, compute_element_info in kwargs_info_forward.items() + } + if args_info_backward: + gradient_inputs = [ComputeElement(compute_element_info=compute_element_info) + for compute_element_info in args_info_backward] + else: + gradient_inputs = None + return ApiInputAggregation( + forward_inputs, + kwargs_compute_element_dict, + gradient_inputs + ) + + + +def exec_api(args, kwargs, args_grad_input, propagation): + output = {api_type}.{api_name}(*args, **kwargs) + if propagation == BACKWARD: + args_input_tensor = [tensor for tensor in args if isinstance(tensor, torch.Tensor) and tensor.requires_grad] + args_input_tensor.extend( + [value for value in kwargs.values() if isinstance(value, torch.Tensor) and value.requires_grad]) + output_backward = torch.autograd.grad(outputs=output, inputs=args_input_tensor, grad_outputs=args_grad_input) + return output_backward + return output + +def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol): + out_bench = out_bench.to(out_device.dtype) + min = torch.finfo(out_device.dtype).min + max = torch.finfo(out_device.dtype).max + bench_clip = torch.clamp(out_bench, min=min, max=max) + device_clip = torch.clamp(out_device, min=min, max=max) + clipped_abs_ae = torch.abs(device_clip - bench_clip) + clipped_re = clipped_abs_ae / abs_bench_with_eps + pass_mask = torch.less_equal(clipped_re, rtol) + both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip)) + pass_mask = torch.logical_or(pass_mask, both_nan_mask) + not_pass_mask = torch.logical_not(pass_mask) + not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask) + inf_nan_err_cnt = torch.sum(not_pass_mask) + return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask) + + +def compute_rmse(abs_err, normal_value_mask): + if torch.sum(normal_value_mask) == 0: + return 0 + else: + masked_ae = torch.where(normal_value_mask, abs_err, 0) + mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask) + rmse = torch.sqrt(mse) + return rmse + + +def compute_error_balance(out_device, out_bench): + larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0)) + smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0)) + if torch.numel(out_bench) == 0: + raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}") + error_balance = abs(larger_count - smaller_count) / torch.numel(out_bench) + return error_balance + + +# 运行和比对函数 +def run_and_compare_helper(api_name_str, api_input_aggregation, forward_or_backward): + """ + Args: + api_info: ApiInfo + api_name_str: str + api_input_aggregation: ApiInputAggregation + forward_or_backward: str: Union["forward", "backward"] + + Return: + output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})] + + Description: + get mindspore api output, run torch api and get output. + compare output. + record compare result. + """ + # get output + if forward_or_backward == Const.FORWARD: + tested_outputs, inputs, kwargs, forward_result_tuple = api_runner(api_input_aggregation, api_name_str, + forward_or_backward, + global_context.get_framework()) + print(f"inputs:{inputs}") + print(f"kwargs:{kwargs}") + print(f"forward_result_tuple:{forward_result_tuple}") + elif forward_or_backward == Const.BACKWARD: + tested_outputs, gradient_inputs, backward_result_tuple = api_runner(api_input_aggregation, api_name_str, + forward_or_backward, + global_context.get_framework()) + print(f"gradient_inputs:{gradient_inputs}") + print(f"backward_result_tuple:{backward_result_tuple}") + else: + tested_outputs = api_runner(api_input_aggregation, api_name_str, + forward_or_backward, global_context.get_framework()) + + bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK) + + tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward) + bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward) + + # compare output + output_list = [] + for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)): + api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)]) + bench_dtype = bench_out.get_dtype() + tested_dtype = tested_out.get_dtype() + shape = bench_out.get_shape() + + compare_result_dict = dict() + for compare_algorithm_name, compare_algorithm in compare_algorithms.items(): + compare_result = compare_algorithm(bench_out, tested_out) + compare_result_dict[compare_algorithm_name] = compare_result + + if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \ + compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS: + status = CompareConst.PASS + err_msg = "" + else: + status = CompareConst.ERROR + err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg + + compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg) + + # self.pre_forward_hook(api_name_str, None, inputs, kwargs) + basic_info_status = \ + BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg) + output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict])) + return output_list + + +if __name__ == "__main__": + framework = "{framework}" + dump_data_dir = "{dump_data_dir}" + api_name = "{api_name}" + api_name_str = "{api_full_name}" + propagation = "{propagation}" + data_mode = "{data_mode}" + torch.manual_seed({random_seed}) + + data_manager = DataManager("./op_result_output", None) + create_directory("./op_result_output") + + print("Before init:", + "is_constructed =", global_context.get_is_constructed(), + "dump_data_dir =", global_context.get_dump_data_dir(), + "framework =", global_context.get_framework()) + + is_constructed = data_mode == "random_data" + global_context.init(is_constructed, dump_data_dir, framework) + print(" After init:", + "is_constructed =", global_context.get_is_constructed(), + "dump_data_dir =", global_context.get_dump_data_dir(), + "framework =", global_context.get_framework()) + + for i in range({iter_times}): + print(f"iter: {{i}}:") + if propagation == BACKWARD: + + + backward_inputs_aggregation = get_input(propagation) + + backward_output_list = run_and_compare_helper(api_name_str, backward_inputs_aggregation, + Const.BACKWARD) + process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS, + result=backward_output_list, err_msg="") + + + if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS: + data_manager.record(process_result_packet.result) + elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP: + data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg) + + data_manager.save_results(api_name_str) + else: + forward_inputs_aggregation = get_input(propagation) + + forward_output_list = run_and_compare_helper(api_name_str, forward_inputs_aggregation, + Const.FORWARD) + process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS, + result=forward_output_list, err_msg="") + + + if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS: + data_manager.record(process_result_packet.result) + elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP: + data_manager.record_exception_skip(api_name_str, Const.FORWARD, process_result_packet.err_msg) + + data_manager.save_results(api_name_str) + + print("Compare finished.") -- Gitee From 7a078b8345a2819290eda5a762559a76411427f7 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 10:23:45 +0800 Subject: [PATCH 02/13] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=B8=8D=E9=9C=80?= =?UTF-8?q?=E8=A6=81=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generate_op_script/op_generator.py | 36 ++++--------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py index ecc6ac1c8f..a7235bc322 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -35,35 +35,6 @@ from msprobe.core.common.log import logger from msprobe.core.common.file_utils import make_dir, change_mode from msprobe.core.common.decorator import recursion_depth_decorator -MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor" -BOOL_TYPE_STR = "bool" -INT_TYPE_STR = "int" -FLOAT_TYPE_STR = "float" -SLICE_TYPE_STR = "slice" -TUPLE_TYPE_STR = "tuple" -STR_TYPE_STR = "str" -MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype" -TORCH_DTYPE_TYPE_STR = "torch.dtype" - -api_info_type_str_to_type = { - MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor, - BOOL_TYPE_STR: bool, - INT_TYPE_STR: int, - FLOAT_TYPE_STR: float, - SLICE_TYPE_STR: slice, - STR_TYPE_STR: str, - MINDSPORE_DTYPE_TYPE_STR: typing.Type, -} -type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()} - -TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] -TORCH_BOOL_TYPE = ["torch.bool"] -TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", - "torch.int64", "torch.long"] -TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", - "torch.float64", "torch.double"] -TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", - "torch.cdouble"] OPERATOR_TYPE = ("Functional", "Tensor", "Torch", "Mint") API_INFO = 2 @@ -147,7 +118,12 @@ class CommonConfig: forward_item = next(iter(json_content.items()), None) if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]: raise ValueError(f'Invalid forward API data in json_content!') - # 需要去除掉影响key + + # if propagation is backward, ensure json file contains forward and backward info 需要去除掉影响key + # if propagation == Const.BACKWARD and len(json_content) < API_INFO + 2: + # raise ValueError(f'Backward propagation requires contains forward and backward info!') + + # if propagation is backward, ensure it has valid data if propagation == Const.BACKWARD: backward_item = list(json_content.items())[1] if not isinstance(backward_item[1], dict) or not backward_item[1]: -- Gitee From 83c17f0dfac6487d3562649eb50e1564c85754e9 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 10:30:18 +0800 Subject: [PATCH 03/13] Update operator_replication.template --- .../operator_replication.template | 105 +++--------------- 1 file changed, 17 insertions(+), 88 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template index 22b30b86b0..978ea0d07d 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -2,25 +2,25 @@ import os import re import stat import time -from enum import Enum, auto -from abc import ABC, abstractmethod -import csv - -import gc import sys -from pathlib import Path -import mindspore -from mindspore import ops - - -from tabulate import tabulate - +import gc import logging - import traceback +import csv +from enum import Enum, auto +from abc import ABC, abstractmethod +from pathlib import Path +from collections import defaultdict +from functools import wraps +import numpy as np +import mindspore +from mindspore import ops +from mindspore._c_expression import typing +from mindspore.common import dtype as mstype +# ===== Logging Setup ===== def error_log_with_exp(self, msg: str, exp: Exception): """ msg: 你的错误提示 @@ -29,11 +29,8 @@ def error_log_with_exp(self, msg: str, exp: Exception): # 将 Exception 的类型、消息和 traceback 通过 exc_info 参数一并传给 .error() self.error(msg, exc_info=(type(exp), exp, exp.__traceback__)) -# 把它挂到 Logger 上 logging.Logger.error_log_with_exp = error_log_with_exp - - # 1. 基本配置:设置日志级别为 INFO,默认输出到控制台 logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', @@ -42,7 +39,8 @@ logging.basicConfig(level=logging.INFO, logger = logging.getLogger() -# ======= 常数类 ======= + +# ===== Exception Classes ===== class CodedException(Exception): def __init__(self, code, error_info=''): @@ -66,7 +64,7 @@ class ApiAccuracyCheckerException(CodedException): ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ", } - +# ======= Constants ======= class FileCheckConst: """ Class for file check const @@ -327,12 +325,8 @@ if not is_valid_pt_mt_env: # ======= 常数类 ======= -import numpy as np -from mindspore._c_expression import typing -from mindspore.common import dtype as mstype -TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] TORCH_BOOL_TYPE = ["torch.bool"] TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", "torch.int64", "torch.long"] @@ -716,10 +710,6 @@ class CompareStandard(Enum): # ======== 文件操作类 ========== -from collections import defaultdict -from functools import wraps - - def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None): ''' Args: @@ -2031,68 +2021,7 @@ class BasicInfoAndStatus: -# ======== 划分类 ======= - - -def generate_bool_tensor(low, high, shape): - low, high = int(low), int(high) - tensor = torch.randint(low, high + 1, shape) - bool_tensor = torch.gt(tensor, 0) - return bool_tensor - - -def generate_numerical_tensor(low, high, shape, data_dtype): - if data_dtype in TORCH_FLOAT_TYPE: - scale = high - low - rand01 = torch.rand(shape, dtype=eval(data_dtype)) - tensor = rand01 * scale + low - elif data_dtype in TORCH_INT_TYPE: - low, high = int(low), int(high) - tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype)) - else: - raise NotImplementedError(f"{{data_dtype}} is not supported!") - if torch.numel(tensor) == 0: - return tensor - tmp_tensor = tensor.reshape(-1) - tmp_tensor[0] = low - tmp_tensor[-1] = high - data = tmp_tensor.reshape(shape) - return data - - -def generate_random_tensor(info): - low, high = info.get('Min'), info.get('Max') - data_dtype = info.get('dtype') - shape = tuple(info.get('shape')) - if data_dtype == "torch.bool": - data = generate_bool_tensor(low, high, shape) - else: - data = generate_numerical_tensor(low, high, shape, data_dtype) - return data - - -def generate_real_tensor(data_path): - data_path = os.path.realpath(data_path) - data = load_pt(data_path, to_cpu = True) - return data - - -def generate_data(info): - data_type = info.get("type") - data_path = info.get("data_name") - data_grad = info.get("requires_grad") - if data_type in TENSOR_DATA_LIST: - if data_path: - data = generate_real_tensor(data_path) - else: - data = generate_random_tensor(info) - else: - data = info.get("value") - if data_grad == True: - data.requires_grad_(True) - return data - - +# ======== 获取输入 ======= def get_input(propagation): args_info_forward = {args_info_forward} kwargs_info_forward = {kwargs_info_forward} -- Gitee From df46c6979eb23bcec3c54074d72fe12d1b71c3f7 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 10:33:08 +0800 Subject: [PATCH 04/13] Update operator_replication.template --- .../operator_replication.template | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template index 978ea0d07d..63f688c810 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -324,23 +324,8 @@ if not is_valid_pt_mt_env: -# ======= 常数类 ======= - - -TORCH_BOOL_TYPE = ["torch.bool"] -TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", - "torch.int64", "torch.long"] -TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", - "torch.float64", "torch.double"] -TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"] -RAISE_PRECISION = {{ - "torch.float16": torch.float32, - "torch.half": torch.float32, - "torch.bfloat16": torch.float32, - "torch.float32": torch.float64, - "torch.float": torch.float64 -}} -THOUSANDTH_THRESHOLDING = 0.001 +# ======= 常数类(支持msadapter) ======= + BACKWARD = 'backward' DIR = "dir" FILE = "file" -- Gitee From 5839153cbfe49a7cefd3ff8be994c6d141358929 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 14:28:06 +0800 Subject: [PATCH 05/13] Update operator_replication.template --- .../generate_op_script/operator_replication.template | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template index 63f688c810..7c7fb33477 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -1893,7 +1893,7 @@ class ComputeElement: def _construct_ndarray(self, shape, maximum, minimum, np_dtype): shape = tuple(shape) - np.random.seed(42) + np.random.seed({random_seed}) if np_dtype == np.bool_: ndarray = np.random.rand(*shape) > 0.5 else: @@ -2147,7 +2147,7 @@ def run_and_compare_helper(api_name_str, api_input_aggregation, forward_or_backw if __name__ == "__main__": framework = "{framework}" - dump_data_dir = "{dump_data_dir}" + dump_data_dir = "{real_data_path}" api_name = "{api_name}" api_name_str = "{api_full_name}" propagation = "{propagation}" -- Gitee From 43a1266e03df53c5a60838ddefef6cae82c1f329 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 14:29:12 +0800 Subject: [PATCH 06/13] Update op_generator.py --- .../generate_op_script/op_generator.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py index a7235bc322..0573e37821 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -45,6 +45,9 @@ API_MAX_LENGTH = 30 PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD] DATAMODE_LIST = ["random_data", "real_data"] ITER_MAX_TIMES = 1000 +FRAMEWORK = 'framework' +REAL_DATA_PATH = 'real_data_path' +EXCLUED = {FRAMEWORK, REAL_DATA_PATH} class APIInfo: @@ -108,10 +111,10 @@ class CommonConfig: raise ValueError(f'content of json file is not a dict!') # ensure the length of json_content is within allowed limits - print(f"json_content:{json_content}") - print(f"len(json_content):{len(json_content)}") - if len(json_content) > API_INFO + 2: + filtered = {k: v for k, v in json_content.items() if k not in EXCLUED} + + if len(filtered) > API_INFO: raise ValueError(f'json file has more than one API, the API only contains forward and backward info') # Retrieve the first API name and dictionary @@ -119,9 +122,9 @@ class CommonConfig: if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]: raise ValueError(f'Invalid forward API data in json_content!') - # if propagation is backward, ensure json file contains forward and backward info 需要去除掉影响key - # if propagation == Const.BACKWARD and len(json_content) < API_INFO + 2: - # raise ValueError(f'Backward propagation requires contains forward and backward info!') + # if propagation is backward, ensure json file contains forward and backward info + if propagation == Const.BACKWARD and len(filtered) < API_INFO: + raise ValueError(f'Backward propagation requires contains forward and backward info!') # if propagation is backward, ensure it has valid data if propagation == Const.BACKWARD: @@ -161,7 +164,7 @@ class APIExtractor: def extract_op(self): self.data = load_json(self.dump_json_path) # 拿到 framework - self.framework = self.data.get('framework', None) + self.framework = self.data.get(FRAMEWORK, None) # print(f"self.data:{self.data}") new_data = {} extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+") # 修改为只要包含或等于apiname即可,不需要是只包含 @@ -178,11 +181,11 @@ class APIExtractor: new_data[key] = value if self.real_data_path is not None: - new_data['real_data_path'] = self.real_data_path + new_data[REAL_DATA_PATH] = self.real_data_path # 把 framework 加进去 if self.framework is not None: - new_data['framework'] = self.framework + new_data[FRAMEWORK] = self.framework if not new_data: logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.") else: @@ -284,8 +287,6 @@ class OperatorScriptGenerator: return internal_settings def generate_forward_inputs_code(self, args_info): - # 先把 generate_args_element_assignment_code 里已定义的 arg_info_x 变量名 - # 取出来,拼成列表 names = [] def collect(info): @@ -304,7 +305,6 @@ class OperatorScriptGenerator: ) def generate_kwargs_compute_element_dict_code(self): - # 我们这里假定 kwargs_device 已经是一个 dict 变量 return ( " # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n" " kwargs_compute_element_dict = {\n" @@ -314,7 +314,6 @@ class OperatorScriptGenerator: ) def generate_gradient_inputs_code(self, args_info_backward): - # 同理收集反向梯度的 arg_info_x names = [] def collect(info): @@ -375,8 +374,8 @@ def _run_operator_generate_commond(cmd_args): args_info_backward = api_info_dict_backward.get(Const.INPUT) op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward) internal_settings = op_generate.get_settings(api_full_name_backward) - internal_settings['framework'] = framework - internal_settings['real_data_path'] = real_data_path + internal_settings[FRAMEWORK] = framework + internal_settings[REAL_DATA_PATH] = real_data_path else: # read and check json api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict @@ -388,8 +387,8 @@ def _run_operator_generate_commond(cmd_args): op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None) internal_settings = op_generate.get_settings(api_full_name_forward) - internal_settings['framework'] = framework - internal_settings['real_data_path'] = real_data_path + internal_settings[FRAMEWORK] = framework + internal_settings[REAL_DATA_PATH] = real_data_path template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template") operator_script_path = os.path.join(cmd_args.api_output_path, -- Gitee From 3dbffed197cafb41c93e99dda8b67e64b32402f5 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 14:30:22 +0800 Subject: [PATCH 07/13] Update operator_replication.template --- .../operator_replication.template | 215 ++++-------------- 1 file changed, 44 insertions(+), 171 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template index 7c7fb33477..096dad7809 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -2,25 +2,25 @@ import os import re import stat import time -import sys -import gc -import logging -import traceback -import csv - from enum import Enum, auto from abc import ABC, abstractmethod -from pathlib import Path -from collections import defaultdict -from functools import wraps +import csv -import numpy as np +import gc +import sys +from pathlib import Path import mindspore from mindspore import ops -from mindspore._c_expression import typing -from mindspore.common import dtype as mstype -# ===== Logging Setup ===== + +from tabulate import tabulate + +import logging + +import traceback + + + def error_log_with_exp(self, msg: str, exp: Exception): """ msg: 你的错误提示 @@ -29,8 +29,11 @@ def error_log_with_exp(self, msg: str, exp: Exception): # 将 Exception 的类型、消息和 traceback 通过 exc_info 参数一并传给 .error() self.error(msg, exc_info=(type(exp), exp, exp.__traceback__)) +# 把它挂到 Logger 上 logging.Logger.error_log_with_exp = error_log_with_exp + + # 1. 基本配置:设置日志级别为 INFO,默认输出到控制台 logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', @@ -39,8 +42,7 @@ logging.basicConfig(level=logging.INFO, logger = logging.getLogger() - -# ===== Exception Classes ===== +# ======= 常数类 ======= class CodedException(Exception): def __init__(self, code, error_info=''): @@ -64,7 +66,7 @@ class ApiAccuracyCheckerException(CodedException): ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ", } -# ======= Constants ======= + class FileCheckConst: """ Class for file check const @@ -324,8 +326,27 @@ if not is_valid_pt_mt_env: -# ======= 常数类(支持msadapter) ======= +# ======= 常数类 ======= +import numpy as np +from mindspore._c_expression import typing +from mindspore.common import dtype as mstype + +TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] +TORCH_BOOL_TYPE = ["torch.bool"] +TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", + "torch.int64", "torch.long"] +TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", + "torch.float64", "torch.double"] +TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"] +RAISE_PRECISION = {{ + "torch.float16": torch.float32, + "torch.half": torch.float32, + "torch.bfloat16": torch.float32, + "torch.float32": torch.float64, + "torch.float": torch.float64 +}} +THOUSANDTH_THRESHOLDING = 0.001 BACKWARD = 'backward' DIR = "dir" FILE = "file" @@ -695,6 +716,10 @@ class CompareStandard(Enum): # ======== 文件操作类 ========== +from collections import defaultdict +from functools import wraps + + def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None): ''' Args: @@ -990,16 +1015,6 @@ def check_path_type(file_path, file_type): logger.error(f"The {file_path} should be a dictionary!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) - -def check_others_writable(directory): - dir_stat = os.stat(directory) - is_writable = ( - bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写 - bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写 - ) - return is_writable - - def make_dir(dir_path): check_path_before_create(dir_path) dir_path = os.path.realpath(dir_path) @@ -1078,46 +1093,10 @@ def change_mode(path, mode): 'Failed to change {} authority. {}'.format(path, str(ex))) from ex -@recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod') -def recursive_chmod(path): - """ - 递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750 - - :param path: 要修改权限的目录路径 - """ - for _, dirs, files in os.walk(path): - for file_name in files: - file_path = os.path.join(path, file_name) - change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - for dir_name in dirs: - dir_path = os.path.join(path, dir_name) - change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY) - recursive_chmod(dir_path) - - def path_len_exceeds_limit(file_path): return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH - -def check_file_type(path): - """ - Function Description: - determine if it is a file or a directory - Parameter: - path: path - Exception Description: - when neither a file nor a directory throw exception - """ - if os.path.isdir(path): - return FileCheckConst.DIR - elif os.path.isfile(path): - return FileCheckConst.FILE - else: - logger.error(f'{path} does not exist, please check!') - raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) - - def load_npy(filepath): check_file_or_directory_path(filepath) try: @@ -1127,56 +1106,6 @@ def load_npy(filepath): raise RuntimeError(f"Load numpy file {filepath} failed.") from e return npy - -def check_file_or_directory_path(path, isdir=False): - """ - Function Description: - check whether the path is valid - Parameter: - path: the path to check - isdir: the path is dir or file - Exception Description: - when invalid data throw exception - """ - if isdir: - path_checker = FileChecker(path, DIR, WRITE_ABLE) - else: - path_checker = FileChecker(path, FILE, READ_ABLE) - path_checker.common_check() - - -def change_mode(path, mode): - if not os.path.exists(path) or os.path.islink(path): - return - try: - os.chmod(path, mode) - except PermissionError as ex: - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, - 'Failed to change {} authority. {}'.format(path, str(ex))) from ex - - -@recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod') -def recursive_chmod(path): - """ - 递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750 - - :param path: 要修改权限的目录路径 - """ - for _, dirs, files in os.walk(path): - for file_name in files: - file_path = os.path.join(path, file_name) - change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - for dir_name in dirs: - dir_path = os.path.join(path, dir_name) - change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY) - recursive_chmod(dir_path) - - -def path_len_exceeds_limit(file_path): - return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH - - def write_csv(data, filepath, mode="a+", malicious_check=False): def csv_value_is_valid(value: str) -> bool: if not isinstance(value, str): @@ -2006,7 +1935,8 @@ class BasicInfoAndStatus: -# ======== 获取输入 ======= +# ======== api执行类 ======= + def get_input(propagation): args_info_forward = {args_info_forward} kwargs_info_forward = {kwargs_info_forward} @@ -2028,54 +1958,6 @@ def get_input(propagation): gradient_inputs ) - - -def exec_api(args, kwargs, args_grad_input, propagation): - output = {api_type}.{api_name}(*args, **kwargs) - if propagation == BACKWARD: - args_input_tensor = [tensor for tensor in args if isinstance(tensor, torch.Tensor) and tensor.requires_grad] - args_input_tensor.extend( - [value for value in kwargs.values() if isinstance(value, torch.Tensor) and value.requires_grad]) - output_backward = torch.autograd.grad(outputs=output, inputs=args_input_tensor, grad_outputs=args_grad_input) - return output_backward - return output - -def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol): - out_bench = out_bench.to(out_device.dtype) - min = torch.finfo(out_device.dtype).min - max = torch.finfo(out_device.dtype).max - bench_clip = torch.clamp(out_bench, min=min, max=max) - device_clip = torch.clamp(out_device, min=min, max=max) - clipped_abs_ae = torch.abs(device_clip - bench_clip) - clipped_re = clipped_abs_ae / abs_bench_with_eps - pass_mask = torch.less_equal(clipped_re, rtol) - both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip)) - pass_mask = torch.logical_or(pass_mask, both_nan_mask) - not_pass_mask = torch.logical_not(pass_mask) - not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask) - inf_nan_err_cnt = torch.sum(not_pass_mask) - return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask) - - -def compute_rmse(abs_err, normal_value_mask): - if torch.sum(normal_value_mask) == 0: - return 0 - else: - masked_ae = torch.where(normal_value_mask, abs_err, 0) - mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask) - rmse = torch.sqrt(mse) - return rmse - - -def compute_error_balance(out_device, out_bench): - larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0)) - smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0)) - if torch.numel(out_bench) == 0: - raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}") - error_balance = abs(larger_count - smaller_count) / torch.numel(out_bench) - return error_balance - - # 运行和比对函数 def run_and_compare_helper(api_name_str, api_input_aggregation, forward_or_backward): """ @@ -2157,17 +2039,8 @@ if __name__ == "__main__": data_manager = DataManager("./op_result_output", None) create_directory("./op_result_output") - print("Before init:", - "is_constructed =", global_context.get_is_constructed(), - "dump_data_dir =", global_context.get_dump_data_dir(), - "framework =", global_context.get_framework()) - is_constructed = data_mode == "random_data" global_context.init(is_constructed, dump_data_dir, framework) - print(" After init:", - "is_constructed =", global_context.get_is_constructed(), - "dump_data_dir =", global_context.get_dump_data_dir(), - "framework =", global_context.get_framework()) for i in range({iter_times}): print(f"iter: {{i}}:") -- Gitee From ba9c69cd137cbbb3f6ebb530647cfb311435d2ec Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 15:11:07 +0800 Subject: [PATCH 08/13] =?UTF-8?q?=E4=BF=AE=E6=94=B9cleancode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../docs/32.generate_operator_PyTorch.md | 107 ++++++++++++++ .../generate_op_script/op_generator.py | 137 ++++++++++-------- .../operator_replication.template | 3 +- 3 files changed, 185 insertions(+), 62 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/docs/32.generate_operator_PyTorch.md diff --git a/debug/accuracy_tools/msprobe/docs/32.generate_operator_PyTorch.md b/debug/accuracy_tools/msprobe/docs/32.generate_operator_PyTorch.md new file mode 100644 index 0000000000..d44989ac35 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/32.generate_operator_PyTorch.md @@ -0,0 +1,107 @@ +# 单算子API自动生成脚本 + +## 1 简介 + +单算子API自动生成脚本通过提取dump数据中的可疑算子,对其进行单API复现,输出单API精度的比对结果。具体而言,该工具可以从dump数据中提取可疑API的前反向信息,根据前反向数据生成单API的前反向过程,最后通过**新精度标准比对法**a将 NPU/GPU 和 CPU 的结果进行比对,从而给出不同比对方法下的比对结果。本工具支持**随机生成模式和真实数据模式**b。 + +a. 在生成单API脚本时可以选择由工具构造随机数获得 dump 数据或选择真实输入的数据进行单API复现。随机生成模式(对应 task: "statistics")执行效率高,可以快速获得结果,但数据精度低,只能大致判断精度问题;真实数据模式(对应 task: "tensor")执行效率略低于随机生成模式,但是数据精度高,可以准确判断精度问题。 + +## 2 使用方式 + +### 前提 +1. 安装 msprobe。详见[ msprobe 安装](./01.installation.md)章节。 +2. 已完成对训练过程的dump,获得dump.json文件。 + [MindSpore 场景下的数据采集](./06.data_dump_MindSpore.md)章节或[Msadapter 场景下的数据采集](./29.data_dump_MSAdapter.md)章节,注意需要配置 level="L1"。 + +3. 发现某个算子疑似存在精度问题,并得知算子名,如Mint.split.1、Functional.softmax.3、Tensor.add.0、Torch.matmul.5等 + +### 2.1 配置config_op.json +单API复现参数配置如下(以复现softmax算子为例): +``` +{ + "dump_json_path": "./dump.json", + "api_name": "Mint.split.1", + "extract_api_path": "Mint.split.1.json", + "propagation": "backward", + "data_mode": "random_data", + "random_seed": 42, + "iter_times": 1 +} +``` +**配置文件参数说明** + + | 参数名称 | 解释 | 是否必选 | + | ---------------------------- |----------------------------------------------------------------------------| ---------------------------------- | + | dump_json_path | dump.json的文件路径,包含所有dump算子的信息;如果已经提取了可疑算子并保存可以不指定。 | 否 | + | api_name | 算子名,如Functional.softmax.3、Tensor.add.0、Torch.matmul.5等。 | 否 | + | extract_api_path | 提取可疑算子的json文件路径 | 是 | + | propagation | 选择复现算子的forward还是backward,默认为forward | 否 | + | data_mode | 选择复现算子的随机数据(random_data)还是真实数据(real_data)模式,默认为random_data | 否 | + | random_seed | 仅random_data模式有效,表示手动设定的随机种子,默认为1234 | 否 | + | iter_times | 仅random_data模式有效,表示单API运行的次数,由于安全相关原因,最大支持设置为1000 | 否 | + + ### 2.2 运行命令生成单API脚本 +config_op.json配置好后,运行如下命令: +``` +msprobe -f mindspore op_generate -i ./config.json -o ./ +``` +或者 + +进入到mstt的generate_op_script文件夹 +``` +cd mstt/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script +``` +运行 +``` +python op_generator.py -i ./config_op.json -o ./ +``` +**参数说明** + | 参数名称 | 解释 | 是否必选 | + | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- | + | -i 或 --config_input | config_op.json的路径 | 是 | + | -o 或 --api_output_path | 单API脚本的输出路径 | 是 | + + ### 2.3 运行单API脚本 + 运行完op_generator.py后,会在指定路径下生成api_name.py的单API脚本,例如Mint.split.1.forward.py、Functional.softmax.3.backward.py、Tensor.add.0.forward.py、Torch.matmul.5.backward.py + +运行单API脚本即可获得不同比对方法下的比对结果 +``` +python api_name.py +``` + +**运行结果说明** + +单算子脚本生成的 `accuracy_checking_result_{timestamp}.csv` 和 `accuracy_checking_details_{timestamp}.csv` 文件内容详情如下: + +`accuracy_checking_details_{timestamp}.csv` + +| 字段 | 含义 | +| ------------------- | ------------------------------------------------------------ | +| API Name | API 名称。 | +| Bench Dtype | 标杆数据的 API 数据类型。 | +| Tested Dtype | 被检验数据的 API 数据类型。 | +| Shape | API 的 Shape 信息。 | +| Cosine | 被检验数据与标杆数据的余弦相似度。 | +| MaxAbsErr | 被检验数据与标杆数据的最大绝对误差。 | +| MaxRelativeErr | 被检验数据与标杆数据的最大相对误差。 | +| Status | API 预检通过状态,pass 表示通过测试,error 表示未通过。 | +| Message | 提示信息。 | + +注意:PyTorch 无法对 dtype 为整数类型的 tensor 进行反向求导,而 MindSpore 支持。反向过程的预检仅比较 dtype 为浮点型的输出。 + +`accuracy_checking_result_{timestamp}.csv` + +| 字段 | 含义 | +| --------------------- | ----------------- | +| API Name | API 名称。 | +| Forward Test Success | 前向 API 是否通过测试,pass 为通过,error 为错误。 | +| Backward Test Success | 反向 API 是否通过测试,pass 为通过,error 为错误,如果是空白的话代表该 API 没有反向输出。 | +| Message | 提示信息。 | + +Forward Test Success 和 Backward Test Success 是否通过测试是由 `accuracy_checking_details_{timestamp}.csv` 中的余弦相似度、最大绝对误差判定结果决定的。具体规则详见 [4.1 API 预检指标](#41-api-预检指标)。 +需要注意的是 `accuracy_checking_details_{timestamp}.csv` 中可能存在一个 API 的前向(反向)有多个输出,那么每个输出记录一行,而在 `accuracy_checking_result_{timestamp}.csv` 中的结果需要该 API 的所有结果均为 pass 才能标记为 pass,只要存在一个 error 则标记 error。 + +### 4.1 API 预检指标 + + - API 预检指标是通过对 `accuracy_checking_details_{timestamp}.csv` 中的余弦相似度、最大绝对误差的数值进行判断,得出该 API 是否符合精度标准的参考指标。 + - 余弦相似度大于 0.99,并且最大绝对误差小于 0.0001,标记“pass”,否则标记为“error”。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py index 0573e37821..fa99916f8c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,24 +15,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -import mindspore -import numpy as np -from mindspore._c_expression import typing +# 标准库 import argparse import json +import math import os import re import string -import math +# 第三方库 +import mindspore +from mindspore._c_expression import typing import numpy as np import torch -from msprobe.core.common.file_utils import FileOpen, load_json, save_json -from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int +# 应用程序自定义模块 +from msprobe.core.common.file_utils import ( + FileOpen, + load_json, + save_json, + make_dir, + change_mode, +) +from msprobe.core.common.utils import ( + check_file_or_directory_path, + check_op_str_pattern_valid, + is_int, +) from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import make_dir, change_mode from msprobe.core.common.decorator import recursion_depth_decorator OPERATOR_TYPE = ("Functional", "Tensor", "Torch", "Mint") @@ -111,9 +122,12 @@ class CommonConfig: raise ValueError(f'content of json file is not a dict!') # ensure the length of json_content is within allowed limits + print(f"json_content:{json_content}") + print(f"len(json_content):{len(json_content)}") filtered = {k: v for k, v in json_content.items() if k not in EXCLUED} - + print(f"filtered:{filtered}") + print(f"len(filtered):{len(filtered)}") if len(filtered) > API_INFO: raise ValueError(f'json file has more than one API, the API only contains forward and backward info') @@ -165,19 +179,16 @@ class APIExtractor: self.data = load_json(self.dump_json_path) # 拿到 framework self.framework = self.data.get(FRAMEWORK, None) - # print(f"self.data:{self.data}") + new_data = {} extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+") # 修改为只要包含或等于apiname即可,不需要是只包含 self.real_data_path = self.data.get('dump_data_dir', '') for key, value in self.data.get('data', {}).items(): - print(f"key:{key}") if extract_key_pattern.match(key): if self.real_data_path: - print(f"self.real_data_path:{self.real_data_path}") value = self.load_real_data_path(value, self.real_data_path) - print(f"value:{value}") new_data[key] = value if self.real_data_path is not None: @@ -239,6 +250,57 @@ class OperatorScriptGenerator: api_name = Const.SEP.join([prefix, api_name]) return api_type, api_name, api_order + @staticmethod + def generate_forward_inputs_code(args_info): + names = [] + + def collect(info): + if isinstance(info, dict): + names.append(info["parameter_name"]) + else: + for sub in info: + collect(sub) + + collect(args_info) + + return ( + " forward_inputs = [\n" + " ComputeElement(parameter=info)\n" + " for info in (" + ", ".join(names) + ")\n" + " ]\n" + ) + + @staticmethod + def generate_kwargs_compute_element_dict_code(): + return ( + " # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n" + " kwargs_compute_element_dict = {\n" + " key_str: ComputeElement(compute_element_info=compute_element_info)\n" + " for key_str, compute_element_info in kwargs_device.items()\n" + " }\n" + ) + + @staticmethod + def generate_gradient_inputs_code(args_info_backward): + names = [] + + def collect(info): + if isinstance(info, dict): + names.append(info["parameter_name"]) + else: + for sub in info: + collect(sub) + + collect(args_info_backward) + + return ( + " # —— 构造反向梯度 ComputeElement 列表 —— #\n" + " gradient_inputs = [\n" + " ComputeElement(parameter=info)\n" + " for info in (" + ", ".join(names) + ")\n" + " ]\n" + ) + def get_settings(self, api_full_name): ''' internal_settings contain all information needed for the operator program. @@ -286,52 +348,6 @@ class OperatorScriptGenerator: return internal_settings - def generate_forward_inputs_code(self, args_info): - names = [] - - def collect(info): - if isinstance(info, dict): - names.append(info["parameter_name"]) - else: - for sub in info: collect(sub) - - collect(args_info) - - return ( - " forward_inputs = [\n" - " ComputeElement(parameter=info)\n" - " for info in (" + ", ".join(names) + ")\n" - " ]\n" - ) - - def generate_kwargs_compute_element_dict_code(self): - return ( - " # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n" - " kwargs_compute_element_dict = {\n" - " key_str: ComputeElement(compute_element_info=compute_element_info)\n" - " for key_str, compute_element_info in kwargs_device.items()\n" - " }\n" - ) - - def generate_gradient_inputs_code(self, args_info_backward): - names = [] - - def collect(info): - if isinstance(info, dict): - names.append(info["parameter_name"]) - else: - for sub in info: collect(sub) - - collect(args_info_backward) - - return ( - " # —— 构造反向梯度 ComputeElement 列表 —— #\n" - " gradient_inputs = [\n" - " ComputeElement(parameter=info)\n" - " for info in (" + ", ".join(names) + ")\n" - " ]\n" - ) - def _op_generator_parser(parser): parser.add_argument("-i", "--config_input", dest="config_input", type=str, @@ -379,11 +395,10 @@ def _run_operator_generate_commond(cmd_args): else: # read and check json api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict - print(f"api_full_name_forward:{api_full_name_forward},api_info_dict_forward:{api_info_dict_forward}") + args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS) kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS) - print(f"args_info_forward:{args_info_forward},kwargs_info_forward:{kwargs_info_forward}") op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None) internal_settings = op_generate.get_settings(api_full_name_forward) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template index 096dad7809..36149399d2 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -2031,7 +2031,8 @@ if __name__ == "__main__": framework = "{framework}" dump_data_dir = "{real_data_path}" api_name = "{api_name}" - api_name_str = "{api_full_name}" + api_full_name = "{api_full_name}" + api_name_str = ".".join(api_full_name.split(".")[:3]) propagation = "{propagation}" data_mode = "{data_mode}" torch.manual_seed({random_seed}) -- Gitee From 71819da207d4621053499fead67abfeff545dfa3 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 17:52:09 +0800 Subject: [PATCH 09/13] Update op_generator.py --- .../api_accuracy_checker/generate_op_script/op_generator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py index fa99916f8c..c75452e482 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -122,12 +122,9 @@ class CommonConfig: raise ValueError(f'content of json file is not a dict!') # ensure the length of json_content is within allowed limits - print(f"json_content:{json_content}") - print(f"len(json_content):{len(json_content)}") filtered = {k: v for k, v in json_content.items() if k not in EXCLUED} - print(f"filtered:{filtered}") - print(f"len(filtered):{len(filtered)}") + if len(filtered) > API_INFO: raise ValueError(f'json file has more than one API, the API only contains forward and backward info') -- Gitee From 22776579b8219a08dbeb667177650b07b7a2c4ec Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 17:53:14 +0800 Subject: [PATCH 10/13] Update op_generator.py --- .../api_accuracy_checker/generate_op_script/op_generator.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py index c75452e482..e5a273dcec 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -18,16 +18,10 @@ # 标准库 import argparse import json -import math import os import re import string -# 第三方库 -import mindspore -from mindspore._c_expression import typing -import numpy as np -import torch # 应用程序自定义模块 from msprobe.core.common.file_utils import ( -- Gitee From 5ec19e708afcbc5f2f8a17925e97a6237b14e1bc Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 17:53:20 +0800 Subject: [PATCH 11/13] Update op_generator.py --- .../api_accuracy_checker/generate_op_script/op_generator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py index e5a273dcec..df1c6317e6 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -22,7 +22,6 @@ import os import re import string - # 应用程序自定义模块 from msprobe.core.common.file_utils import ( FileOpen, -- Gitee From 56fdf0cd0cee4ea8b9a4a98c8b7df806efb1e9ac Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 18:13:32 +0800 Subject: [PATCH 12/13] =?UTF-8?q?=E6=A3=80=E8=A7=86=E6=84=8F=E8=A7=81?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/generate_op_script/op_generator.py | 6 ++---- .../generate_op_script/operator_replication.template | 1 - 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py index df1c6317e6..38304d5250 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -96,10 +96,8 @@ class CommonConfig: def check_user_settings(self): iter_t = self.iter_times - if iter_t <= 0: - raise ValueError("iter_times should be an integer bigger than zero!") - if iter_t > ITER_MAX_TIMES: - raise ValueError("iter_times should not be greater than 1000!") + if iter_t <= 0 or iter_t > ITER_MAX_TIMES: + raise ValueError(f"iter_times should be range from 1 to {ITER_MAX_TIMES}.") json_file = self.extract_api_path propagation = self.propagation diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template index 36149399d2..f2042a5f7e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -1295,7 +1295,6 @@ class DataManager: if key not in self.results: self.results[key] = [] self.results[key].append((basic_info, compare_result_dict)) - # logger.debug(f"Updated self.results for key {key}: {self.results[key]}") logger.debug(f"Complete self.results after recording: {self.results}") def record_exception_skip(self, api_name, forward_or_backward, err_msg): -- Gitee From 21b6aa46968e6aa7e8cfe4f3f6b22017d0acaf1a Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 20 May 2025 19:23:12 +0800 Subject: [PATCH 13/13] =?UTF-8?q?=E5=88=86=E6=89=B9=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operator_replication.template | 1865 ----------------- 1 file changed, 1865 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template index f2042a5f7e..2e26f606ac 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -213,1868 +213,3 @@ class MsCompareConst: API_NOT_FOUND = "api_not_found" EXCEPTION_SKIP = "exception_skip" -# ======= mindtorch支持 ======== -import torch as mindtorch -from torch import Tensor as mindtorch_tensor -import torch.nn.functional as mindtorch_func -import torch.distributed as mindtorch_dist - -is_valid_pt_mt_env = True - - -def is_mindtorch(): - mindtorch_check_result = False - try: - import torch as test_torch - from mindspore import Tensor as MindsporeTensor - except ImportError: - return mindtorch_check_result - tensor = test_torch.tensor(0.0) - if isinstance(tensor, MindsporeTensor): - mindtorch_check_result = True - - return mindtorch_check_result - - -def remove_torch_related_paths(): - removed_paths = [] - if not is_mindtorch(): - return - try: - import torch as remove_torch - torch_file = remove_torch.__file__ - except ImportError: - return - - torch_dir = os.path.dirname(torch_file) - - torch_dir_path = Path(torch_dir).resolve() - parent_dir = torch_dir_path.parent - - paths_to_remove = [str(parent_dir)] - - for path in paths_to_remove: - try: - path_resolved = str(Path(path).resolve()) - except Exception as error: - logger.debug(f"Failed to resolve path {path}: {error}") - - - if path_resolved in sys.path: - index = sys.path.index(path_resolved) - removed_paths.append((path_resolved, index)) - sys.path.pop(index) - - return - - -def clear_torch_from_sys_modules(): - modules_to_remove = [] - for module in sys.modules: - if module == "torch" or module.startswith("torch."): - modules_to_remove.append(module) - - for module in modules_to_remove: - del sys.modules[module] - - -def set_pt_mt_env_invalid(): - global is_valid_pt_mt_env - is_valid_pt_mt_env = False - - -def delete_torch_paths(): - - if not is_mindtorch(): - set_pt_mt_env_invalid() - - clear_torch_from_sys_modules() - - for count_delete_env_path in range(MsCompareConst.MAX_RECURSION_DEPTH): - if not is_mindtorch(): - break - - remove_torch_related_paths() - - clear_torch_from_sys_modules() - - if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1: - raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure " - f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.") - - -if not is_mindtorch(): - set_pt_mt_env_invalid() - -else: - initial_sys_path = sys.path.copy() - delete_torch_paths() - - gc.collect() - - import torch - - if is_mindtorch(): - set_pt_mt_env_invalid() - - sys.path = initial_sys_path - - - -if not is_valid_pt_mt_env: - import torch - - - -# ======= 常数类 ======= -import numpy as np -from mindspore._c_expression import typing -from mindspore.common import dtype as mstype - - -TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] -TORCH_BOOL_TYPE = ["torch.bool"] -TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", - "torch.int64", "torch.long"] -TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", - "torch.float64", "torch.double"] -TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"] -RAISE_PRECISION = {{ - "torch.float16": torch.float32, - "torch.half": torch.float32, - "torch.bfloat16": torch.float32, - "torch.float32": torch.float64, - "torch.float": torch.float64 -}} -THOUSANDTH_THRESHOLDING = 0.001 -BACKWARD = 'backward' -DIR = "dir" -FILE = "file" -READ_ABLE = "read" -WRITE_ABLE = "write" -READ_WRITE_ABLE = "read and write" -DIRECTORY_LENGTH = 4096 -FILE_NAME_LENGTH = 255 -SOFT_LINK_ERROR = "检测到软链接" -FILE_PERMISSION_ERROR = "文件权限错误" -INVALID_FILE_ERROR = "无效文件" -ILLEGAL_PATH_ERROR = "非法文件路径" -ILLEGAL_PARAM_ERROR = "非法打开方式" -FILE_TOO_LARGE_ERROR = "文件过大" -FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" -FILE_SIZE_DICT = {{ - ".pkl": 1073741824, # 1 * 1024 * 1024 * 1024 - ".npy": 10737418240, # 10 * 1024 * 1024 * 1024 - ".json": 1073741824, # 1 * 1024 * 1024 * 1024 - ".pt": 10737418240, # 10 * 1024 * 1024 * 1024 - ".csv": 1073741824, # 1 * 1024 * 1024 * 1024 - ".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024 - ".yaml": 1073741824, # 1 * 1024 * 1024 * 1024 - ".ir": 1073741824 # 1 * 1024 * 1024 * 1024 -}} -COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 - - -INT8 = "Int8" -UINT8 = "UInt8" -INT16 = "Int16" -UINT16 = "UInt16" -INT32 = "Int32" -UINT32 = "UInt32" -INT64 = "Int64" -UINT64 = "UInt64" -FLOAT16 = "Float16" -FLOAT32 = "Float32" -FLOAT64 = "Float64" -BOOL = "Bool" -BFLOAT16 = "BFloat16" -INT4 = "Int4" - -dtype_str_to_ms_dtype = { - INT8: mstype.int8, - UINT8: mstype.uint8, - INT16: mstype.int16, - UINT16: mstype.uint16, - INT32: mstype.int32, - UINT32: mstype.uint32, - INT64: mstype.int64, - UINT64: mstype.uint64, - FLOAT16: mstype.float16, - FLOAT32: mstype.float32, - FLOAT64: mstype.float64, - BOOL: mstype.bool_, - BFLOAT16: mstype.bfloat16, - INT4: mstype.qint4x2 -} -ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()} - -dtype_str_to_np_dtype = { - INT8: np.int8, - UINT8: np.uint8, - INT16: np.int16, - UINT16: np.uint16, - INT32: np.int32, - UINT32: np.uint32, - INT64: np.int64, - UINT64: np.uint64, - FLOAT16: np.float16, - FLOAT32: np.float32, - FLOAT64: np.float64, - BOOL: np.bool_ -} -np_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_np_dtype.items()} - -dtype_str_to_torch_dtype = { - INT8: torch.int8, - UINT8: torch.uint8, - INT16: torch.int16, - INT32: torch.int32, - INT64: torch.int64, - FLOAT16: torch.float16, - FLOAT32: torch.float32, - FLOAT64: torch.float64, - BOOL: torch.bool, - BFLOAT16: torch.bfloat16, -} -torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()} - - -dtype_str_to_mindtorch_dtype = { - INT8: mindtorch.int8, - UINT8: mindtorch.uint8, - INT16: mindtorch.int16, - INT32: mindtorch.int32, - INT64: mindtorch.int64, - FLOAT16: mindtorch.float16, - FLOAT32: mindtorch.float32, - FLOAT64: mindtorch.float64, - BOOL: mindtorch.bool, - BFLOAT16: mindtorch.bfloat16, -} -mindtorch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_mindtorch_dtype.items()} - -MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor" -BOOL_TYPE_STR = "bool" -INT_TYPE_STR = "int" -FLOAT_TYPE_STR = "float" -SLICE_TYPE_STR = "slice" -TUPLE_TYPE_STR = "tuple" -STR_TYPE_STR = "str" -MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype" -TORCH_DTYPE_TYPE_STR = "torch.dtype" - -api_info_type_str_to_type = { - MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor, - BOOL_TYPE_STR: bool, - INT_TYPE_STR: int, - FLOAT_TYPE_STR: float, - SLICE_TYPE_STR: slice, - STR_TYPE_STR: str, - MINDSPORE_DTYPE_TYPE_STR: typing.Type, -} -type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()} - -DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE = np.float64 -DEFAULT_CONSTRUCT_NP_INT_DTYPE = np.float64 -DEFAULT_CONSTRUCT_NP_UINT_DTYPE = np.float64 - -float_dtype_str_list = [ - FLOAT16, - FLOAT32, - FLOAT64, - BFLOAT16, -] - -int_dtype_str_list = [ - INT8, - INT16, - INT32, - INT64, - BOOL, - INT4, -] - -uint_dtype_str_list = [ - UINT8, - UINT16, - UINT32, - UINT64, -] - - - - - -# ======= 比对类 ======= - - - -class CompareResult: - def __init__(self, compare_value, pass_status, err_msg): - self.compare_value = compare_value - self.pass_status = pass_status - self.err_msg = err_msg - - -class BaseCompareAlgorithm(ABC): - def __init__(self) -> None: - super().__init__() - self.compare_algorithm_name = None - self.err_msg_mapping = { - CompareConst.COSINE: { - CompareConst.PASS: "", - CompareConst.ERROR: f"cosine similarity is less than threshold: {CompareConst.COS_THRESHOLD} ", - CompareConst.SKIP: "two inputs are not valid for computing cosine similarity, skip comparing ", - }, - CompareConst.MAX_ABS_ERR: { - CompareConst.PASS: "", - CompareConst.ERROR: "max absolute difference is greater than " \ - f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ", - CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ", - }, - CompareConst.MAX_RELATIVE_ERR: { - CompareConst.PASS: "", - CompareConst.ERROR: "", - CompareConst.SKIP: "", - }, - } - - def __call__(self, bench_compute_element, tested_compute_element): - ''' - Args: - bench_compute_element: ComputeElement - tested_compute_element: ComputeElement - - Return: - compare_result: CompareResult - ''' - if self.check_validity(bench_compute_element, tested_compute_element): - compare_value = self.run_compare(bench_compute_element, tested_compute_element) - pass_status = self.check_pass(compare_value) - else: - logger.warning(f"not suitable for computing {self.compare_algorithm_name}, skip this.") - compare_value = None - pass_status = CompareConst.SKIP - - err_msg = self.err_msg_mapping.get(self.compare_algorithm_name).get(pass_status) - - compare_result = CompareResult(compare_value, pass_status, err_msg) - return compare_result - - @staticmethod - def convert_to_np_float64_ndarray(tensor): - if isinstance(tensor, mindspore.Tensor): - ndarray = tensor.astype(mindspore.float64).numpy() - elif isinstance(tensor, torch.Tensor): - ndarray = tensor.to(torch.float64, copy=True).numpy() - else: - err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \ - "input is not mindspore.Tensor or torch.Tensor" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) - return ndarray - - @staticmethod - def check_two_tensor(bench_compute_element, tested_compute_element): - bench_parameter = bench_compute_element.get_parameter() - tested_parameter = tested_compute_element.get_parameter() - - bench_is_tensor = isinstance(bench_parameter, (mindspore.Tensor, torch.Tensor)) - tested_is_tensor = isinstance(tested_parameter, (mindspore.Tensor, torch.Tensor)) - shape_same = bench_compute_element.get_shape() == tested_compute_element.get_shape() - return bench_is_tensor and tested_is_tensor and shape_same - - @abstractmethod - def check_validity(self, bench_compute_element, tested_compute_element): - ''' - Args: - bench_compute_element: ComputeElement - tested_compute_element: ComputeElement - - Return: - check_res: boolean - ''' - raise NotImplementedError - - @abstractmethod - def run_compare(self, bench_compute_element, tested_compute_element): - ''' - Args: - bench_compute_element: ComputeElement - tested_compute_element: ComputeElement - - Return: - compare_value: float/int - ''' - raise NotImplementedError - - @abstractmethod - def check_pass(self, compare_value): - ''' - Args: - compare_value: float/int - - Return: - pass_status: str - ''' - raise NotImplementedError - - -class CosineSimilarityCompareAlgorithm(BaseCompareAlgorithm): - def __init__(self) -> None: - super().__init__() - self.compare_algorithm_name = CompareConst.COSINE - - def check_validity(self, bench_compute_element, tested_compute_element): - return self.check_two_tensor(bench_compute_element, tested_compute_element) - - def run_compare(self, bench_compute_element, tested_compute_element): - bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter()) - tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter()) - - bench_norm = np.linalg.norm(bench_ndarray) - tested_norm = np.linalg.norm(tested_ndarray) - dot_product = np.dot(bench_ndarray.flatten(), tested_ndarray.flatten()) - cosine_similarity = (MsCompareConst.EPSILON + dot_product) / (MsCompareConst.EPSILON + bench_norm * tested_norm) - return cosine_similarity - - def check_pass(self, compare_value): - if compare_value > CompareConst.COS_THRESHOLD: - return CompareConst.PASS - else: - return CompareConst.ERROR - - -class MaxAbsoluteDiffCompareAlgorithm(BaseCompareAlgorithm): - def __init__(self) -> None: - super().__init__() - self.compare_algorithm_name = CompareConst.MAX_ABS_ERR - - def check_validity(self, bench_compute_element, tested_compute_element): - return self.check_two_tensor(bench_compute_element, tested_compute_element) - - def run_compare(self, bench_compute_element, tested_compute_element): - bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter()) - tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter()) - - max_absolute_diff = np.max(np.abs(bench_ndarray - tested_ndarray)) - return max_absolute_diff - - def check_pass(self, compare_value): - if compare_value < CompareConst.MAX_ABS_ERR_THRESHOLD: - return CompareConst.PASS - else: - return CompareConst.ERROR - - -class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm): - def __init__(self) -> None: - super().__init__() - self.compare_algorithm_name = CompareConst.MAX_RELATIVE_ERR - - def check_validity(self, bench_compute_element, tested_compute_element): - return self.check_two_tensor(bench_compute_element, tested_compute_element) - - def run_compare(self, bench_compute_element, tested_compute_element): - bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter()) - tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter()) - - abs_diff = np.abs(bench_ndarray - tested_ndarray) - bench_ndarray_nonzero = np.abs(bench_ndarray) + (bench_ndarray == 0) * MsCompareConst.EPSILON - max_relative_diff = np.max(abs_diff / bench_ndarray_nonzero) - return max_relative_diff - - def check_pass(self, compare_value): - if compare_value < CompareConst.MAX_RELATIVE_ERR_THRESHOLD: - return CompareConst.PASS - else: - return CompareConst.ERROR - - -compare_algorithms = { - CompareConst.COSINE: CosineSimilarityCompareAlgorithm(), - CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(), - CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(), -} - - - -class CompareStandard(Enum): - BINARY_EQUALITY_STANDARD = auto() - ABSOLUTE_THRESHOLD_STANDARD = auto() - ULP_ERROR_STANDARD = auto() - BENCHMARK_STANDARD = auto() - THOUSANDTH_STANDARD = auto() - - -class CompareStandard(Enum): - BINARY_EQUALITY_STANDARD = auto() - ABSOLUTE_THRESHOLD_STANDARD = auto() - ULP_ERROR_STANDARD = auto() - BENCHMARK_STANDARD = auto() - THOUSANDTH_STANDARD = auto() - - -# ======== 文件操作类 ========== - -from collections import defaultdict -from functools import wraps - - -def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None): - ''' - Args: - dict_instance: dict, dict parsed from input json - key: str - key_description: str - accepted_type: tuple - accepted_value: Union[tuple, list] - - Return: - value, the corresponding value of "key" in "dict_instance" - - Exception: - raise ApiAccuracyCheckerException.ParseJsonFailed error when - 1. dict_instance is not a dict - 2. value is None - 3. value is not accepted type - 4. value is not accepted value - ''' - if not isinstance(dict_instance, dict): - error_info = "check_and_get_from_json_dict failed: input is not a dict" - raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) - value = dict_instance.get(key) - if value is None: - error_info = f"check_and_get_from_json_dict failed: {key_description} is missing" - raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) - elif accepted_type is not None and not isinstance(value, accepted_type): - error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}" - raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) - elif accepted_value is not None and value not in accepted_value: - error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}" - raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info) - return value - - -def convert_to_tuple(args): - if isinstance(args, (tuple, list)): - return tuple(args) - else: - input_list = [args] - return tuple(input_list) - - -def trim_output_compute_element_list(compute_element_list, forward_or_backward): - ''' - Args: - compute_element_list: List[ComputeElement] - forward_or_backward: str, Union["forward", "backward"] - ''' - trimmed_list = [] - for compute_element in compute_element_list: - if compute_element.get_parameter() is None or \ - (forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list): - # trim case: 1. parameter is None. 2. backward output has non float parameter - continue - trimmed_list.append(compute_element) - return trimmed_list - - - - -# 记录工具函数递归的深度 -recursion_depth = defaultdict(int) - - -def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH): - """装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。""" - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - func_id = id(func) - recursion_depth[func_id] += 1 - - try: - result = func(*args, **kwargs) - finally: - recursion_depth[func_id] -= 1 - return result - - return wrapper - - return decorator - - - -class FileChecker: - """ - The class for check file. - - Attributes: - file_path: The file or dictionary path to be verified. - path_type: file or dictionary - ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability - file_type(str): The correct file type for file - """ - - def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): - self.file_path = file_path - self.path_type = self._check_path_type(path_type) - self.ability = ability - self.file_type = file_type - self.is_script = is_script - - @staticmethod - def _check_path_type(path_type): - if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: - logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') - raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) - return path_type - - def common_check(self): - """ - 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 - 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 - """ - check_path_exists(self.file_path) - check_link(self.file_path) - self.file_path = os.path.realpath(self.file_path) - check_path_length(self.file_path) - check_path_type(self.file_path, self.path_type) - self.check_path_ability() - if self.is_script: - check_path_owner_consistent(self.file_path) - check_path_pattern_valid(self.file_path) - check_common_file_size(self.file_path) - check_file_suffix(self.file_path, self.file_type) - if self.path_type == FileCheckConst.FILE: - check_dirpath_before_read(self.file_path) - return self.file_path - - def check_path_ability(self): - if self.ability == FileCheckConst.WRITE_ABLE: - check_path_writability(self.file_path) - if self.ability == FileCheckConst.READ_ABLE: - check_path_readability(self.file_path) - if self.ability == FileCheckConst.READ_WRITE_ABLE: - check_path_readability(self.file_path) - check_path_writability(self.file_path) - - -class FileOpen: - """ - The class for open file by a safe way. - - Attributes: - file_path: The file or dictionary path to be opened. - mode(str): The file open mode - """ - SUPPORT_READ_MODE = ["r", "rb"] - SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"] - SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"] - - def __init__(self, file_path, mode, encoding='utf-8'): - self.file_path = file_path - self.mode = mode - self.encoding = encoding - self._handle = None - - def __enter__(self): - self.check_file_path() - binary_mode = "b" - if binary_mode not in self.mode: - self._handle = open(self.file_path, self.mode, encoding=self.encoding) - else: - self._handle = open(self.file_path, self.mode) - return self._handle - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._handle: - self._handle.close() - - def check_file_path(self): - support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE - if self.mode not in support_mode: - logger.error("File open not support %s mode" % self.mode) - raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) - check_link(self.file_path) - self.file_path = os.path.realpath(self.file_path) - check_path_length(self.file_path) - self.check_ability_and_owner() - check_path_pattern_valid(self.file_path) - if os.path.exists(self.file_path): - check_common_file_size(self.file_path) - check_dirpath_before_read(self.file_path) - - def check_ability_and_owner(self): - if self.mode in self.SUPPORT_READ_MODE: - check_path_exists(self.file_path) - check_path_readability(self.file_path) - check_path_owner_consistent(self.file_path) - if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path): - check_path_writability(self.file_path) - check_path_owner_consistent(self.file_path) - if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path): - check_path_readability(self.file_path) - check_path_writability(self.file_path) - check_path_owner_consistent(self.file_path) - - -def check_link(path): - abs_path = os.path.abspath(path) - if os.path.islink(abs_path): - logger.error('The file path {} is a soft link.'.format(path)) - raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) - - -def check_path_length(path, name_length=None): - file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH - if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(path)) > file_max_name_length: - logger.error('The file path length exceeds limit.') - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) - - -def check_path_exists(path): - if not os.path.exists(path): - logger.error('The file path %s does not exist.' % path) - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) - - -def check_path_readability(path): - if not os.access(path, os.R_OK): - logger.error('The file path %s is not readable.' % path) - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) - - -def check_path_writability(path): - if not os.access(path, os.W_OK): - logger.error('The file path %s is not writable.' % path) - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) - - -def check_path_executable(path): - if not os.access(path, os.X_OK): - logger.error('The file path %s is not executable.' % path) - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) - - -def check_other_user_writable(path): - st = os.stat(path) - if st.st_mode & 0o002: - logger.error('The file path %s may be insecure because other users have write permissions. ' % path) - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) - - -def check_path_owner_consistent(path): - file_owner = os.stat(path).st_uid - if file_owner != os.getuid() and os.getuid() != 0: - logger.error('The file path %s may be insecure because is does not belong to you.' % path) - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) - - -def check_path_pattern_valid(path): - if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): - logger.error('The file path %s contains special characters.' % (path)) - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) - - -def check_file_size(file_path, max_size): - try: - file_size = os.path.getsize(file_path) - except OSError as os_error: - logger.error(f'Failed to open "{file_path}". {str(os_error)}') - raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from os_error - if file_size >= max_size: - logger.error(f'The size ({file_size}) of {file_path} exceeds ({max_size}) bytes, tools not support.') - raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) - - -def check_common_file_size(file_path): - if os.path.isfile(file_path): - for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): - if file_path.endswith(suffix): - check_file_size(file_path, max_size) - return - check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE) - - -def check_file_suffix(file_path, file_suffix): - if file_suffix: - if not file_path.endswith(file_suffix): - logger.error(f"The {file_path} should be a {file_suffix} file!") - raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) - - -def check_path_type(file_path, file_type): - if file_type == FileCheckConst.FILE: - if not os.path.isfile(file_path): - logger.error(f"The {file_path} should be a file!") - raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) - if file_type == FileCheckConst.DIR: - if not os.path.isdir(file_path): - logger.error(f"The {file_path} should be a dictionary!") - raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) - -def make_dir(dir_path): - check_path_before_create(dir_path) - dir_path = os.path.realpath(dir_path) - if os.path.isdir(dir_path): - return - try: - os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) - except OSError as ex: - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, - f"Failed to create {dir_path}. " - f"Please check the path permission or disk space. {str(ex)}") from ex - file_check = FileChecker(dir_path, FileCheckConst.DIR) - file_check.common_check() - - - - -@recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16) -def create_directory(dir_path): - """ - Function Description: - creating a safe directory with specified permissions - Parameter: - dir_path: directory path - Exception Description: - when invalid data throw exception - """ - check_link(dir_path) - check_path_before_create(dir_path) - dir_path = os.path.realpath(dir_path) - parent_dir = os.path.dirname(dir_path) - if not os.path.isdir(parent_dir): - create_directory(parent_dir) - make_dir(dir_path) - - -def check_path_before_create(path): - check_link(path) - if path_len_exceeds_limit(path): - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.') - - if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)): - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, - 'The file path {} contains special characters.'.format(path)) - - -def check_dirpath_before_read(path): - path = os.path.realpath(path) - dirpath = os.path.dirname(path) - - -def check_file_or_directory_path(path, isdir=False): - """ - Function Description: - check whether the path is valid - Parameter: - path: the path to check - isdir: the path is dir or file - Exception Description: - when invalid data throw exception - """ - if isdir: - path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) - else: - path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE) - path_checker.common_check() - - -def change_mode(path, mode): - if not os.path.exists(path) or os.path.islink(path): - return - try: - os.chmod(path, mode) - except PermissionError as ex: - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, - 'Failed to change {} authority. {}'.format(path, str(ex))) from ex - - -def path_len_exceeds_limit(file_path): - return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH - -def load_npy(filepath): - check_file_or_directory_path(filepath) - try: - npy = np.load(filepath, allow_pickle=False) - except Exception as e: - logger.error(f"The numpy file failed to load. Please check the path: {filepath}.") - raise RuntimeError(f"Load numpy file {filepath} failed.") from e - return npy - -def write_csv(data, filepath, mode="a+", malicious_check=False): - def csv_value_is_valid(value: str) -> bool: - if not isinstance(value, str): - return True - try: - # -1.00 or +1.00 should be considered as digit numbers - float(value) - except ValueError: - # otherwise, they will be considered as formular injections - return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) - return True - - if malicious_check: - for row in data: - for cell in row: - if not csv_value_is_valid(cell): - raise RuntimeError(f"Malicious value [{cell}] is not allowed " - f"to be written into the csv: {filepath}.") - - check_path_before_create(filepath) - file_path = os.path.realpath(filepath) - try: - with FileOpen(filepath, mode, encoding='utf-8-sig') as f: - writer = csv.writer(f) - writer.writerows(data) - except Exception as e: - logger.error(f'Save csv file "{os.path.basename(file_path)}" failed') - raise RuntimeError(f"Save csv file {file_path} failed.") from e - change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) - print(f"file_path:{file_path}") - - - -def write_csv_header(csv_path, header_func): - """如果是第一次写入,则写入 CSV 表头""" - header = header_func() # 获取表头 - logger.debug(f"Writing CSV header: {header}") - write_csv([header], csv_path, mode="a+") - - -def get_result_csv_header(): - """获取结果 CSV 文件的表头""" - return [ - MsCompareConst.DETAIL_CSV_API_NAME, - MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS, - MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS, - MsCompareConst.DETAIL_CSV_MESSAGE, - ] - - -def get_detail_csv_header(): - """获取详细 CSV 文件的表头""" - detail_csv_header_basic_info = [ - MsCompareConst.DETAIL_CSV_API_NAME, - MsCompareConst.DETAIL_CSV_BENCH_DTYPE, - MsCompareConst.DETAIL_CSV_TESTED_DTYPE, - MsCompareConst.DETAIL_CSV_SHAPE, - ] - detail_csv_header_compare_result = list(compare_algorithms.keys()) - detail_csv_header_status = [ - MsCompareConst.DETAIL_CSV_PASS_STATUS, - MsCompareConst.DETAIL_CSV_MESSAGE, - ] - return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status - - -def check_csv_header(headers, required_constants, csv_path): - """校验 CSV 文件表头是否包含所有必需的常量""" - missing_constants = [const for const in required_constants if not any(const in header for header in headers)] - - if missing_constants: - raise MsprobeBaseException( - MsprobeBaseException.MISSING_HEADER_ERROR, - f"{csv_path} 缺少以下必需的表头字段: {missing_constants}" - ) -def add_time_as_suffix(name): - return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) - - -# ======= 结果落盘管理类 ======== - - - - -class DataManager: - def __init__(self, csv_dir, result_csv_path): - self.results = {} - self.results_exception_skip = {} - self.is_first_write = True # 标记用于添加表头 - self.csv_dir = csv_dir - self.api_names_set = set() # 存储已经出现的 API 名称的集合 - # 如果传入了 result_csv_path,则启用断点续检 - if result_csv_path: - self.resume_from_last_csv(result_csv_path) - self.initialize_api_names_set(result_csv_path) - else: - # 默认情况下,设置输出路径为空,等待首次写入时初始化 - self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME)) - self.detail_out_path = os.path.join( - self.csv_dir, - os.path.basename(self.result_out_path).replace("result", "details") - ) - - if self.detail_out_path and os.path.exists(self.detail_out_path): - check_file_or_directory_path(self.detail_out_path) - - if self.result_out_path and os.path.exists(self.result_out_path): - check_file_or_directory_path(self.result_out_path) - - def initialize_api_names_set(self, result_csv_path): - """读取现有的 CSV 文件并存储已经出现的 API 名称到集合中""" - # 使用新的 read_csv 函数读取数据 - csv_data = read_csv(result_csv_path, as_pd=False) - - # 读取标题行 - headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空 - - # 使用提取的表头校验函数 - if check_csv_header(headers, get_result_csv_header(), result_csv_path): - - # 获取 "API Name" 列的索引 - api_name_index = None - for i, header in enumerate(headers): - if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找 - api_name_index = i - break - - if api_name_index is None: - logger.warning(f"{result_csv_path} No column contains 'API Name'.") - return - - # 读取每一行的 API 名称 - for row in csv_data[1:]: # 跳过标题行,从第二行开始 - if row and len(row) > api_name_index: - api_name = row[api_name_index] - if api_name: - self.api_names_set.add(api_name) - - logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}") - - def is_unique_api(self, api_name): - """检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True""" - if api_name in self.api_names_set: - return False - self.api_names_set.add(api_name) - return True - - def resume_from_last_csv(self, result_csv_path): - """从上次运行的 result_csv_path 恢复断点""" - # 获取上次的目录路径 - last_dir = os.path.dirname(result_csv_path) - - # 设置当前目录和输出路径,确保在首次写入时使用 - self.csv_dir = last_dir - self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details")) - if self.detail_out_path and os.path.exists(self.detail_out_path): - check_file_or_directory_path(self.detail_out_path) - self.result_out_path = result_csv_path - self.is_first_write = False - - def save_results(self, api_name_str): - if self.is_first_write: - # 直接写入表头 - logger.info("Writing CSV headers for the first time.") - write_csv_header(self.detail_out_path, get_detail_csv_header) - write_csv_header(self.result_out_path, get_result_csv_header) - self.is_first_write = False # 写入后标记为 False,避免重复写入表头 - - """写入详细输出和结果摘要并清理结果""" - logger.debug("Starting to write detailed output to CSV.") - self.to_detail_csv(self.detail_out_path) - logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.") - - logger.debug("Starting to write result summary to CSV.") - self.to_result_csv(self.result_out_path) - logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.") - - # 清理记录,准备下一次调用 - self.clear_results() - - def record(self, output_list): - if output_list is None: - return - for output in output_list: - api_real_name, forward_or_backward, basic_info, compare_result_dict = output - key = (api_real_name, forward_or_backward) - if key not in self.results: - self.results[key] = [] - self.results[key].append((basic_info, compare_result_dict)) - logger.debug(f"Complete self.results after recording: {self.results}") - - def record_exception_skip(self, api_name, forward_or_backward, err_msg): - ''' - record exception_skip information into self.record_exception_skip. - self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}} - string in key is api_name, string in value is err_msg - ''' - if api_name not in self.results_exception_skip: - self.results_exception_skip[api_name] = {Const.FORWARD: None, Const.BACKWARD: None} - self.results_exception_skip[api_name][forward_or_backward] = err_msg - - def clear_results(self): - """清空 self.results 数据""" - logger.debug("Clearing self.results data.") - self.results.clear() - self.results_exception_skip.clear() - - def to_detail_csv(self, csv_path): - logger.debug("Preparing detail CSV headers and rows.") - detail_csv = [] - - detail_csv_header_compare_result = list(compare_algorithms.keys()) - - for _, results in self.results.items(): - for res in results: - basic_info, compare_result_dict = res - csv_row_basic_info = [ - basic_info.api_name, - basic_info.bench_dtype, - basic_info.tested_dtype, - basic_info.shape - ] - csv_row_compare_result = [ - compare_result_dict.get(algorithm_name).compare_value - for algorithm_name in detail_csv_header_compare_result - ] - csv_row_status = [basic_info.status, basic_info.err_msg] - csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status - detail_csv.append(csv_row) - logger.debug(f"Detail CSV row added: {csv_row}") - - logger.debug(f"Writing detail CSV to {csv_path}.") - write_csv(detail_csv, csv_path, mode="a+") - logger.debug(f"Detail CSV written successfully to {csv_path}.") - - def to_result_csv(self, csv_path): - ''' - depend on both self.results and self.results_exception_skip - ''' - logger.debug("Preparing result CSV data.") - result_csv = [] - - result_csv_dict = {} - for key, results in self.results.items(): - api_real_name, forward_or_backward = key - pass_status = CompareConst.PASS - overall_err_msg = "" - - for res in results: - basic_info, _ = res - if basic_info.status != CompareConst.PASS: - pass_status = CompareConst.ERROR - overall_err_msg += basic_info.err_msg - - overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg - - if api_real_name not in result_csv_dict: - result_csv_dict[api_real_name] = ResultCsvEntry() - if forward_or_backward == Const.FORWARD: - result_csv_dict[api_real_name].forward_pass_status = pass_status - result_csv_dict[api_real_name].forward_err_msg = overall_err_msg - else: - result_csv_dict[api_real_name].backward_pass_status = pass_status - result_csv_dict[api_real_name].backward_err_msg = overall_err_msg - - for api_name, entry in result_csv_dict.items(): - overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and - entry.backward_pass_status == CompareConst.PASS) else \ - entry.forward_err_msg + entry.backward_err_msg - row = [ - api_name, - entry.forward_pass_status, - entry.backward_pass_status, - overall_err_msg - ] - # change row if this api has exception_skip information - if api_name in self.results_exception_skip: - if self.results_exception_skip[api_name][Const.FORWARD] is not None: - row[1] = CompareConst.SKIP - row[-1] += self.results_exception_skip[api_name][Const.FORWARD] - if self.results_exception_skip[api_name][Const.BACKWARD] is not None: - row[2] = CompareConst.SKIP - row[-1] += self.results_exception_skip[api_name][Const.BACKWARD] - del self.results_exception_skip[api_name] - result_csv.append(row) - logger.debug(f"Result CSV row added: {row}") - for api_name in self.results_exception_skip: - current_exception_skip = self.results_exception_skip[api_name] - forward_status = None - backward_status = None - err_msg = "" - if current_exception_skip[Const.FORWARD] is not None: - forward_status = CompareConst.SKIP - err_msg += current_exception_skip[Const.FORWARD] - if current_exception_skip[Const.BACKWARD] is not None: - backward_status = CompareConst.SKIP - err_msg += current_exception_skip[Const.BACKWARD] - row = [api_name, forward_status, backward_status, err_msg] - result_csv.append(row) - - write_csv(result_csv, csv_path, mode="a+") - logger.debug(f"Result CSV written successfully to {csv_path}.") - - # 设置标记为 False,防止后续重复添加表头 - self.is_first_write = False - - -# ======== 输入类型类 ======= -class GlobalContext: - def __init__(self): - self.is_constructed = True - self.dump_data_dir = "" - self.framework = Const.MS_FRAMEWORK - - def init(self, is_constructed, dump_data_dir, framework): - self.is_constructed = is_constructed - self.dump_data_dir = dump_data_dir - self.framework = framework - - def get_dump_data_dir(self): - return self.dump_data_dir - - def get_is_constructed(self): - return self.is_constructed - - def get_framework(self): - return self.framework - - -global_context = GlobalContext() - - - -class ApiInputAggregation: - def __init__(self, inputs, kwargs, gradient_inputs) -> None: - """ - Args: - inputs: List[ComputeElement] - kwargs: dict{str: ComputeElement} - gradient_inputs: Union[List[ComputeElement], None] - """ - self.inputs = inputs - self.kwargs = kwargs - self.gradient_inputs = gradient_inputs - - -api_parent_module_mapping = { - (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint, - (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch, - (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional, - (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional, - (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor, - (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor, - (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): mindtorch_tensor, - (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): torch.Tensor, - (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): mindtorch, - (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): torch, - (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func, - (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional, - (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist, - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed, - (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops - -} - - -api_parent_module_str_mapping = { - (MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint", - (MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch", - (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional", - (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional", - (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor", - (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor", - (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): "mindtorch_tensor", - (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): "torch.Tensor", - (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): "mindtorch", - (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): "torch", - (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func", - (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional", - (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist", - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed", - (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops" -} - - -class ApiRunner: - def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD, - api_platform=Const.MS_FRAMEWORK): - ''' - Args: - api_input_aggregation: ApiInputAggregation - api_name_str: str, e.g. "MintFunctional.relu.0" - forward_or_backward: str, Union["forward", "backward"] - api_platform: str, Union["mindspore", "torch", "mindtorch"] - - Return: - outputs: list[ComputeElement] - - Description: - run mindspore.mint/torch api - ''' - - api_type_str, api_sub_name = self.get_info_from_name(api_name_str, api_platform) - api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform) - - return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform) - - @staticmethod - def get_info_from_name(api_name_str, api_platform=Const.MS_FRAMEWORK): - """ - Args: - api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0" - api_platform: str, the platform for the API, which can be either "mindspore" or "mindtorch". - It specifies which framework is being used. Default is "mindspore". - Return: - api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Torch", "Functional"] - api_sub_name: str, e.g. "relu" - """ - api_name_list = api_name_str.split(Const.SEP) - if len(api_name_list) != 3: - err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) - api_type_str, api_sub_name = api_name_list[0], api_name_list[1] - if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API, - MsCompareConst.FUNCTIONAL_API] \ - and api_platform == Const.MS_FRAMEWORK: - err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) - - if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK: - err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) - return api_type_str, api_sub_name - - @staticmethod - def get_api_instance(api_type_str, api_sub_name, api_platform): - """ - Args: - api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"] - api_sub_name: str, e.g. "relu" - api_platform: str: Union["mindspore", "pytorch"] - - Return: - api_instance: function object - - Description: - get mindspore.mint/torch api function - mindspore.mint.{api_sub_name} <--> torch.{api_sub_name} - mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name} - """ - - api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) - api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) - full_api_name = api_parent_module_str + Const.SEP + api_sub_name - - if not hasattr(api_parent_module, api_sub_name): - err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong)) - - api_instance = getattr(api_parent_module, api_sub_name) - if not callable(api_instance): - err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong)) - - return api_instance - - @staticmethod - def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform): - inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform) - for compute_element in api_input_aggregation.inputs) - kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform) - for key, value in api_input_aggregation.kwargs.items()} - gradient_inputs = api_input_aggregation.gradient_inputs - - if forward_or_backward == Const.FORWARD: - forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple - forward_result_tuple = convert_to_tuple(forward_result) - res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple] - if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK: - return res_compute_element_list, inputs, kwargs, forward_result_tuple - else: - if gradient_inputs is None: - err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) - gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform) - for compute_element in gradient_inputs) - if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK: - if len(gradient_inputs) == 1: - gradient_inputs = gradient_inputs[0] - - def api_with_kwargs(*forward_inputs): - return api_instance(*forward_inputs, **kwargs) - - grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs) - backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple - backward_result_tuple = convert_to_tuple(backward_result) - res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple] - return res_compute_element_list, gradient_inputs, backward_result_tuple - else: - # set requires_grad - requires_grad_index = [] - for index, tensor in enumerate(inputs): - if isinstance(tensor, torch.Tensor) and \ - torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list: - setattr(tensor, "requires_grad", True) - requires_grad_index.append(index) - forward_results = api_instance(*inputs, **kwargs) - forward_results = convert_to_tuple(forward_results) - for forward_res, gradient_in in zip(forward_results, gradient_inputs): - forward_res.backward(gradient_in) - backward_result_list = [] - for index in requires_grad_index: - backward_result_list.append(getattr(inputs[index], "grad")) - res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list] - - return res_compute_element_list - - -api_runner = ApiRunner() - -# ======== 数据结构类 ======== - -class ResultCsvEntry: - def __init__(self) -> None: - self.forward_pass_status = None - self.backward_pass_status = None - self.forward_err_msg = "" - self.backward_err_msg = "" - self.overall_err_msg = None - -class ProcessResultPacket: - def __init__(self, process_status, result, err_msg) -> None: - self.process_status = process_status - self.result = result - self.err_msg = err_msg - -class MstensorMetaData: - def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None: - self.dtype_str = dtype_str - self.npy_path = npy_path - self.maximum = maximum - self.minimum = minimum - self.shape = shape - - -class DtypeMetaData: - def __init__(self, dtype_str) -> None: - self.dtype_str = dtype_str - - -class ComputeElement: - def __init__(self, compute_element_info=None, parameter=None): - self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple]) - if parameter is not None: - self._init_with_parameter(parameter) - elif isinstance(compute_element_info, (list, dict)): - self._init_from_compute_element_info(compute_element_info) - elif compute_element_info is None: - self._init_from_null_compute_element_info() - else: - pass - logger.error_log_with_exp( - "ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)", - ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) - - @staticmethod - def transfer_to_torch_tensor(ms_tensor): - ''' - Args: - ms_tensor: mindspore.Tensor - Return: - torch_tensor: torch.Tensor - ''' - ms_dtype = ms_tensor.dtype - dtype_str = ms_dtype_to_dtype_str.get(ms_dtype) - if dtype_str not in dtype_str_to_torch_dtype: - err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) - else: - torch_dtype = dtype_str_to_torch_dtype.get(dtype_str) - - if dtype_str in int_dtype_str_list: - middle_dtype = mindspore.int64 - else: - middle_dtype = mindspore.float64 - np_ndarray = ms_tensor.astype(middle_dtype).numpy() - torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype) - return torch_tensor - - @staticmethod - def transfer_to_mindtorch_tensor(ms_tensor): - """ - Args: - ms_tensor: mindspore.Tensor - Return: - mindtorch_tensor: mindtorch.Tensor - """ - - ms_dtype = ms_tensor.dtype - - dtype_str = ms_dtype_to_dtype_str.get(ms_dtype) - - if dtype_str not in dtype_str_to_mindtorch_dtype: - err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}" - logger.error_log_with_exp(err_msg, - ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) - else: - mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str) - - if dtype_str in int_dtype_str_list: - middle_dtype = mindspore.int64 - else: - middle_dtype = mindspore.float64 - - np_ndarray = ms_tensor.astype(middle_dtype).numpy() - - mindtorch_tensor = mindtorch.from_numpy(np_ndarray).to(ms_dtype) - - return mindtorch_tensor - - @staticmethod - def transfer_to_mindspore_tensor(torch_tensor): - ''' - Args: - torch_tensor: torch.Tensor - - Return: - ms_tensor: mindspore.Tensor - ''' - torch_dtype = torch_tensor.dtype - dtype_str = torch_dtype_to_dtype_str.get(torch_dtype) - if dtype_str not in dtype_str_to_ms_dtype: - err_msg = \ - f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) - else: - ms_dtype = dtype_str_to_ms_dtype.get(dtype_str) - - if dtype_str in int_dtype_str_list: - middle_dtype = torch.int64 - else: - middle_dtype = torch.float64 - np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy() - ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype) - return ms_tensor - - @staticmethod - def convert_inf_to_real_num(value, dtype_str): - if value == float("inf"): - np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE) - value = np.finfo(np_dtype).max - elif value == float("-inf"): - np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE) - value = np.finfo(np_dtype).min - return value - - def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK): - ''' - Args: - get_origin: boolean - tensor_platform: str, Union["mindspore", "pytorch"] - - Return: - parameter: Union[int, float, str, slice, tuple, torch.Tensor, mindspore.Tensor] - ''' - if self.parameter is None: - return self.parameter - if isinstance(self.parameter, tuple): - return tuple([compute_element.get_parameter(get_origin=get_origin, tensor_platform=tensor_platform) - for compute_element in self.parameter]) - elif isinstance(self.parameter, self.supported_parameter_type): - parameter_tmp = self.parameter - elif isinstance(self.parameter, DtypeMetaData): - if tensor_platform == Const.MS_FRAMEWORK: - parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str) - elif tensor_platform == Const.PT_FRAMEWORK: - parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str) - elif tensor_platform == Const.MT_FRAMEWORK: - parameter_tmp = dtype_str_to_mindtorch_dtype.get(self.parameter.dtype_str) - - elif isinstance(self.parameter, MstensorMetaData): - mstensor_meta_data = self.parameter - ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str) - if global_context.get_is_constructed(): - np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE) - ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum, - mstensor_meta_data.minimum, np_dtype) - else: - ndarray = load_npy(mstensor_meta_data.npy_path) - parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype) - else: - err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \ - "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) - - # if necessary, do transfer - if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK: - parameter = self.transfer_to_torch_tensor(parameter_tmp) - elif not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.MT_FRAMEWORK: - parameter = self.transfer_to_mindtorch_tensor(parameter_tmp) - elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK: - parameter = self.transfer_to_mindspore_tensor(parameter_tmp) - else: - parameter = parameter_tmp - - return parameter - - def get_shape(self): - return self.shape - - def get_dtype(self): - return self.dtype_str - - def _construct_ndarray(self, shape, maximum, minimum, np_dtype): - shape = tuple(shape) - np.random.seed({random_seed}) - if np_dtype == np.bool_: - ndarray = np.random.rand(*shape) > 0.5 - else: - maximum = self.convert_inf_to_real_num(maximum, np_dtype) - minimum = self.convert_inf_to_real_num(minimum, np_dtype) - ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype) - return ndarray - - def _init_from_null_compute_element_info(self): - self.parameter = None - self.shape = tuple() - self.dtype = "None" - - def _init_from_compute_element_info(self, compute_element_info): - ''' - Args: - compute_element_info: Union[list, dict] - - Return: - void - - init member attributes: self.shape, self.dtype_str, self.parameter - ''' - if isinstance(compute_element_info, list): - self.shape = tuple() - self.dtype_str = TUPLE_TYPE_STR - self.parameter = tuple([ComputeElement(compute_element_info=sub_info) - for sub_info in compute_element_info]) - else: - type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json", - accepted_type=str, accepted_value=api_info_type_str_to_type.keys()) - self.shape = tuple() - self.dtype_str = type_str - if type_str == MINDSPORE_TENSOR_TYPE_STR: - self._init_from_mstensor_compute_element_info(compute_element_info) - else: - value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json") - if type_str == MINDSPORE_DTYPE_TYPE_STR: - self.parameter = DtypeMetaData(value) - elif type_str == SLICE_TYPE_STR: - self.parameter = slice(*tuple(value)) - else: # type_str in ("str", "int", "float", "bool") - self.parameter = value - - def _init_from_mstensor_compute_element_info(self, compute_element_info): - ''' - do not load real tensor, only record meta data - ''' - dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json", - accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys()) - shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json", - accepted_type=(list,)) - if global_context.get_is_constructed(): - maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json", - accepted_type=(int, float)) - minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json", - accepted_type=(int, float)) - - npy_path = None - else: - maximum, minimum = None, None - data_name = check_and_get_from_json_dict(compute_element_info, "data_name", - "data_name field in api_info.json", accepted_type=(str,)) - npy_path = os.path.join(global_context.get_dump_data_dir(), data_name) - mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape) - self.parameter = mstensor_meta_data - self.dtype_str = dtype_str - self.shape = tuple(shape) - - def _init_with_parameter(self, parameter): - self.parameter = parameter - print(f"parameter:{parameter}") - print(f"self.supported_parameter_type:{self.supported_parameter_type}") - if isinstance(parameter, dict): - # 这里假设 dict 中有 'type'、'shape'、'dtype' 等字段 - return self._init_from_compute_element_info(parameter) - self.shape = tuple() - if not isinstance(parameter, self.supported_parameter_type): - err_msg = "ComputeElement._init_with_parameter failed: " \ - "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType)) - if isinstance(parameter, mindspore.Tensor): - self.shape = tuple(parameter.shape) - self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype) - elif isinstance(parameter, torch.Tensor): - self.shape = tuple(parameter.shape) - self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype) - elif isinstance(parameter, typing.Type): - self.dtype_str = MINDSPORE_DTYPE_TYPE_STR - self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter)) - elif isinstance(parameter, torch.dtype): - self.dtype_str = TORCH_DTYPE_TYPE_STR - self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter)) - elif isinstance(parameter, tuple): - self.dtype_str = TUPLE_TYPE_STR - self.parameter = tuple([ComputeElement(parameter=param) for param in parameter]) - else: - self.dtype_str = type_to_api_info_type_str.get(type(parameter)) - print(f"self.dtype_str{self.dtype_str}") - -class BasicInfoAndStatus: - def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None: - self.api_name = api_name - self.bench_dtype = bench_dtype - self.tested_dtype = tested_dtype - self.shape = shape - self.status = status - self.err_msg = err_msg - - - - -# ======== api执行类 ======= - -def get_input(propagation): - args_info_forward = {args_info_forward} - kwargs_info_forward = {kwargs_info_forward} - args_info_backward = {args_info_backward} - forward_inputs = [ComputeElement(compute_element_info=compute_element_info) - for compute_element_info in args_info_forward] - kwargs_compute_element_dict = { - key_str: ComputeElement(compute_element_info=compute_element_info) - for key_str, compute_element_info in kwargs_info_forward.items() - } - if args_info_backward: - gradient_inputs = [ComputeElement(compute_element_info=compute_element_info) - for compute_element_info in args_info_backward] - else: - gradient_inputs = None - return ApiInputAggregation( - forward_inputs, - kwargs_compute_element_dict, - gradient_inputs - ) - -# 运行和比对函数 -def run_and_compare_helper(api_name_str, api_input_aggregation, forward_or_backward): - """ - Args: - api_info: ApiInfo - api_name_str: str - api_input_aggregation: ApiInputAggregation - forward_or_backward: str: Union["forward", "backward"] - - Return: - output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})] - - Description: - get mindspore api output, run torch api and get output. - compare output. - record compare result. - """ - # get output - if forward_or_backward == Const.FORWARD: - tested_outputs, inputs, kwargs, forward_result_tuple = api_runner(api_input_aggregation, api_name_str, - forward_or_backward, - global_context.get_framework()) - print(f"inputs:{inputs}") - print(f"kwargs:{kwargs}") - print(f"forward_result_tuple:{forward_result_tuple}") - elif forward_or_backward == Const.BACKWARD: - tested_outputs, gradient_inputs, backward_result_tuple = api_runner(api_input_aggregation, api_name_str, - forward_or_backward, - global_context.get_framework()) - print(f"gradient_inputs:{gradient_inputs}") - print(f"backward_result_tuple:{backward_result_tuple}") - else: - tested_outputs = api_runner(api_input_aggregation, api_name_str, - forward_or_backward, global_context.get_framework()) - - bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK) - - tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward) - bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward) - - # compare output - output_list = [] - for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)): - api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)]) - bench_dtype = bench_out.get_dtype() - tested_dtype = tested_out.get_dtype() - shape = bench_out.get_shape() - - compare_result_dict = dict() - for compare_algorithm_name, compare_algorithm in compare_algorithms.items(): - compare_result = compare_algorithm(bench_out, tested_out) - compare_result_dict[compare_algorithm_name] = compare_result - - if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \ - compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS: - status = CompareConst.PASS - err_msg = "" - else: - status = CompareConst.ERROR - err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg + - compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg) - - # self.pre_forward_hook(api_name_str, None, inputs, kwargs) - basic_info_status = \ - BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg) - output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict])) - return output_list - - -if __name__ == "__main__": - framework = "{framework}" - dump_data_dir = "{real_data_path}" - api_name = "{api_name}" - api_full_name = "{api_full_name}" - api_name_str = ".".join(api_full_name.split(".")[:3]) - propagation = "{propagation}" - data_mode = "{data_mode}" - torch.manual_seed({random_seed}) - - data_manager = DataManager("./op_result_output", None) - create_directory("./op_result_output") - - is_constructed = data_mode == "random_data" - global_context.init(is_constructed, dump_data_dir, framework) - - for i in range({iter_times}): - print(f"iter: {{i}}:") - if propagation == BACKWARD: - - - backward_inputs_aggregation = get_input(propagation) - - backward_output_list = run_and_compare_helper(api_name_str, backward_inputs_aggregation, - Const.BACKWARD) - process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS, - result=backward_output_list, err_msg="") - - - if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS: - data_manager.record(process_result_packet.result) - elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP: - data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg) - - data_manager.save_results(api_name_str) - else: - forward_inputs_aggregation = get_input(propagation) - - forward_output_list = run_and_compare_helper(api_name_str, forward_inputs_aggregation, - Const.FORWARD) - process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS, - result=forward_output_list, err_msg="") - - - if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS: - data_manager.record(process_result_packet.result) - elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP: - data_manager.record_exception_skip(api_name_str, Const.FORWARD, process_result_packet.err_msg) - - data_manager.save_results(api_name_str) - - print("Compare finished.") -- Gitee