diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_scope.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_scope.py index 63eb0ce6544f87a3009bfaf367d25ed79145c156..20547f5879dadb73794dbd8b359c2619b3591c90 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_scope.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_scope.py @@ -218,3 +218,23 @@ def disable_autofuse(): } with ops.get_default_graph()._attr_scope(attrs): yield + + +@contextlib.contextmanager +def limit_core_num_scope(op_aicore_num="0", op_vectorore_num="0"): + """ + Limit the aic abd aiv core num of autofuse operators within the scope. + """ + if not isinstance(op_vectorore_num, str): + raise ValueError("Param op_vectorore_num must be string.") + try: + int_vector_core_num = int(op_vectorore_num) + except ValueError: + raise ValueError("Param op_vectorore_num can not be translated into a valid int number.") + if not (0 < int_vector_core_num < 48): + raise ValueError("Param op_vectorore_num must be in a valid range.") + attrs = { + "_op_vectorcore_num": attr_value_pb2.AttrValue(s=compat.as_bytes(op_vectorore_num)) + } + with ops.get_default_graph()._attr_scope(attrs): + yield