diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 5e67fd0078741619355f6d46c7bf9855e4351aee..a321f5820c5865fcdfb4eb20167ffdc4d0b4cd3c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -21,6 +21,7 @@ class Config: 'enable_dataloader': bool, 'target_iter': list, 'white_list': list, + 'black_list': list, 'error_data_path': str, 'precision': int, 'is_online': bool, @@ -53,6 +54,14 @@ class Config: invalid_api = [i for i in value if i not in WrapApi] if invalid_api: raise ValueError(f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list") + if key == 'black_list': + if not isinstance(value, list): + raise ValueError("black_list must be a list type") + if not all(isinstance(i, str) for i in value): + raise ValueError("All elements in black_list must be of str type") + invalid_api = [i for i in value if i not in WrapApi] + if invalid_api: + raise ValueError(f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the black_list") if key == 'nfs_path': if value and not os.path.exists(value): raise ValueError(f"nfs path {value} doesn't exist.") @@ -64,13 +73,14 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path=None, real_data=None, target_iter=None, white_list=None, enable_dataloader=None, - is_online=None, is_benchmark_device=True, port=None, host=None, rank_list=None): + def update_config(self, dump_path=None, real_data=None, target_iter=None, white_list=None, black_list=None, + enable_dataloader=None, is_online=None, is_benchmark_device=True, port=None, host=None, rank_list=None): args = { "dump_path": dump_path if dump_path is not None else self.config.get("dump_path", './'), "real_data": real_data if real_data is not None else self.config.get("real_data", False), "target_iter": target_iter if target_iter is not None else self.config.get("target_iter", [1]), "white_list": white_list if white_list is not None else self.config.get("white_list", []), + "black_list": black_list if black_list is not None else self.config.get("black_list", []), "enable_dataloader": enable_dataloader if enable_dataloader is not None else self.config.get("enable_dataloader", False), "is_online": is_online if is_online is not None else self.config.get("is_online", False), diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index c660181572298bb2e7e72724e5654968bfc59149..5a281cacede401eb7e79d3178dbfaf8595aedd98 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -3,6 +3,7 @@ real_data: False enable_dataloader: False target_iter: [1] white_list: [] +black_list: [] error_data_path: './' precision: 14 is_online: False diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_aten.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_aten.py index b51e93fb5801da79bbd5da88b8c30cb17a884e7f..e7da8320d48b3b3fcd4aba25415f61c9b4c40cba 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_aten.py @@ -19,6 +19,7 @@ import torch from api_accuracy_checker.hook_module.hook_module import HOOKModule from api_accuracy_checker.common.utils import torch_device_guard +from api_accuracy_checker.common.config import msCheckerConfig from api_accuracy_checker.hook_module.utils import WrapAtenOps, WhiteAtenOps from api_accuracy_checker.common.function_factory import npu_custom_grad_functions @@ -31,7 +32,13 @@ for f in dir(torch.ops.aten): def get_aten_ops(): global WrapAtenOps _all_aten_ops = dir(torch.ops.aten) - return set(WrapAtenOps) & set(_all_aten_ops) + available_ops = set(WrapAtenOps) & set(_all_aten_ops) + if msCheckerConfig.black_list: + available_ops = available_ops - set(msCheckerConfig.black_list) + if msCheckerConfig.white_list: + return available_ops & set(msCheckerConfig.white_list) + else: + return available_ops class HOOKAtenOP(object): diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py index 967e9efc84123533422240c1f1fe530e73800bf0..53c0a9f80eb22f88f14549e5d69096a14ade7255 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py @@ -33,10 +33,13 @@ for f in dir(torch.nn.functional): def get_functional_ops(): global WrapFunctionalOps _all_functional_ops = dir(torch.nn.functional) + available_ops = set(WrapFunctionalOps) & set(_all_functional_ops) + if msCheckerConfig.black_list: + available_ops = available_ops - set(msCheckerConfig.black_list) if msCheckerConfig.white_list: - return set(WrapFunctionalOps) & set(_all_functional_ops) & set(msCheckerConfig.white_list) + return available_ops & set(msCheckerConfig.white_list) else: - return set(WrapFunctionalOps) & set(_all_functional_ops) + return available_ops class HOOKFunctionalOP(object): diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py index a4a42df1102d752ad3b23221ff639c7c0f17878e..17bffa3ba6c70ccc58140413bb5c9d28ff14122c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_npu_custom.py @@ -37,11 +37,13 @@ def get_npu_ops(): _npu_ops = dir(torch.ops.npu) else: _npu_ops = dir(torch_npu._C._VariableFunctionsClass) - + available_ops = set(WrapNPUOps) & set(_npu_ops) + if msCheckerConfig.black_list: + available_ops = available_ops - set(msCheckerConfig.black_list) if msCheckerConfig.white_list: - return set(WrapNPUOps) & set(_npu_ops) & set(msCheckerConfig.white_list) + return available_ops & set(msCheckerConfig.white_list) else: - return set(WrapNPUOps) & set(_npu_ops) + return available_ops class HOOKNpuOP(object): diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py index d60cac74baf15872854d71089df7cdd81925746e..7b098277328193cd06c1f8fd1d72bf6b8ab4296b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py @@ -31,10 +31,13 @@ from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import parameter_adapter def get_tensor_ops(): global WrapTensorOps _tensor_ops = dir(torch._C._TensorBase) + available_ops = set(WrapTensorOps) & set(_tensor_ops) + if msCheckerConfig.black_list: + available_ops = available_ops - set(msCheckerConfig.black_list) if msCheckerConfig.white_list: - return set(WrapTensorOps) & set(_tensor_ops) & set(msCheckerConfig.white_list) + return available_ops & set(msCheckerConfig.white_list) else: - return set(WrapTensorOps) & set(_tensor_ops) + return available_ops class HOOKTensor(object): diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py index 9fbe343d374364a5904f36d0d82189ff856e14b3..54649d7ca5921c4ca948ae16261dfb14a1d180f2 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py @@ -30,10 +30,13 @@ from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen def get_torch_ops(): global WrapTorchOps _torch_ops = dir(torch._C._VariableFunctionsClass) + available_ops = set(WrapTorchOps) & set(_torch_ops) + if msCheckerConfig.black_list: + available_ops = available_ops - set(msCheckerConfig.black_list) if msCheckerConfig.white_list: - return set(WrapTorchOps) & set(_torch_ops) & set(msCheckerConfig.white_list) + return available_ops & set(msCheckerConfig.white_list) else: - return set(WrapTorchOps) & set(_torch_ops) + return available_ops class HOOKTorchOP(object): diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 4f5a85e2d64e9e80ea1098d078d0df4304b24cfb..a91a7334ed199d8d24f7eedbc21ed5d2eda697ae 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -205,6 +205,10 @@ def run_api_offline(config, compare, api_name_set): if api_full_name in api_name_set: continue try: + if msCheckerConfig.black_list: + [_, api_name, _] = api_full_name.split(Const.DELIMITER) + if api_name in set(msCheckerConfig.black_list): + continue if msCheckerConfig.white_list: [_, api_name, _] = api_full_name.split(Const.DELIMITER) if api_name not in set(msCheckerConfig.white_list): @@ -252,7 +256,10 @@ def run_api_online(config, compare): if not isinstance(api_data, ApiData): continue api_full_name = api_data.name - + if msCheckerConfig.black_list: + [_, api_name, _] = api_full_name.split(Const.DELIMITER) + if api_name in set(msCheckerConfig.black_list): + continue if msCheckerConfig.white_list: [_, api_name, _] = api_full_name.split(Const.DELIMITER) if api_name not in set(msCheckerConfig.white_list):