diff --git a/debug/accuracy_tools/msprobe/docs/33.generate_operator_PyTorch.md b/debug/accuracy_tools/msprobe/docs/33.generate_operator_PyTorch.md
new file mode 100644
index 0000000000000000000000000000000000000000..d44989ac3549e61e28f2d1897139b78fa52abfc9
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/docs/33.generate_operator_PyTorch.md
@@ -0,0 +1,107 @@
+# 单算子API自动生成脚本
+
+## 1 简介
+
+单算子API自动生成脚本通过提取dump数据中的可疑算子,对其进行单API复现,输出单API精度的比对结果。具体而言,该工具可以从dump数据中提取可疑API的前反向信息,根据前反向数据生成单API的前反向过程,最后通过**新精度标准比对法**a将 NPU/GPU 和 CPU 的结果进行比对,从而给出不同比对方法下的比对结果。本工具支持**随机生成模式和真实数据模式**b。
+
+a. 在生成单API脚本时可以选择由工具构造随机数获得 dump 数据或选择真实输入的数据进行单API复现。随机生成模式(对应 task: "statistics")执行效率高,可以快速获得结果,但数据精度低,只能大致判断精度问题;真实数据模式(对应 task: "tensor")执行效率略低于随机生成模式,但是数据精度高,可以准确判断精度问题。
+
+## 2 使用方式
+
+### 前提
+1. 安装 msprobe。详见[ msprobe 安装](./01.installation.md)章节。
+2. 已完成对训练过程的dump,获得dump.json文件。
+ [MindSpore 场景下的数据采集](./06.data_dump_MindSpore.md)章节或[Msadapter 场景下的数据采集](./29.data_dump_MSAdapter.md)章节,注意需要配置 level="L1"。
+
+3. 发现某个算子疑似存在精度问题,并得知算子名,如Mint.split.1、Functional.softmax.3、Tensor.add.0、Torch.matmul.5等
+
+### 2.1 配置config_op.json
+单API复现参数配置如下(以复现softmax算子为例):
+```
+{
+ "dump_json_path": "./dump.json",
+ "api_name": "Mint.split.1",
+ "extract_api_path": "Mint.split.1.json",
+ "propagation": "backward",
+ "data_mode": "random_data",
+ "random_seed": 42,
+ "iter_times": 1
+}
+```
+**配置文件参数说明**
+
+ | 参数名称 | 解释 | 是否必选 |
+ | ---------------------------- |----------------------------------------------------------------------------| ---------------------------------- |
+ | dump_json_path | dump.json的文件路径,包含所有dump算子的信息;如果已经提取了可疑算子并保存可以不指定。 | 否 |
+ | api_name | 算子名,如Functional.softmax.3、Tensor.add.0、Torch.matmul.5等。 | 否 |
+ | extract_api_path | 提取可疑算子的json文件路径 | 是 |
+ | propagation | 选择复现算子的forward还是backward,默认为forward | 否 |
+ | data_mode | 选择复现算子的随机数据(random_data)还是真实数据(real_data)模式,默认为random_data | 否 |
+ | random_seed | 仅random_data模式有效,表示手动设定的随机种子,默认为1234 | 否 |
+ | iter_times | 仅random_data模式有效,表示单API运行的次数,由于安全相关原因,最大支持设置为1000 | 否 |
+
+ ### 2.2 运行命令生成单API脚本
+config_op.json配置好后,运行如下命令:
+```
+msprobe -f mindspore op_generate -i ./config.json -o ./
+```
+或者
+
+进入到mstt的generate_op_script文件夹
+```
+cd mstt/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script
+```
+运行
+```
+python op_generator.py -i ./config_op.json -o ./
+```
+**参数说明**
+ | 参数名称 | 解释 | 是否必选 |
+ | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- |
+ | -i 或 --config_input | config_op.json的路径 | 是 |
+ | -o 或 --api_output_path | 单API脚本的输出路径 | 是 |
+
+ ### 2.3 运行单API脚本
+ 运行完op_generator.py后,会在指定路径下生成api_name.py的单API脚本,例如Mint.split.1.forward.py、Functional.softmax.3.backward.py、Tensor.add.0.forward.py、Torch.matmul.5.backward.py
+
+运行单API脚本即可获得不同比对方法下的比对结果
+```
+python api_name.py
+```
+
+**运行结果说明**
+
+单算子脚本生成的 `accuracy_checking_result_{timestamp}.csv` 和 `accuracy_checking_details_{timestamp}.csv` 文件内容详情如下:
+
+`accuracy_checking_details_{timestamp}.csv`
+
+| 字段 | 含义 |
+| ------------------- | ------------------------------------------------------------ |
+| API Name | API 名称。 |
+| Bench Dtype | 标杆数据的 API 数据类型。 |
+| Tested Dtype | 被检验数据的 API 数据类型。 |
+| Shape | API 的 Shape 信息。 |
+| Cosine | 被检验数据与标杆数据的余弦相似度。 |
+| MaxAbsErr | 被检验数据与标杆数据的最大绝对误差。 |
+| MaxRelativeErr | 被检验数据与标杆数据的最大相对误差。 |
+| Status | API 预检通过状态,pass 表示通过测试,error 表示未通过。 |
+| Message | 提示信息。 |
+
+注意:PyTorch 无法对 dtype 为整数类型的 tensor 进行反向求导,而 MindSpore 支持。反向过程的预检仅比较 dtype 为浮点型的输出。
+
+`accuracy_checking_result_{timestamp}.csv`
+
+| 字段 | 含义 |
+| --------------------- | ----------------- |
+| API Name | API 名称。 |
+| Forward Test Success | 前向 API 是否通过测试,pass 为通过,error 为错误。 |
+| Backward Test Success | 反向 API 是否通过测试,pass 为通过,error 为错误,如果是空白的话代表该 API 没有反向输出。 |
+| Message | 提示信息。 |
+
+Forward Test Success 和 Backward Test Success 是否通过测试是由 `accuracy_checking_details_{timestamp}.csv` 中的余弦相似度、最大绝对误差判定结果决定的。具体规则详见 [4.1 API 预检指标](#41-api-预检指标)。
+需要注意的是 `accuracy_checking_details_{timestamp}.csv` 中可能存在一个 API 的前向(反向)有多个输出,那么每个输出记录一行,而在 `accuracy_checking_result_{timestamp}.csv` 中的结果需要该 API 的所有结果均为 pass 才能标记为 pass,只要存在一个 error 则标记 error。
+
+### 4.1 API 预检指标
+
+ - API 预检指标是通过对 `accuracy_checking_details_{timestamp}.csv` 中的余弦相似度、最大绝对误差的数值进行判断,得出该 API 是否符合精度标准的参考指标。
+ - 余弦相似度大于 0.99,并且最大绝对误差小于 0.0001,标记“pass”,否则标记为“error”。
\ No newline at end of file
diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json
new file mode 100644
index 0000000000000000000000000000000000000000..68a47dc26c3cb770e0e3c9a2ce2ada89dcec76c6
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json
@@ -0,0 +1,9 @@
+{
+ "dump_json_path": "./dump.json",
+ "api_name": "Mint.split.1",
+ "extract_api_path": "Mint.split.1.json",
+ "propagation": "backward",
+ "data_mode": "random_data",
+ "random_seed": 1234,
+ "iter_times": 1
+}
\ No newline at end of file
diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..38304d525069b99367377235479cbd10ebd76158
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py
@@ -0,0 +1,446 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# 标准库
+import argparse
+import json
+import os
+import re
+import string
+
+# 应用程序自定义模块
+from msprobe.core.common.file_utils import (
+ FileOpen,
+ load_json,
+ save_json,
+ make_dir,
+ change_mode,
+)
+from msprobe.core.common.utils import (
+ check_file_or_directory_path,
+ check_op_str_pattern_valid,
+ is_int,
+)
+from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
+from msprobe.core.common.log import logger
+from msprobe.core.common.decorator import recursion_depth_decorator
+
+OPERATOR_TYPE = ("Functional", "Tensor", "Torch", "Mint")
+
+API_INFO = 2
+FOUR_SEGMENT = 4
+FIVE_SEGMENT = 5
+DATA_NAME = "data_name"
+API_MAX_LENGTH = 30
+PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
+DATAMODE_LIST = ["random_data", "real_data"]
+ITER_MAX_TIMES = 1000
+FRAMEWORK = 'framework'
+REAL_DATA_PATH = 'real_data_path'
+EXCLUED = {FRAMEWORK, REAL_DATA_PATH}
+
+
+class APIInfo:
+ def __init__(self, api_full_name, api_info_dict, backward_info=None):
+ self.api_full_name = api_full_name
+ self.api_info_dict = api_info_dict
+ self.backward_info = backward_info
+
+ @property
+ def api_type(self):
+ return self.api_full_name.split(Const.SEP, -1)[0]
+
+ @classmethod
+ def from_json(cls, json_content, propagation):
+ forward_name, forward_dict = list(json_content.items())[0]
+ forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict)
+
+ if propagation == Const.BACKWARD:
+ backward_name, backward_dict = list(json_content.items())[1]
+ backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict)
+ forward_info.backward_info = backward_info
+
+ if not forward_info.is_supported_type():
+ raise ValueError(f"type {forward_info.api_type} of API is not supported!")
+
+ return forward_info
+
+ def is_supported_type(self):
+ return self.api_type in OPERATOR_TYPE
+
+
+class CommonConfig:
+ def __init__(self, json_config):
+ self.dump_json_path = json_config.get('dump_json_path')
+ self.api_name = json_config.get('api_name')
+ self.extract_api_path = json_config.get('extract_api_path')
+ self.propagation = json_config.get('propagation')
+ self.data_mode = json_config.get('data_mode')
+ self.random_seed = json_config.get('random_seed')
+ self.iter_times = json_config.get('iter_times')
+ self._check_config()
+
+ def check_user_settings(self):
+ iter_t = self.iter_times
+ if iter_t <= 0 or iter_t > ITER_MAX_TIMES:
+ raise ValueError(f"iter_times should be range from 1 to {ITER_MAX_TIMES}.")
+
+ json_file = self.extract_api_path
+ propagation = self.propagation
+
+ json_content = load_json(json_file)
+
+ # ensure the dict is not empty
+ if not json_content:
+ raise ValueError(f'json file is empty!')
+
+ # ensure json_content is of type dict
+ if not isinstance(json_content, dict):
+ raise ValueError(f'content of json file is not a dict!')
+
+ # ensure the length of json_content is within allowed limits
+
+ filtered = {k: v for k, v in json_content.items() if k not in EXCLUED}
+
+ if len(filtered) > API_INFO:
+ raise ValueError(f'json file has more than one API, the API only contains forward and backward info')
+
+ # Retrieve the first API name and dictionary
+ forward_item = next(iter(json_content.items()), None)
+ if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
+ raise ValueError(f'Invalid forward API data in json_content!')
+
+ # if propagation is backward, ensure json file contains forward and backward info
+ if propagation == Const.BACKWARD and len(filtered) < API_INFO:
+ raise ValueError(f'Backward propagation requires contains forward and backward info!')
+
+ # if propagation is backward, ensure it has valid data
+ if propagation == Const.BACKWARD:
+ backward_item = list(json_content.items())[1]
+ if not isinstance(backward_item[1], dict) or not backward_item[1]:
+ raise ValueError(f'Invalid backward API data in json_content!')
+
+ return json_content
+
+ def _check_config(self):
+ if self.dump_json_path:
+ check_file_or_directory_path(self.dump_json_path)
+ if self.api_name:
+ check_op_str_pattern_valid(self.api_name)
+ if len(self.api_name) > API_MAX_LENGTH:
+ raise ValueError(f'API name {self.api_name} is too long!')
+ make_dir(os.path.dirname(self.extract_api_path))
+ if self.propagation and self.propagation not in PROPAGATION_LIST:
+ raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}')
+ if self.data_mode and self.data_mode not in DATAMODE_LIST:
+ raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}')
+ if not is_int(self.random_seed):
+ raise ValueError(f'random_seed is invalid, it should be an int')
+ if not is_int(self.iter_times):
+ raise ValueError(f'iter_times is invalid, it should be an int')
+
+
+class APIExtractor:
+ def __init__(self, api_name, dump_json_path, output_file):
+ self.api_name = api_name
+ self.dump_json_path = dump_json_path
+ self.output_file = output_file
+ self.data = None
+ self.framework = None
+ self.real_data_path = None
+
+ def extract_op(self):
+ self.data = load_json(self.dump_json_path)
+ # 拿到 framework
+ self.framework = self.data.get(FRAMEWORK, None)
+
+ new_data = {}
+ extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+") # 修改为只要包含或等于apiname即可,不需要是只包含
+
+ self.real_data_path = self.data.get('dump_data_dir', '')
+
+ for key, value in self.data.get('data', {}).items():
+ if extract_key_pattern.match(key):
+ if self.real_data_path:
+ value = self.load_real_data_path(value, self.real_data_path)
+ new_data[key] = value
+
+ if self.real_data_path is not None:
+ new_data[REAL_DATA_PATH] = self.real_data_path
+
+ # 把 framework 加进去
+ if self.framework is not None:
+ new_data[FRAMEWORK] = self.framework
+ if not new_data:
+ logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
+ else:
+ save_json(self.output_file, new_data, indent=4)
+ logger.info(
+ f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}")
+
+ def load_real_data_path(self, value, dump_data_dir):
+ parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT]
+ for parameter in parameters:
+ for v in value.get(parameter, []):
+ if v is not None:
+ self.update_data_name(v, dump_data_dir)
+ return value
+
+ @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
+ def update_data_name(self, data, dump_data_dir):
+ if isinstance(data, list):
+ for item in data:
+ self.update_data_name(item, dump_data_dir)
+ elif DATA_NAME in data:
+ data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME])
+
+
+class OperatorScriptGenerator:
+ def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward):
+ self.common_config = common_config
+ self.args_info_forward = args_info_forward
+ self.kwargs_info_forward = kwargs_info_forward
+ self.args_info_backward = args_info_backward
+
+ @staticmethod
+ def extract_detailed_api_segments(full_api_name):
+ """
+ Function Description:
+ Extract the name of the API.
+ Parameter:
+ full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0
+ Return:
+ api_name: Name of api. Example: matmul, mul, etc.
+ full_api_name: Full name of api. Example: torch.matmul.0
+ direction_status: Direction status of api. Example: forward, backward, etc.
+ """
+ api_parts = full_api_name.split(Const.SEP)
+ api_parts_length = len(api_parts)
+ api_type, api_name, api_order = None, None, None
+ if api_parts_length == FOUR_SEGMENT:
+ api_type, api_name, api_order, _ = api_parts
+ elif api_parts_length == FIVE_SEGMENT:
+ api_type, prefix, api_name, api_order, _ = api_parts
+ api_name = Const.SEP.join([prefix, api_name])
+ return api_type, api_name, api_order
+
+ @staticmethod
+ def generate_forward_inputs_code(args_info):
+ names = []
+
+ def collect(info):
+ if isinstance(info, dict):
+ names.append(info["parameter_name"])
+ else:
+ for sub in info:
+ collect(sub)
+
+ collect(args_info)
+
+ return (
+ " forward_inputs = [\n"
+ " ComputeElement(parameter=info)\n"
+ " for info in (" + ", ".join(names) + ")\n"
+ " ]\n"
+ )
+
+ @staticmethod
+ def generate_kwargs_compute_element_dict_code():
+ return (
+ " # ---- 构造 kwargs 对应的 ComputeElement 字典 ----\n"
+ " kwargs_compute_element_dict = {\n"
+ " key_str: ComputeElement(compute_element_info=compute_element_info)\n"
+ " for key_str, compute_element_info in kwargs_device.items()\n"
+ " }\n"
+ )
+
+ @staticmethod
+ def generate_gradient_inputs_code(args_info_backward):
+ names = []
+
+ def collect(info):
+ if isinstance(info, dict):
+ names.append(info["parameter_name"])
+ else:
+ for sub in info:
+ collect(sub)
+
+ collect(args_info_backward)
+
+ return (
+ " # —— 构造反向梯度 ComputeElement 列表 —— #\n"
+ " gradient_inputs = [\n"
+ " ComputeElement(parameter=info)\n"
+ " for info in (" + ", ".join(names) + ")\n"
+ " ]\n"
+ )
+
+ def get_settings(self, api_full_name):
+ '''
+ internal_settings contain all information needed for the operator program.
+ keys:
+ api_full_name: api_type.api_name.ordinal_number
+ api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch
+ api_name: name of API
+ ordinal_number: how many times the same api has been called
+ direction_status: forward
+ random_seed: if mode is random_data, random seed is random_seed
+ iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data,
+ iter_times does not matter
+ args_element_assignment: code for args assignment
+ args_list_generator_device: code for generate args list on device
+ args_list_generator_bench: code for generate args list on bench
+ kwargs_value_assignment: code for kwargs assignment
+ kwargs_dict_generator_device: code for generate kwargs dict on device
+ kwargs_dict_generator_bench: code for generate kwargs dict on bench
+ '''
+ # Generate an internal setting dictionary based on user settings
+ # including API name, type, comparison standard, random seed, number of iterations and other information
+ internal_settings = {}
+ internal_settings["propagation"] = self.common_config.propagation
+ internal_settings["api_full_name"] = api_full_name
+ api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name)
+ if api_type == "Functional":
+ internal_settings["api_type"] = "torch.nn.functional"
+ elif api_type == "Tensor":
+ internal_settings["api_type"] = "torch.Tensor"
+ else:
+ internal_settings["api_type"] = "torch"
+ internal_settings["api_name"] = api_name
+ internal_settings["ordinal_number"] = ordinal_number
+ internal_settings["direction_status"] = self.common_config.propagation
+ internal_settings["random_seed"] = self.common_config.random_seed
+ internal_settings["data_mode"] = self.common_config.data_mode
+ if self.common_config.data_mode == "real_data":
+ internal_settings["iter_times"] = 1
+ else:
+ internal_settings["iter_times"] = self.common_config.iter_times
+
+ internal_settings["args_info_forward"] = self.args_info_forward
+ internal_settings["kwargs_info_forward"] = self.kwargs_info_forward
+ internal_settings["args_info_backward"] = self.args_info_backward
+
+ return internal_settings
+
+
+def _op_generator_parser(parser):
+ parser.add_argument("-i", "--config_input", dest="config_input", type=str,
+ help=" Path of config json file", required=True)
+ parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
+ help=" Path of extract api_name.json.", required=True)
+
+
+def parse_json_config(json_file_path):
+ if not json_file_path:
+ raise Exception("config_input path can not be empty, please check.")
+ json_config = load_json(json_file_path)
+ common_config = CommonConfig(json_config)
+ return common_config
+
+
+def _run_operator_generate_commond(cmd_args):
+ common_config = parse_json_config(cmd_args.config_input)
+
+ if common_config.dump_json_path:
+ api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path)
+ api_extract.extract_op()
+ framework = api_extract.framework
+ real_data_path = api_extract.real_data_path
+ check_file_or_directory_path(common_config.extract_api_path)
+ check_file_or_directory_path(cmd_args.api_output_path, isdir=True)
+ json_content = common_config.check_user_settings()
+ api_info = APIInfo.from_json(json_content, common_config.propagation)
+
+ if common_config.propagation == Const.BACKWARD:
+ # read and check json
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
+ api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name,
+ api_info.backward_info.api_info_dict)
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
+ if Const.GRAD_INPUT in api_info_dict_backward:
+ args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT)
+ elif Const.INPUT in api_info_dict_backward:
+ args_info_backward = api_info_dict_backward.get(Const.INPUT)
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward)
+ internal_settings = op_generate.get_settings(api_full_name_backward)
+ internal_settings[FRAMEWORK] = framework
+ internal_settings[REAL_DATA_PATH] = real_data_path
+ else:
+ # read and check json
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
+
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
+
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
+
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None)
+ internal_settings = op_generate.get_settings(api_full_name_forward)
+ internal_settings[FRAMEWORK] = framework
+ internal_settings[REAL_DATA_PATH] = real_data_path
+
+ template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template")
+ operator_script_path = os.path.join(cmd_args.api_output_path,
+ "{0}.py".format(internal_settings.get("api_full_name")))
+
+ class SafeDict(dict):
+ def __missing__(self, key):
+ # leave {key} in the output if it’s not in the dict
+ return '{' + key + '}'
+
+ class RobustFormatter(string.Formatter):
+ def vformat(self, format_string, args, kwargs):
+ result = []
+ # parse() 会把文本和每个占位符拆开
+ for literal, field_name, format_spec, conversion in self.parse(format_string):
+ # 输出字面文本
+ result.append(literal)
+ if field_name is None:
+ continue
+ try:
+ # 正常获取变量并格式化
+ obj, _ = self.get_field(field_name, args, kwargs)
+ if conversion:
+ obj = self.convert_field(obj, conversion)
+ result.append(self.format_field(obj, format_spec))
+ except Exception:
+ # 不管是 KeyError 还是 ValueError,都原样回写 {field_name[:format_spec]}
+ placeholder = '{' + field_name
+ if conversion:
+ placeholder += '!' + conversion
+ if format_spec:
+ placeholder += ':' + format_spec
+ placeholder += '}'
+ result.append(placeholder)
+ return ''.join(result)
+
+ fmt = RobustFormatter()
+ with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout:
+ code_template = ftemp.read()
+ # 这里用 fmt.format,不用 format_map
+ fout.write(fmt.format(code_template, **internal_settings))
+
+ change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
+
+ logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ _op_generator_parser(parser)
+ cmd_args = parser.parse_args()
+ _run_operator_generate_commond(cmd_args)
diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template
new file mode 100644
index 0000000000000000000000000000000000000000..2e26f606ace3ab7b5c57fc46b13118e9ad90487d
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template
@@ -0,0 +1,215 @@
+import os
+import re
+import stat
+import time
+from enum import Enum, auto
+from abc import ABC, abstractmethod
+import csv
+
+import gc
+import sys
+from pathlib import Path
+import mindspore
+from mindspore import ops
+
+
+from tabulate import tabulate
+
+import logging
+
+import traceback
+
+
+
+def error_log_with_exp(self, msg: str, exp: Exception):
+ """
+ msg: 你的错误提示
+ exp: 你要记录的 Exception 实例
+ """
+ # 将 Exception 的类型、消息和 traceback 通过 exc_info 参数一并传给 .error()
+ self.error(msg, exc_info=(type(exp), exp, exp.__traceback__))
+
+# 把它挂到 Logger 上
+logging.Logger.error_log_with_exp = error_log_with_exp
+
+
+
+# 1. 基本配置:设置日志级别为 INFO,默认输出到控制台
+logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s [%(levelname)s] %(message)s',
+ datefmt='%H:%M:%S')
+
+logger = logging.getLogger()
+
+
+# ======= 常数类 =======
+
+class CodedException(Exception):
+ def __init__(self, code, error_info=''):
+ super().__init__()
+ self.code = code
+ self.error_info = self.err_strs.get(code) + error_info
+
+ def __str__(self):
+ return self.error_info
+
+
+class ApiAccuracyCheckerException(CodedException):
+ ParseJsonFailed = 0
+ UnsupportType = 1
+ WrongValue = 2
+ ApiWrong = 3
+ err_strs = {
+ ParseJsonFailed: "[msprobe] Api Accuracy Checker parse json failed: ",
+ UnsupportType: "[msprobe] Api Accuracy Checker get unsupported type: ",
+ WrongValue: "[msprobe] Api Accuracy Checker get wrong value: ",
+ ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ",
+ }
+
+
+class FileCheckConst:
+ """
+ Class for file check const
+ """
+ READ_ABLE = "read"
+ WRITE_ABLE = "write"
+ READ_WRITE_ABLE = "read and write"
+ DIRECTORY_LENGTH = 4096
+ FILE_NAME_LENGTH = 255
+ FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
+ FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
+ PKL_SUFFIX = ".pkl"
+ NUMPY_SUFFIX = ".npy"
+ JSON_SUFFIX = ".json"
+ PT_SUFFIX = ".pt"
+ CSV_SUFFIX = ".csv"
+ XLSX_SUFFIX = ".xlsx"
+ YAML_SUFFIX = ".yaml"
+ IR_SUFFIX = ".ir"
+ ZIP_SUFFIX = ".zip"
+ SHELL_SUFFIX = ".sh"
+ MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
+ MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
+ MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ MAX_ZIP_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
+ MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
+ DIR = "dir"
+ FILE = "file"
+ DATA_DIR_AUTHORITY = 0o750
+ DATA_FILE_AUTHORITY = 0o640
+ FILE_SIZE_DICT = {
+ PKL_SUFFIX: MAX_PKL_SIZE,
+ NUMPY_SUFFIX: MAX_NUMPY_SIZE,
+ JSON_SUFFIX: MAX_JSON_SIZE,
+ PT_SUFFIX: MAX_PT_SIZE,
+ CSV_SUFFIX: MAX_CSV_SIZE,
+ XLSX_SUFFIX: MAX_XLSX_SIZE,
+ YAML_SUFFIX: MAX_YAML_SIZE,
+ IR_SUFFIX: MAX_IR_SIZE,
+ ZIP_SUFFIX: MAX_ZIP_SIZE
+ }
+ CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
+
+class Const:
+ MAX_DEPTH = 10
+ PT_FRAMEWORK = "pytorch"
+ MS_FRAMEWORK = "mindspore"
+ MT_FRAMEWORK = "mindtorch"
+ SEP = "."
+ KWARGS = 'kwargs'
+ INPUT = 'input'
+ OUTPUT = 'output'
+ INPUT_ARGS = 'input_args'
+ INPUT_KWARGS = 'input_kwargs'
+ GRAD_INPUT = 'grad_input'
+ GRAD_OUTPUT = 'grad_output'
+ BACKWARD = 'backward'
+ FORWARD = 'forward'
+
+
+class CompareConst:
+ # compare result data
+ PASS = 'pass'
+ WARNING = 'Warning'
+ ERROR = 'error'
+ TRUE = 'TRUE'
+ FALSE = 'FALSE'
+ SKIP = 'SKIP'
+
+ # compare result column name
+ COSINE = "Cosine"
+ EUC_DIST = "EucDist"
+ MAX_ABS_ERR = "MaxAbsErr"
+ MAX_RELATIVE_ERR = "MaxRelativeErr"
+ MIN_RELATIVE_ERR = "MinRelativeErr"
+ MEAN_RELATIVE_ERR = "MeanRelativeErr"
+ NORM_RELATIVE_ERR = "NormRelativeErr"
+
+ # accuracy standards
+ COS_THRESHOLD = 0.99
+ MAX_ABS_ERR_THRESHOLD = 0.001
+ MAX_RELATIVE_ERR_THRESHOLD = 0.001
+ COS_MAX_THRESHOLD = 0.9
+ MAX_ABS_ERR_MAX_THRESHOLD = 1
+
+class MsCompareConst:
+ # api_info field
+ MINT = "Mint"
+ MINT_FUNCTIONAL = "MintFunctional"
+ TENSOR_API = "Tensor"
+ FUNCTIONAL_API = "Functional"
+ FUSION_API = "FUSION"
+
+ API_NAME_STR_LENGTH = 4
+ MAX_RECURSION_DEPTH = 20
+
+ # Mindtorch api_info field
+ MINDTORCH_TENSOR = "Tensor"
+ MINDTORCH = "Torch"
+ MINDTORCH_FUNC = "Functional"
+ MINDTORCH_NPU = "NPU"
+ MINDTORCH_DIST = "Distributed"
+
+ MT_VALID_API_TYPES = [
+ MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
+ ]
+ SUPPORTED_FUSION_LIST = ["flash_attention_score"]
+
+ TASK_FIELD = "task"
+ STATISTICS_TASK = "statistics"
+ FRAMEWORK = "framework"
+ TENSOR_TASK = "tensor"
+ DUMP_DATA_DIR_FIELD = "dump_data_dir"
+ DATA_FIELD = "data"
+
+ # supported api yaml
+ SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
+ SUPPORTED_TENSOR_LIST_KEY = "tensor"
+
+ # detail_csv
+ DETAIL_CSV_API_NAME = "API Name"
+ DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
+ DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
+ DETAIL_CSV_SHAPE = "Shape"
+ DETAIL_CSV_PASS_STATUS = "Status"
+ DETAIL_CSV_MESSAGE = "Message"
+ DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
+
+ # result_csv
+ RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
+ RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
+ RESULT_CSV_FILE_NAME = "accuracy_checking_result"
+
+ EPSILON = 1e-8
+
+ class ProcessStatus:
+ SUCCESS = "success"
+ API_NOT_FOUND = "api_not_found"
+ EXCEPTION_SKIP = "exception_skip"
+