diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 7d355b7b9b8854d3819d9fb6562451e69f5fc60d..f9b882f47b107a474b9f888e76fe94defd2b26fe 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -1,8 +1,11 @@ import os import yaml from api_accuracy_checker.common.utils import check_file_or_directory_path +from api_accuracy_checker.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) + class Config: def __init__(self, yaml_file): @@ -19,7 +22,8 @@ class Config: 'dump_step': int, 'error_data_path': str, 'target_iter': list, - 'precision': int + 'precision': int, + 'white_list': list } if not isinstance(value, validators.get(key)): raise ValueError(f"{key} must be {validators[key].__name__} type") @@ -34,6 +38,14 @@ class Config: raise ValueError("All elements in target_iter must be greater than or equal to 0") if key == 'precision' and value < 0: raise ValueError("precision must be greater than 0") + if key == 'white_list': + if not isinstance(value, list): + raise ValueError("white_list must be a list type") + if not all(isinstance(i, str) for i in value): + raise ValueError("All elements in white_list must be of str type") + invalid_api = [i for i in value if i not in WrapApi] + if invalid_api: + raise ValueError(f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list") return value def __getattr__(self, item): @@ -42,13 +54,14 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path, real_data=False, target_iter=None): + def update_config(self, dump_path=None, real_data=False, target_iter=None, white_list=None): if target_iter is None: target_iter = self.config.get('target_iter',[1]) args = { - "dump_path": dump_path, + "dump_path": dump_path if dump_path else self.config.get("dump_path", './'), "real_data": real_data, - "target_iter": target_iter + "target_iter": target_iter if target_iter else self.config.get("target_iter", [1]), + "white_list": white_list if white_list else self.config.get("white_list", []) } for key, value in args.items(): if key in self.config: diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 4a1420eb4636b71530394d7ee64cf39af7a8523a..0bd145893e83c00c5aff120b82e537d28f4664eb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -5,4 +5,5 @@ dump_step: 1000 error_data_path: './' target_iter: [1] precision: 14 +white_list: [] \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index fc1a57bc7b6e9ac72a8090ce0714497eb66bcb2f..677ac2c5c206a4ab97132808994e04999a4398d7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -49,7 +49,7 @@ class DumpUtil(object): if DumpUtil.call_num in msCheckerConfig.target_iter: set_dump_switch("ON") elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num)) + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) else: set_dump_switch("OFF") DumpUtil.call_num += 1 diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d16ac993ed45faa0f9b48bb64050592e15ef4d2 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import yaml + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + Ops = yaml.safe_load(f) + WrapFunctionalOps = Ops.get('functional') + WrapTensorOps = Ops.get('tensor') + WrapTorchOps = Ops.get('torch') \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py index a1c9af127f25afc713bd16a1153cf3bfcab292f2..e224838c67f6f4210370574f1d5fc43226c0a7de 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py @@ -22,13 +22,10 @@ import yaml from api_accuracy_checker.hook_module.hook_module import HOOKModule from api_accuracy_checker.common.utils import torch_device_guard +from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.hook_module.utils import WrapFunctionalOps from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - WrapFunctionalOps = yaml.safe_load(f).get('functional') - for f in dir(torch.nn.functional): locals().update({f: getattr(torch.nn.functional, f)}) @@ -36,7 +33,10 @@ for f in dir(torch.nn.functional): def get_functional_ops(): global WrapFunctionalOps _all_functional_ops = dir(torch.nn.functional) - return set(WrapFunctionalOps) & set(_all_functional_ops) + if msCheckerConfig.white_list: + return set(WrapFunctionalOps) & set(_all_functional_ops) & set(msCheckerConfig.white_list) + else: + return set(WrapFunctionalOps) & set(_all_functional_ops) class HOOKFunctionalOP(object): diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py index 547955ec1876e1f4a464ff0268ec845c46b2f1a9..03e73d4b2d3bde91c2644770bb0f9b32850984c5 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py @@ -22,18 +22,18 @@ import yaml from api_accuracy_checker.hook_module.hook_module import HOOKModule from api_accuracy_checker.common.utils import torch_device_guard +from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.hook_module.utils import WrapTensorOps from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - WrapTensorOps = yaml.safe_load(f).get('tensor') - def get_tensor_ops(): global WrapTensorOps _tensor_ops = dir(torch._C._TensorBase) - return set(WrapTensorOps) & set(_tensor_ops) + if msCheckerConfig.white_list: + return set(WrapTensorOps) & set(_tensor_ops) & set(msCheckerConfig.white_list) + else: + return set(WrapTensorOps) & set(_tensor_ops) class HOOKTensor(object): diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py index 70461b0acc211eca8b4a0cf1832fb304119c7ae2..458af67ae95acaada158b2c25a191838a17a02db 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py @@ -22,19 +22,18 @@ import yaml from api_accuracy_checker.hook_module.hook_module import HOOKModule from api_accuracy_checker.common.utils import torch_device_guard - +from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.hook_module.utils import WrapTorchOps from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - WrapTorchOps = yaml.safe_load(f).get('torch') - def get_torch_ops(): global WrapTorchOps _torch_ops = dir(torch._C._VariableFunctionsClass) - return set(WrapTorchOps) & set(_torch_ops) + if msCheckerConfig.white_list: + return set(WrapTorchOps) & set(_torch_ops) & set(msCheckerConfig.white_list) + else: + return set(WrapTorchOps) & set(_torch_ops) class HOOKTorchOP(object): diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py index 1174326565822ba687ed4f9922770ea51654bfff..fb455434cf5464d416b16cb6a903aa9d29c2236b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py @@ -4,17 +4,13 @@ import sys import torch_npu import torch from tqdm import tqdm -from api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, run_backward, init_environment, \ - get_api_info +from api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, run_backward, get_api_info from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ print_error_log from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, check_file_suffix, check_link -init_environment() - - def check_tensor_overflow(x): if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool: if len(x.shape) == 0: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 5c67338d4e44bc85f4d415833ee06cd451498d48..0c0f3305c7104e87f64d6002996ed63c342c2eb9 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -32,19 +32,6 @@ from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen ut_error_data_dir = 'ut_error_data' -def init_environment(): - cur_path = os.path.dirname(os.path.realpath(__file__)) - yaml_path = os.path.join(cur_path, "../hook_module/support_wrap_ops.yaml") - with FileOpen(yaml_path, 'r') as f: - WrapFunctionalOps = yaml.safe_load(f).get('functional') - for f in dir(torch.nn.functional): - if f != "__name__": - locals().update({f: getattr(torch.nn.functional, f)}) - - -init_environment() - - def exec_api(api_type, api_name, args, kwargs): if api_type == "Functional": functional_api = FunctionalOPTemplate(api_name, str, False) @@ -125,6 +112,10 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare = Comparator(out_path) for api_full_name, api_info_dict in tqdm(forward_content.items()): try: + if msCheckerConfig.white_list: + [_, api_name, _] = api_full_name.split("*") + if api_name not in set(msCheckerConfig.white_list): + continue data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info.bench_out,