diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 94907f01493272f95641047074c908c96a9449b8..9304a9ca4fe0bbeabcbb6bea6fb97fad61c42163 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -36,7 +36,14 @@ except ImportError: else: IS_GPU = False -if not IS_GPU: +torch_without_guard_version_list = ['2.1'] +for version in torch_without_guard_version_list: + if torch.__version__.startswith(version): + torch_without_guard_version = True + break + else: + torch_without_guard_version = False +if not IS_GPU and not torch_without_guard_version: from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard device = collections.namedtuple('device', ['type', 'index']) @@ -456,7 +463,7 @@ def format_value(value): def torch_device_guard(func): - if IS_GPU: + if IS_GPU or torch_without_guard_version: return func # Parse args/kwargs matched torch.device objects