diff --git a/debug/accuracy_tools/msprobe/docs/33.generate_operator_PyTorch.md b/debug/accuracy_tools/msprobe/docs/33.generate_operator_PyTorch.md new file mode 100644 index 0000000000000000000000000000000000000000..d44989ac3549e61e28f2d1897139b78fa52abfc9 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/33.generate_operator_PyTorch.md @@ -0,0 +1,107 @@ +# 单算子API自动生成脚本 + +## 1 简介 + +单算子API自动生成脚本通过提取dump数据中的可疑算子,对其进行单API复现,输出单API精度的比对结果。具体而言,该工具可以从dump数据中提取可疑API的前反向信息,根据前反向数据生成单API的前反向过程,最后通过**新精度标准比对法**a将 NPU/GPU 和 CPU 的结果进行比对,从而给出不同比对方法下的比对结果。本工具支持**随机生成模式和真实数据模式**b。 + +a. 在生成单API脚本时可以选择由工具构造随机数获得 dump 数据或选择真实输入的数据进行单API复现。随机生成模式(对应 task: "statistics")执行效率高,可以快速获得结果,但数据精度低,只能大致判断精度问题;真实数据模式(对应 task: "tensor")执行效率略低于随机生成模式,但是数据精度高,可以准确判断精度问题。 + +## 2 使用方式 + +### 前提 +1. 安装 msprobe。详见[ msprobe 安装](./01.installation.md)章节。 +2. 已完成对训练过程的dump,获得dump.json文件。 + [MindSpore 场景下的数据采集](./06.data_dump_MindSpore.md)章节或[Msadapter 场景下的数据采集](./29.data_dump_MSAdapter.md)章节,注意需要配置 level="L1"。 + +3. 发现某个算子疑似存在精度问题,并得知算子名,如Mint.split.1、Functional.softmax.3、Tensor.add.0、Torch.matmul.5等 + +### 2.1 配置config_op.json +单API复现参数配置如下(以复现softmax算子为例): +``` +{ + "dump_json_path": "./dump.json", + "api_name": "Mint.split.1", + "extract_api_path": "Mint.split.1.json", + "propagation": "backward", + "data_mode": "random_data", + "random_seed": 42, + "iter_times": 1 +} +``` +**配置文件参数说明** + + | 参数名称 | 解释 | 是否必选 | + | ---------------------------- |----------------------------------------------------------------------------| ---------------------------------- | + | dump_json_path | dump.json的文件路径,包含所有dump算子的信息;如果已经提取了可疑算子并保存可以不指定。 | 否 | + | api_name | 算子名,如Functional.softmax.3、Tensor.add.0、Torch.matmul.5等。 | 否 | + | extract_api_path | 提取可疑算子的json文件路径 | 是 | + | propagation | 选择复现算子的forward还是backward,默认为forward | 否 | + | data_mode | 选择复现算子的随机数据(random_data)还是真实数据(real_data)模式,默认为random_data | 否 | + | random_seed | 仅random_data模式有效,表示手动设定的随机种子,默认为1234 | 否 | + | iter_times | 仅random_data模式有效,表示单API运行的次数,由于安全相关原因,最大支持设置为1000 | 否 | + + ### 2.2 运行命令生成单API脚本 +config_op.json配置好后,运行如下命令: +``` +msprobe -f mindspore op_generate -i ./config.json -o ./ +``` +或者 + +进入到mstt的generate_op_script文件夹 +``` +cd mstt/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script +``` +运行 +``` +python op_generator.py -i ./config_op.json -o ./ +``` +**参数说明** + | 参数名称 | 解释 | 是否必选 | + | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- | + | -i 或 --config_input | config_op.json的路径 | 是 | + | -o 或 --api_output_path | 单API脚本的输出路径 | 是 | + + ### 2.3 运行单API脚本 + 运行完op_generator.py后,会在指定路径下生成api_name.py的单API脚本,例如Mint.split.1.forward.py、Functional.softmax.3.backward.py、Tensor.add.0.forward.py、Torch.matmul.5.backward.py + +运行单API脚本即可获得不同比对方法下的比对结果 +``` +python api_name.py +``` + +**运行结果说明** + +单算子脚本生成的 `accuracy_checking_result_{timestamp}.csv` 和 `accuracy_checking_details_{timestamp}.csv` 文件内容详情如下: + +`accuracy_checking_details_{timestamp}.csv` + +| 字段 | 含义 | +| ------------------- | ------------------------------------------------------------ | +| API Name | API 名称。 | +| Bench Dtype | 标杆数据的 API 数据类型。 | +| Tested Dtype | 被检验数据的 API 数据类型。 | +| Shape | API 的 Shape 信息。 | +| Cosine | 被检验数据与标杆数据的余弦相似度。 | +| MaxAbsErr | 被检验数据与标杆数据的最大绝对误差。 | +| MaxRelativeErr | 被检验数据与标杆数据的最大相对误差。 | +| Status | API 预检通过状态,pass 表示通过测试,error 表示未通过。 | +| Message | 提示信息。 | + +注意:PyTorch 无法对 dtype 为整数类型的 tensor 进行反向求导,而 MindSpore 支持。反向过程的预检仅比较 dtype 为浮点型的输出。 + +`accuracy_checking_result_{timestamp}.csv` + +| 字段 | 含义 | +| --------------------- | ----------------- | +| API Name | API 名称。 | +| Forward Test Success | 前向 API 是否通过测试,pass 为通过,error 为错误。 | +| Backward Test Success | 反向 API 是否通过测试,pass 为通过,error 为错误,如果是空白的话代表该 API 没有反向输出。 | +| Message | 提示信息。 | + +Forward Test Success 和 Backward Test Success 是否通过测试是由 `accuracy_checking_details_{timestamp}.csv` 中的余弦相似度、最大绝对误差判定结果决定的。具体规则详见 [4.1 API 预检指标](#41-api-预检指标)。 +需要注意的是 `accuracy_checking_details_{timestamp}.csv` 中可能存在一个 API 的前向(反向)有多个输出,那么每个输出记录一行,而在 `accuracy_checking_result_{timestamp}.csv` 中的结果需要该 API 的所有结果均为 pass 才能标记为 pass,只要存在一个 error 则标记 error。 + +### 4.1 API 预检指标 + + - API 预检指标是通过对 `accuracy_checking_details_{timestamp}.csv` 中的余弦相似度、最大绝对误差的数值进行判断,得出该 API 是否符合精度标准的参考指标。 + - 余弦相似度大于 0.99,并且最大绝对误差小于 0.0001,标记“pass”,否则标记为“error”。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json new file mode 100644 index 0000000000000000000000000000000000000000..68a47dc26c3cb770e0e3c9a2ce2ada89dcec76c6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json @@ -0,0 +1,9 @@ +{ + "dump_json_path": "./dump.json", + "api_name": "Mint.split.1", + "extract_api_path": "Mint.split.1.json", + "propagation": "backward", + "data_mode": "random_data", + "random_seed": 1234, + "iter_times": 1 +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..38304d525069b99367377235479cbd10ebd76158 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 标准库 +import argparse +import json +import os +import re +import string + +# 应用程序自定义模块 +from msprobe.core.common.file_utils import ( + FileOpen, + load_json, + save_json, + make_dir, + change_mode, +) +from msprobe.core.common.utils import ( + check_file_or_directory_path, + check_op_str_pattern_valid, + is_int, +) +from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst +from msprobe.core.common.log import logger +from msprobe.core.common.decorator import recursion_depth_decorator + +OPERATOR_TYPE = ("Functional", "Tensor", "Torch", "Mint") + +API_INFO = 2 +FOUR_SEGMENT = 4 +FIVE_SEGMENT = 5 +DATA_NAME = "data_name" +API_MAX_LENGTH = 30 +PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD] +DATAMODE_LIST = ["random_data", "real_data"] +ITER_MAX_TIMES = 1000 +FRAMEWORK = 'framework' +REAL_DATA_PATH = 'real_data_path' +EXCLUED = {FRAMEWORK, REAL_DATA_PATH} + + +class APIInfo: + def __init__(self, api_full_name, api_info_dict, backward_info=None): + self.api_full_name = api_full_name + self.api_info_dict = api_info_dict + self.backward_info = backward_info + + @property + def api_type(self): + return self.api_full_name.split(Const.SEP, -1)[0] + + @classmethod + def from_json(cls, json_content, propagation): + forward_name, forward_dict = list(json_content.items())[0] + forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict) + + if propagation == Const.BACKWARD: + backward_name, backward_dict = list(json_content.items())[1] + backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict) + forward_info.backward_info = backward_info + + if not forward_info.is_supported_type(): + raise ValueError(f"type {forward_info.api_type} of API is not supported!") + + return forward_info + + def is_supported_type(self): + return self.api_type in OPERATOR_TYPE + + +class CommonConfig: + def __init__(self, json_config): + self.dump_json_path = json_config.get('dump_json_path') + self.api_name = json_config.get('api_name') + self.extract_api_path = json_config.get('extract_api_path') + self.propagation = json_config.get('propagation') + self.data_mode = json_config.get('data_mode') + self.random_seed = json_config.get('random_seed') + self.iter_times = json_config.get('iter_times') + self._check_config() + + def check_user_settings(self): + iter_t = self.iter_times + if iter_t <= 0 or iter_t > ITER_MAX_TIMES: + raise ValueError(f"iter_times should be range from 1 to {ITER_MAX_TIMES}.") + + json_file = self.extract_api_path + propagation = self.propagation + + json_content = load_json(json_file) + + # ensure the dict is not empty + if not json_content: + raise ValueError(f'json file is empty!') + + # ensure json_content is of type dict + if not isinstance(json_content, dict): + raise ValueError(f'content of json file is not a dict!') + + # ensure the length of json_content is within allowed limits + + filtered = {k: v for k, v in json_content.items() if k not in EXCLUED} + + if len(filtered) > API_INFO: + raise ValueError(f'json file has more than one API, the API only contains forward and backward info') + + # Retrieve the first API name and dictionary + forward_item = next(iter(json_content.items()), None) + if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]: + raise ValueError(f'Invalid forward API data in json_content!') + + # if propagation is backward, ensure json file contains forward and backward info + if propagation == Const.BACKWARD and len(filtered) < API_INFO: + raise ValueError(f'Backward propagation requires contains forward and backward info!') + + # if propagation is backward, ensure it has valid data + if propagation == Const.BACKWARD: + backward_item = list(json_content.items())[1] + if not isinstance(backward_item[1], dict) or not backward_item[1]: + raise ValueError(f'Invalid backward API data in json_content!') + + return json_content + + def _check_config(self): + if self.dump_json_path: + check_file_or_directory_path(self.dump_json_path) + if self.api_name: + check_op_str_pattern_valid(self.api_name) + if len(self.api_name) > API_MAX_LENGTH: + raise ValueError(f'API name {self.api_name} is too long!') + make_dir(os.path.dirname(self.extract_api_path)) + if self.propagation and self.propagation not in PROPAGATION_LIST: + raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}') + if self.data_mode and self.data_mode not in DATAMODE_LIST: + raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}') + if not is_int(self.random_seed): + raise ValueError(f'random_seed is invalid, it should be an int') + if not is_int(self.iter_times): + raise ValueError(f'iter_times is invalid, it should be an int') + + +class APIExtractor: + def __init__(self, api_name, dump_json_path, output_file): + self.api_name = api_name + self.dump_json_path = dump_json_path + self.output_file = output_file + self.data = None + self.framework = None + self.real_data_path = None + + def extract_op(self): + self.data = load_json(self.dump_json_path) + # 拿到 framework + self.framework = self.data.get(FRAMEWORK, None) + + new_data = {} + extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+") # 修改为只要包含或等于apiname即可,不需要是只包含 + + self.real_data_path = self.data.get('dump_data_dir', '') + + for key, value in self.data.get('data', {}).items(): + if extract_key_pattern.match(key): + if self.real_data_path: + value = self.load_real_data_path(value, self.real_data_path) + new_data[key] = value + + if self.real_data_path is not None: + new_data[REAL_DATA_PATH] = self.real_data_path + + # 把 framework 加进去 + if self.framework is not None: + new_data[FRAMEWORK] = self.framework + if not new_data: + logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.") + else: + save_json(self.output_file, new_data, indent=4) + logger.info( + f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}") + + def load_real_data_path(self, value, dump_data_dir): + parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT] + for parameter in parameters: + for v in value.get(parameter, []): + if v is not None: + self.update_data_name(v, dump_data_dir) + return value + + @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name") + def update_data_name(self, data, dump_data_dir): + if isinstance(data, list): + for item in data: + self.update_data_name(item, dump_data_dir) + elif DATA_NAME in data: + data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME]) + + +class OperatorScriptGenerator: + def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward): + self.common_config = common_config + self.args_info_forward = args_info_forward + self.kwargs_info_forward = kwargs_info_forward + self.args_info_backward = args_info_backward + + @staticmethod + def extract_detailed_api_segments(full_api_name): + """ + Function Description: + Extract the name of the API. + Parameter: + full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0 + Return: + api_name: Name of api. Example: matmul, mul, etc. + full_api_name: Full name of api. Example: torch.matmul.0 + direction_status: Direction status of api. Example: forward, backward, etc. + """ + api_parts = full_api_name.split(Const.SEP) + api_parts_length = len(api_parts) + api_type, api_name, api_order = None, None, None + if api_parts_length == FOUR_SEGMENT: + api_type, api_name, api_order, _ = api_parts + elif api_parts_length == FIVE_SEGMENT: + api_type, prefix, api_name, api_order, _ = api_parts + api_name = Const.SEP.join([prefix, api_name]) + return api_type, api_name, api_order + + @staticmethod + def generate_forward_inputs_code(args_info): + names = [] + + def collect(info): + if isinstance(info, dict): + names.append(info["parameter_name"]) + else: + for sub in info: + collect(sub) + + collect(args_info) + + return ( + " forward_inputs = [\n" + " ComputeElement(parameter=info)\n" + " for info in (" + ", ".join(names) + ")\n" + " ]\n" + ) + + @staticmethod + def generate_kwargs_compute_element_dict_code(): + return ( + " # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n" + " kwargs_compute_element_dict = {\n" + " key_str: ComputeElement(compute_element_info=compute_element_info)\n" + " for key_str, compute_element_info in kwargs_device.items()\n" + " }\n" + ) + + @staticmethod + def generate_gradient_inputs_code(args_info_backward): + names = [] + + def collect(info): + if isinstance(info, dict): + names.append(info["parameter_name"]) + else: + for sub in info: + collect(sub) + + collect(args_info_backward) + + return ( + " # —— 构造反向梯度 ComputeElement 列表 —— #\n" + " gradient_inputs = [\n" + " ComputeElement(parameter=info)\n" + " for info in (" + ", ".join(names) + ")\n" + " ]\n" + ) + + def get_settings(self, api_full_name): + ''' + internal_settings contain all information needed for the operator program. + keys: + api_full_name: api_type.api_name.ordinal_number + api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch + api_name: name of API + ordinal_number: how many times the same api has been called + direction_status: forward + random_seed: if mode is random_data, random seed is random_seed + iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data, + iter_times does not matter + args_element_assignment: code for args assignment + args_list_generator_device: code for generate args list on device + args_list_generator_bench: code for generate args list on bench + kwargs_value_assignment: code for kwargs assignment + kwargs_dict_generator_device: code for generate kwargs dict on device + kwargs_dict_generator_bench: code for generate kwargs dict on bench + ''' + # Generate an internal setting dictionary based on user settings + # including API name, type, comparison standard, random seed, number of iterations and other information + internal_settings = {} + internal_settings["propagation"] = self.common_config.propagation + internal_settings["api_full_name"] = api_full_name + api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name) + if api_type == "Functional": + internal_settings["api_type"] = "torch.nn.functional" + elif api_type == "Tensor": + internal_settings["api_type"] = "torch.Tensor" + else: + internal_settings["api_type"] = "torch" + internal_settings["api_name"] = api_name + internal_settings["ordinal_number"] = ordinal_number + internal_settings["direction_status"] = self.common_config.propagation + internal_settings["random_seed"] = self.common_config.random_seed + internal_settings["data_mode"] = self.common_config.data_mode + if self.common_config.data_mode == "real_data": + internal_settings["iter_times"] = 1 + else: + internal_settings["iter_times"] = self.common_config.iter_times + + internal_settings["args_info_forward"] = self.args_info_forward + internal_settings["kwargs_info_forward"] = self.kwargs_info_forward + internal_settings["args_info_backward"] = self.args_info_backward + + return internal_settings + + +def _op_generator_parser(parser): + parser.add_argument("-i", "--config_input", dest="config_input", type=str, + help=" Path of config json file", required=True) + parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str, + help=" Path of extract api_name.json.", required=True) + + +def parse_json_config(json_file_path): + if not json_file_path: + raise Exception("config_input path can not be empty, please check.") + json_config = load_json(json_file_path) + common_config = CommonConfig(json_config) + return common_config + + +def _run_operator_generate_commond(cmd_args): + common_config = parse_json_config(cmd_args.config_input) + + if common_config.dump_json_path: + api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path) + api_extract.extract_op() + framework = api_extract.framework + real_data_path = api_extract.real_data_path + check_file_or_directory_path(common_config.extract_api_path) + check_file_or_directory_path(cmd_args.api_output_path, isdir=True) + json_content = common_config.check_user_settings() + api_info = APIInfo.from_json(json_content, common_config.propagation) + + if common_config.propagation == Const.BACKWARD: + # read and check json + api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict + api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name, + api_info.backward_info.api_info_dict) + args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS) + kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS) + if Const.GRAD_INPUT in api_info_dict_backward: + args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT) + elif Const.INPUT in api_info_dict_backward: + args_info_backward = api_info_dict_backward.get(Const.INPUT) + op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward) + internal_settings = op_generate.get_settings(api_full_name_backward) + internal_settings[FRAMEWORK] = framework + internal_settings[REAL_DATA_PATH] = real_data_path + else: + # read and check json + api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict + + args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS) + + kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS) + + op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None) + internal_settings = op_generate.get_settings(api_full_name_forward) + internal_settings[FRAMEWORK] = framework + internal_settings[REAL_DATA_PATH] = real_data_path + + template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template") + operator_script_path = os.path.join(cmd_args.api_output_path, + "{0}.py".format(internal_settings.get("api_full_name"))) + + class SafeDict(dict): + def __missing__(self, key): + # leave {key} in the output if it’s not in the dict + return '{' + key + '}' + + class RobustFormatter(string.Formatter): + def vformat(self, format_string, args, kwargs): + result = [] + # parse() 会把文本和每个占位符拆开 + for literal, field_name, format_spec, conversion in self.parse(format_string): + # 输出字面文本 + result.append(literal) + if field_name is None: + continue + try: + # 正常获取变量并格式化 + obj, _ = self.get_field(field_name, args, kwargs) + if conversion: + obj = self.convert_field(obj, conversion) + result.append(self.format_field(obj, format_spec)) + except Exception: + # 不管是 KeyError 还是 ValueError,都原样回写 {field_name[:format_spec]} + placeholder = '{' + field_name + if conversion: + placeholder += '!' + conversion + if format_spec: + placeholder += ':' + format_spec + placeholder += '}' + result.append(placeholder) + return ''.join(result) + + fmt = RobustFormatter() + with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout: + code_template = ftemp.read() + # 这里用 fmt.format,不用 format_map + fout.write(fmt.format(code_template, **internal_settings)) + + change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY) + + logger.info(f"Generate operator script successfully and the name is {operator_script_path}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _op_generator_parser(parser) + cmd_args = parser.parse_args() + _run_operator_generate_commond(cmd_args) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template new file mode 100644 index 0000000000000000000000000000000000000000..2e26f606ace3ab7b5c57fc46b13118e9ad90487d --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template @@ -0,0 +1,215 @@ +import os +import re +import stat +import time +from enum import Enum, auto +from abc import ABC, abstractmethod +import csv + +import gc +import sys +from pathlib import Path +import mindspore +from mindspore import ops + + +from tabulate import tabulate + +import logging + +import traceback + + + +def error_log_with_exp(self, msg: str, exp: Exception): + """ + msg: 你的错误提示 + exp: 你要记录的 Exception 实例 + """ + # 将 Exception 的类型、消息和 traceback 通过 exc_info 参数一并传给 .error() + self.error(msg, exc_info=(type(exp), exp, exp.__traceback__)) + +# 把它挂到 Logger 上 +logging.Logger.error_log_with_exp = error_log_with_exp + + + +# 1. 基本配置:设置日志级别为 INFO,默认输出到控制台 +logging.basicConfig(level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + datefmt='%H:%M:%S') + +logger = logging.getLogger() + + +# ======= 常数类 ======= + +class CodedException(Exception): + def __init__(self, code, error_info=''): + super().__init__() + self.code = code + self.error_info = self.err_strs.get(code) + error_info + + def __str__(self): + return self.error_info + + +class ApiAccuracyCheckerException(CodedException): + ParseJsonFailed = 0 + UnsupportType = 1 + WrongValue = 2 + ApiWrong = 3 + err_strs = { + ParseJsonFailed: "[msprobe] Api Accuracy Checker parse json failed: ", + UnsupportType: "[msprobe] Api Accuracy Checker get unsupported type: ", + WrongValue: "[msprobe] Api Accuracy Checker get wrong value: ", + ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ", + } + + +class FileCheckConst: + """ + Class for file check const + """ + READ_ABLE = "read" + WRITE_ABLE = "write" + READ_WRITE_ABLE = "read and write" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + JSON_SUFFIX = ".json" + PT_SUFFIX = ".pt" + CSV_SUFFIX = ".csv" + XLSX_SUFFIX = ".xlsx" + YAML_SUFFIX = ".yaml" + IR_SUFFIX = ".ir" + ZIP_SUFFIX = ".zip" + SHELL_SUFFIX = ".sh" + MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_ZIP_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 + DIR = "dir" + FILE = "file" + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + FILE_SIZE_DICT = { + PKL_SUFFIX: MAX_PKL_SIZE, + NUMPY_SUFFIX: MAX_NUMPY_SIZE, + JSON_SUFFIX: MAX_JSON_SIZE, + PT_SUFFIX: MAX_PT_SIZE, + CSV_SUFFIX: MAX_CSV_SIZE, + XLSX_SUFFIX: MAX_XLSX_SIZE, + YAML_SUFFIX: MAX_YAML_SIZE, + IR_SUFFIX: MAX_IR_SIZE, + ZIP_SUFFIX: MAX_ZIP_SIZE + } + CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]' + +class Const: + MAX_DEPTH = 10 + PT_FRAMEWORK = "pytorch" + MS_FRAMEWORK = "mindspore" + MT_FRAMEWORK = "mindtorch" + SEP = "." + KWARGS = 'kwargs' + INPUT = 'input' + OUTPUT = 'output' + INPUT_ARGS = 'input_args' + INPUT_KWARGS = 'input_kwargs' + GRAD_INPUT = 'grad_input' + GRAD_OUTPUT = 'grad_output' + BACKWARD = 'backward' + FORWARD = 'forward' + + +class CompareConst: + # compare result data + PASS = 'pass' + WARNING = 'Warning' + ERROR = 'error' + TRUE = 'TRUE' + FALSE = 'FALSE' + SKIP = 'SKIP' + + # compare result column name + COSINE = "Cosine" + EUC_DIST = "EucDist" + MAX_ABS_ERR = "MaxAbsErr" + MAX_RELATIVE_ERR = "MaxRelativeErr" + MIN_RELATIVE_ERR = "MinRelativeErr" + MEAN_RELATIVE_ERR = "MeanRelativeErr" + NORM_RELATIVE_ERR = "NormRelativeErr" + + # accuracy standards + COS_THRESHOLD = 0.99 + MAX_ABS_ERR_THRESHOLD = 0.001 + MAX_RELATIVE_ERR_THRESHOLD = 0.001 + COS_MAX_THRESHOLD = 0.9 + MAX_ABS_ERR_MAX_THRESHOLD = 1 + +class MsCompareConst: + # api_info field + MINT = "Mint" + MINT_FUNCTIONAL = "MintFunctional" + TENSOR_API = "Tensor" + FUNCTIONAL_API = "Functional" + FUSION_API = "FUSION" + + API_NAME_STR_LENGTH = 4 + MAX_RECURSION_DEPTH = 20 + + # Mindtorch api_info field + MINDTORCH_TENSOR = "Tensor" + MINDTORCH = "Torch" + MINDTORCH_FUNC = "Functional" + MINDTORCH_NPU = "NPU" + MINDTORCH_DIST = "Distributed" + + MT_VALID_API_TYPES = [ + MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR + ] + SUPPORTED_FUSION_LIST = ["flash_attention_score"] + + TASK_FIELD = "task" + STATISTICS_TASK = "statistics" + FRAMEWORK = "framework" + TENSOR_TASK = "tensor" + DUMP_DATA_DIR_FIELD = "dump_data_dir" + DATA_FIELD = "data" + + # supported api yaml + SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" + SUPPORTED_TENSOR_LIST_KEY = "tensor" + + # detail_csv + DETAIL_CSV_API_NAME = "API Name" + DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" + DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" + DETAIL_CSV_SHAPE = "Shape" + DETAIL_CSV_PASS_STATUS = "Status" + DETAIL_CSV_MESSAGE = "Message" + DETAIL_CSV_FILE_NAME = "accuracy_checking_details" + + # result_csv + RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" + RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" + RESULT_CSV_FILE_NAME = "accuracy_checking_result" + + EPSILON = 1e-8 + + class ProcessStatus: + SUCCESS = "success" + API_NOT_FOUND = "api_not_found" + EXCEPTION_SKIP = "exception_skip" +