diff --git a/tf_adapter_2.x/python/npu_device/_api/distribute/hccl.py b/tf_adapter_2.x/python/npu_device/_api/distribute/hccl.py index 26e7aa76198bc3ffc4b291b999ab27426870712a..33970a819d986e04bd7ad46382debeeea9ec5366 100644 --- a/tf_adapter_2.x/python/npu_device/_api/distribute/hccl.py +++ b/tf_adapter_2.x/python/npu_device/_api/distribute/hccl.py @@ -1,4 +1,3 @@ -from collections import Iterable from npu_device._api.distribute import hccl_ops from npu_device.npu_device import global_npu_ctx @@ -26,7 +25,7 @@ def _all_reduce(values, reduction, fusion, fusion_id, group): reduction = 'sum' topo_guarder = tf.group(values) - if isinstance(values, Iterable): + if isinstance(values, (list, tuple,)): reduced_values = [] for value in values: reduced_value = hccl_ops.allreduce(value, reduction, fusion, fusion_id, group) @@ -56,7 +55,7 @@ def all_reduce(values, reduction, fusion=1, fusion_id=-1, group="hccl_world_grou def _broadcast(values, root_rank, fusion, fusion_id, group): - if isinstance(values, Iterable): + if isinstance(values, (list, tuple,)): for value in values: value.assign(hccl_ops.broadcast([value], root_rank, fusion, fusion_id, group)[0]) else: