diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_scope.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_scope.py index 72702e146f32e0bc4a1f3931e186b4a28ce30399..f098e444e07ebcb53ee0deec6f85db8c9fc0a9c7 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_scope.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_scope.py @@ -30,5 +30,16 @@ def npu_variable_scope(placement=NpuExecutePlacement.ALL): attrs = { "_variable_placement" : attr_value_pb2.AttrValue(s=compat.as_bytes(placement.value)) } + with ops.get_default_graph()._attr_scope(attrs): + yield + +@contextlib.contextmanager +def keep_dtype_scope(): + """ + Specify which layers retain the original precision. + """ + attrs = { + "_keep_dtype": attr_value_pb2.AttrValue(b=True) + } with ops.get_default_graph()._attr_scope(attrs): yield \ No newline at end of file