From 988e1061640d60a1757affc8f1a0d640a0b94eff Mon Sep 17 00:00:00 2001 From: medivh-x Date: Tue, 30 Mar 2021 18:35:28 +0800 Subject: [PATCH] support reduce scalar tensor --- tf_adapter_2.x/python/npu_device/_api/distribute/hccl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tf_adapter_2.x/python/npu_device/_api/distribute/hccl.py b/tf_adapter_2.x/python/npu_device/_api/distribute/hccl.py index 26e7aa761..33970a819 100644 --- a/tf_adapter_2.x/python/npu_device/_api/distribute/hccl.py +++ b/tf_adapter_2.x/python/npu_device/_api/distribute/hccl.py @@ -1,4 +1,3 @@ -from collections import Iterable from npu_device._api.distribute import hccl_ops from npu_device.npu_device import global_npu_ctx @@ -26,7 +25,7 @@ def _all_reduce(values, reduction, fusion, fusion_id, group): reduction = 'sum' topo_guarder = tf.group(values) - if isinstance(values, Iterable): + if isinstance(values, (list, tuple,)): reduced_values = [] for value in values: reduced_value = hccl_ops.allreduce(value, reduction, fusion, fusion_id, group) @@ -56,7 +55,7 @@ def all_reduce(values, reduction, fusion=1, fusion_id=-1, group="hccl_world_grou def _broadcast(values, root_rank, fusion, fusion_id, group): - if isinstance(values, Iterable): + if isinstance(values, (list, tuple,)): for value in values: value.assign(hccl_ops.broadcast([value], root_rank, fusion, fusion_id, group)[0]) else: -- Gitee