From 4b811f917c844c7709dbe087b1e7d5f2fb1da4db Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 19 Jul 2024 10:51:42 +0800 Subject: [PATCH 1/2] add black_list --- .../atat/pytorch/api_accuracy_checker/common/config.py | 9 +++++++++ .../atat/pytorch/api_accuracy_checker/config.yaml | 1 + .../atat/pytorch/api_accuracy_checker/run_ut/run_ut.py | 4 ++++ 3 files changed, 14 insertions(+) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py index 0aceb691b2..a815a1c138 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py @@ -17,6 +17,7 @@ class Config: def validate(self, key, value): validators = { 'white_list': list, + 'black_list': list, 'error_data_path': str, 'precision': int } @@ -35,6 +36,14 @@ class Config: if invalid_api: raise ValueError( f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list") + if key == 'black_list': + if not isinstance(value, list): + raise ValueError("black_list must be a list type") + if not all(isinstance(i, str) for i in value): + raise ValueError("All elements in black_list must be of str type") + invalid_api = [i for i in value if i not in WrapApi] + if invalid_api: + raise ValueError(f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the black_list") return value def __getattr__(self, item): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml index 7f26c72aa3..2dac535dc0 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml @@ -1,4 +1,5 @@ white_list: [] +black_list: [] error_data_path: './' precision: 14 \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py index cd83a95801..294a2a9cbb 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -189,6 +189,10 @@ def run_ut(config): if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api continue try: + if msCheckerConfig.black_list: + [_, api_name, _] = api_full_name.split(Const.SEP) + if api_name in set(msCheckerConfig.black_list): + continue if msCheckerConfig.white_list: [_, api_name, _] = api_full_name.split(Const.SEP) if api_name not in set(msCheckerConfig.white_list): -- Gitee From f4f56fd4098778b9ec7d84aa61ef4c33dff0642e Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 20 Jul 2024 11:27:52 +0800 Subject: [PATCH 2/2] add blacklist --- debug/accuracy_tools/atat/config/config.json | 5 ++ .../accuracy_tools/atat/core/common/const.py | 5 +- .../api_accuracy_checker/common/utils.py | 1 + .../api_accuracy_checker/run_ut/run_ut.py | 57 +++++++++++-------- .../accuracy_tools/atat/pytorch/pt_config.py | 41 +++++++++++++ 5 files changed, 83 insertions(+), 26 deletions(-) diff --git a/debug/accuracy_tools/atat/config/config.json b/debug/accuracy_tools/atat/config/config.json index 70a630a40a..c6077b75ae 100644 --- a/debug/accuracy_tools/atat/config/config.json +++ b/debug/accuracy_tools/atat/config/config.json @@ -24,5 +24,10 @@ "overflow_check": { "overflow_nums": 1, "check_mode":"all" + }, + "run_ut": { + "white_list": [], + "black_list": [], + "error_data_path": "./" } } \ No newline at end of file diff --git a/debug/accuracy_tools/atat/core/common/const.py b/debug/accuracy_tools/atat/core/common/const.py index dea829c3ff..7938f03f51 100644 --- a/debug/accuracy_tools/atat/core/common/const.py +++ b/debug/accuracy_tools/atat/core/common/const.py @@ -15,6 +15,8 @@ class Const: OFF = 'OFF' BACKWARD = 'backward' FORWARD = 'forward' + DEFAULT_LIST = [] + DEFAULT_PATH = './' # dump mode ALL = "all" @@ -52,12 +54,13 @@ class Const: ENV_ENABLE = "1" ENV_DISABLE = "0" MAX_SEED_VALUE = 4294967295 # 2**32 - 1 - TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"] + TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark", "run_ut"] LEVEL_LIST = ["L0", "L1", "L2", "mix"] STATISTICS = "statistics" TENSOR = "tensor" OVERFLOW_CHECK = "overflow_check" FREE_BENCHMARK = "free_benchmark" + RUN_UT = "run_ut" ATTR_NAME_PREFIX = "wrap_" KERNEL_DUMP = "kernel_dump" DATA = "data" diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py index 9e1b02c015..35719b4e51 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py @@ -166,6 +166,7 @@ def initialize_save_path(save_path, dir_name): os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) data_path_checker = FileChecker(data_path, FileCheckConst.DIR) data_path_checker.common_check() + return data_path def write_pt(file_path, tensor): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py index 294a2a9cbb..514e33cb52 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -32,6 +32,7 @@ from atat.pytorch.common.parse_json import parse_json_info_forward_backward from atat.core.common.file_check import FileOpen, FileChecker, \ change_mode, check_file_suffix, check_link, check_path_before_create, create_directory from atat.pytorch.common.log import logger +from atat.pytorch.pt_config import parse_json_config from atat.core.common.const import Const, FileCheckConst, CompareConst current_time = time.strftime("%Y%m%d%H%M%S") @@ -39,7 +40,8 @@ UT_ERROR_DATA_DIR = 'ut_error_data' + current_time RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', - 'save_error_data', 'is_continue_run_ut', 'real_data_path']) + 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list', + 'black_list', 'error_data_path']) not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} not_raise_dtype_set = {'type_as'} @@ -176,8 +178,7 @@ def run_ut(config): logger.info(f"UT task result will be saved in {config.result_csv_path}") logger.info(f"UT task details will be saved in {config.details_csv_path}") if config.save_error_data: - error_data_path = os.path.abspath(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR)) - logger.info(f"UT task error_datas will be saved in {error_data_path}") + logger.info(f"UT task error_datas will be saved in {config.error_data_path}") compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut) with FileOpen(config.result_csv_path, 'r') as file: csv_reader = csv.reader(file) @@ -188,21 +189,17 @@ def run_ut(config): continue if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api continue + [_, api_name, _] = api_full_name.split(Const.SEP) try: - if msCheckerConfig.black_list: - [_, api_name, _] = api_full_name.split(Const.SEP) - if api_name in set(msCheckerConfig.black_list): - continue - if msCheckerConfig.white_list: - [_, api_name, _] = api_full_name.split(Const.SEP) - if api_name not in set(msCheckerConfig.white_list): - continue + if msCheckerConfig.black_list and api_name in config.black_list: + continue + if msCheckerConfig.white_list and api_name not in config.white_list: + continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) if config.save_error_data: - do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) + do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success) except Exception as err: - [_, api_name, _] = api_full_name.split(Const.SEP) if "expected scalar type Long" in str(err): logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") @@ -231,16 +228,16 @@ def is_unsupported_api(api_name): return flag -def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): +def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success): if not is_fwd_success or not is_bwd_success: - processor = UtDataProcessor(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR)) + processor = UtDataProcessor(error_data_path) for element in data_info.in_fwd_data_list: processor.save_tensors_in_element(api_full_name + '.forward.input', element) - processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_out) - processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_out) + processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_output) + processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_output) processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in) - processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad_out) - processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad_out) + processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad) + processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad) def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): @@ -318,14 +315,14 @@ def run_backward(args, grad, grad_index, out): return grad_out -def initialize_save_error_data(): - error_data_path = msCheckerConfig.error_data_path +def initialize_save_error_data(error_data_path): check_path_before_create(error_data_path) create_directory(error_data_path) - error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR, + error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) error_data_path = error_data_path_checker.common_check() - initialize_save_path(error_data_path, UT_ERROR_DATA_DIR) + error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR) + return error_data_path def get_validated_result_csv_path(result_csv_path, mode): @@ -388,6 +385,8 @@ def _run_ut_parser(parser): required=False) parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true", help=" Whether to filter the api in the api_info_file.", required=False) + parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str, + help=" The path of config.json", required=False) def preprocess_forward_content(forward_content): @@ -468,14 +467,22 @@ def run_ut_command(args): if args.result_csv_path: result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') details_csv_path = get_validated_details_csv_path(result_csv_path) + white_list = msCheckerConfig.white_list + black_list = msCheckerConfig.black_list + error_data_path = msCheckerConfig.error_data_path + if args.config_path: + _, task_config = parse_json_config(args.config_path, Const.RUN_UT) + white_list = task_config.white_list + black_list = task_config.black_list + error_data_path = task_config.error_data_path if save_error_data: if args.result_csv_path: time_info = result_csv_path.split('.')[0].split('_')[-1] global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info - initialize_save_error_data() + error_data_path = initialize_save_error_data(error_data_path) run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data, - args.result_csv_path, real_data_path) + args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path) run_ut(run_ut_config) diff --git a/debug/accuracy_tools/atat/pytorch/pt_config.py b/debug/accuracy_tools/atat/pytorch/pt_config.py index 0674b91b34..c0d08fe66f 100644 --- a/debug/accuracy_tools/atat/pytorch/pt_config.py +++ b/debug/accuracy_tools/atat/pytorch/pt_config.py @@ -4,6 +4,10 @@ import os from atat.core.common_config import CommonConfig, BaseConfig from atat.core.common.file_check import FileOpen from atat.core.common.const import Const +from atat.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps + + +WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) class TensorConfig(BaseConfig): @@ -61,6 +65,43 @@ class FreeBenchmarkCheckConfig(BaseConfig): if self.preheat_step and self.preheat_step == 0: raise Exception("preheat_step cannot be 0") + +class RunUTConfig(BaseConfig): + def __init__(self, json_config): + super().__init__(json_config) + self.white_list = json_config.get("white_list", Const.DEFAULT_LIST) + self.black_list = json_config.get("black_list", Const.DEFAULT_LIST) + self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH) + self.check_run_ut_config() + + def check_run_ut_config(self): + self.check_white_list_config() + self.check_black_list_config() + self.check_error_data_path_config() + + def check_white_list_config(self): + if not isinstance(self.white_list, list): + raise Exception("white_list must be a list type") + if not all(isinstance(item, str) for item in self.white_list): + raise Exception("All elements in white_list must be string type") + invalid_api = [item for item in self.white_list if item not in WrapApi] + if invalid_api: + raise Exception("Invalid api in white_list: {}".format(invalid_api)) + + def check_black_list_config(self): + if not isinstance(self.black_list, list): + raise Exception("black_list must be a list type") + if not all(isinstance(item, str) for item in self.black_list): + raise Exception("All elements in black_list must be string type") + invalid_api = [item for item in self.black_list if item not in WrapApi] + if invalid_api: + raise Exception("Invalid api in black_list: {}".format(invalid_api)) + + def check_error_data_path_config(self): + if not os.path.exists(self.error_data_path): + raise Exception("error_data_path: %s is not exist", self.error_data_path) + + def parse_task_config(task, json_config): default_dic = {} if task == Const.TENSOR: -- Gitee