From 77aab60b8b34c369e02286598970c5ceb34c5a70 Mon Sep 17 00:00:00 2001 From: c00420053 Date: Wed, 25 Oct 2023 16:08:53 +0800 Subject: [PATCH] set inf_nan default enable --- tf_adapter/python/npu_bridge/__init__.py | 2 ++ tf_adapter_2.x/python/npu_device/npu_device.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tf_adapter/python/npu_bridge/__init__.py b/tf_adapter/python/npu_bridge/__init__.py index 428dd33a0..a7e296c54 100644 --- a/tf_adapter/python/npu_bridge/__init__.py +++ b/tf_adapter/python/npu_bridge/__init__.py @@ -23,6 +23,8 @@ from npu_bridge.helper import helper from npu_bridge.estimator.npu import npu_estimator from npu_bridge.hccl import hccl_ops from npu_bridge.estimator.npu.npu_plugin import npu_close +from npu_bridge.estimator.npu import npu_plugin +npu_plugin.set_device_sat_mode(1) atexit.register(npu_close) __all__ = [_s for _s in dir() if not _s.startswith('_')] diff --git a/tf_adapter_2.x/python/npu_device/npu_device.py b/tf_adapter_2.x/python/npu_device/npu_device.py index b0036738f..0e0752a8e 100644 --- a/tf_adapter_2.x/python/npu_device/npu_device.py +++ b/tf_adapter_2.x/python/npu_device/npu_device.py @@ -168,6 +168,7 @@ def open(device_id=None): """Initiate and return a NPU device handle""" if device_id is None: device_id = int(os.getenv("ASCEND_DEVICE_ID", '0')) + set_device_sat_mode(1) with _npu_ctx_lock: if not isinstance(context.context(), _ContextWithDefaultDevice): -- Gitee