diff --git a/test/distributed/test_allreduce.py b/test/distributed/test_allreduce.py index 5e9a26cbc83f5b2ef73fe54a5c3703e916b1c449..746c8f3d6524e1a134f982c0571c8a624c65fb21 100644 --- a/test/distributed/test_allreduce.py +++ b/test/distributed/test_allreduce.py @@ -122,6 +122,25 @@ class HcomAllReduceTest(TestCase): HcomAllReduceTest._init_dist_hccl, expected, input1, world_size, dist.ReduceOp.AVG) + @skipIfUnsupportMultiNPU(2) + def test_dist_all_reduce_premul_sum(self): + ranks = [2] + dtype_list = [np.float32] + format_list = [0, 2, 3] + shape_format = [[i, j, [2, 3, 16]] + for i in dtype_list + for j in format_list] + factor_list = [2.0, torch.tensor(3.0)] + for world_size in ranks: + for shape in shape_format: + for factor in factor_list: + exp_input, input1 = create_common_tensor(shape, -10, 10) + reduce_op = dist._make_nccl_premul_sum(factor) + expected = self._construct_excepted_result(exp_input, world_size, reduce_op) + self._test_multiprocess(HcomAllReduceTest._test_all_reduce, + HcomAllReduceTest._init_dist_hccl, expected, input1, world_size, + reduce_op) + if __name__ == '__main__': run_tests() diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index dc9ae622481619d41f82aa00896752878bbb268c..e1ce5b70801c348559fdd5561a32e3feddf47c63 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -138,7 +138,19 @@ HcclReduceOp getHcclReduceOp(const c10d::ReduceOp reduceOp, at::Tensor& input) // represent a bool (see hcclDataType mapping). return HCCL_REDUCE_MAX; } - + + if (reduceOp == c10d::ReduceOp::PREMUL_SUM) { + // HCCL does not support ReduceOp::PREMUL_SUM yet + // PTA supports it by multiplying first, then summing + auto supplement = dynamic_cast(reduceOp.supplement_.get()); + if (supplement->tensor_factor.numel() == 0) { + input = input * supplement->double_factor; + } else { + input = input * supplement->tensor_factor; + } + return HCCL_REDUCE_SUM; + } + if (unsupportedOp.find(reduceOp) != unsupportedOp.end()) { TORCH_CHECK(false, "Cannot use ReduceOp." + unsupportedOp[reduceOp] + " with HCCL",