diff --git a/docs/mindspore/source_en/model_infer/ms_infer/parallel.md b/docs/mindspore/source_en/model_infer/ms_infer/parallel.md index 7a8c02c20688013af442a043cd4cee82316eac1a..90d1b16851dedb7b13b0bee7a50f9e41a1606fe8 100644 --- a/docs/mindspore/source_en/model_infer/ms_infer/parallel.md +++ b/docs/mindspore/source_en/model_infer/ms_infer/parallel.md @@ -226,10 +226,10 @@ Starting with the original implementation of `nn.Dense` in MindSpore, we can bui def construct(self, x): origin_dtype = x.dtype x = self.cast(x, self.dtype) - out = self.bmm(x, self.weight) + output_parallel = self.bmm(x, self.weight) if self.has_bias: output_parallel = self.bias_add(output_parallel, self.cast(self.bias, self.dtype)) - out = self.all_reduce(out) + out = self.all_reduce(output_parallel) out = self.cast(out, origin_dtype) return out ``` diff --git a/docs/mindspore/source_zh_cn/model_infer/ms_infer/parallel.md b/docs/mindspore/source_zh_cn/model_infer/ms_infer/parallel.md index 593939780b3196bb5eefbf8533534b80f711d636..6d92b3815ddfaa96cd7939e33b26c32b098b22ea 100644 --- a/docs/mindspore/source_zh_cn/model_infer/ms_infer/parallel.md +++ b/docs/mindspore/source_zh_cn/model_infer/ms_infer/parallel.md @@ -224,10 +224,10 @@ def construct(self, x): origin_dtype = x.dtype x = self.cast(x, self.dtype) - out = self.bmm(x, self.weight) + output_parallel = self.bmm(x, self.weight) if self.has_bias: output_parallel = self.bias_add(output_parallel, self.cast(self.bias, self.dtype)) - out = self.all_reduce(out) + out = self.all_reduce(output_parallel) out = self.cast(out, origin_dtype) return out ``` diff --git a/docs/sample_code/infer_code/model_dev.py b/docs/sample_code/infer_code/model_dev.py index ec0948f0d57177093e28049f21a9668b640ffd46..814ee81dbb528b50371dc08fb33f3aa7a6378c7f 100644 --- a/docs/sample_code/infer_code/model_dev.py +++ b/docs/sample_code/infer_code/model_dev.py @@ -21,7 +21,6 @@ from mindspore import Parameter, Tensor, nn, ops from mindspore.common import dtype as mstype from mindspore.common.initializer import initializer from mindspore.communication import create_group, get_group_size, get_rank, init -import mindspore.communication as comm class ConfigHelper: @@ -35,8 +34,7 @@ class ConfigHelper: batch_size, seq_length, dtype, num_heads, - has_bias=False - ): + has_bias=False): self.vocab_size = vocab_size self.hidden_size = hidden_size self.ffn_hidden_size = ffn_hidden_size @@ -115,7 +113,6 @@ class ColumnParallelLinear(nn.Cell): return state_dict - class GatherLastDim(nn.Cell): """ Gather the last dimension across all parallel ranks """ @@ -123,9 +120,10 @@ class GatherLastDim(nn.Cell): super().__init__() self.world_size = COMMUN_HELPER.get_tensor_model_parallel_group_size() self.split = ops.Split(axis=0, output_num=self.world_size) + self.all_gather = ops.AllGather(group=COMMUN_HELPER.get_tensor_model_parallel_group()) def construct(self, input_): - output = comm.comm_func.all_gather_into_tensor(input_, group=COMMUN_HELPER.get_tensor_model_parallel_group())[0] + output = self.all_gather(input_) tensor_list = self.split(output) output = ops.cat(tensor_list, axis=-1) return output @@ -155,6 +153,7 @@ class RowParallelLinear(nn.Cell): self.bias = Parameter(initializer(bias_init, (self.in_channels_per_partition), self.dtype), name="bias") self.bias_add = ops.Add() self.bmm = ops.BatchMatMul(transpose_b=True) + self.all_reduce = ops.AllReduce(group=COMMUN_HELPER.get_tensor_model_parallel_group()) self.cast = ops.Cast() def construct(self, x): @@ -163,7 +162,7 @@ class RowParallelLinear(nn.Cell): output_parallel = self.bmm(x, self.weight) if self.has_bias: output_parallel = self.bias_add(output_parallel, self.cast(self.bias, self.dtype)) - output = comm.comm_func.all_reduce(output_parallel, group=COMMUN_HELPER.get_tensor_model_parallel_group())[0] + output = self.all_reduce(output_parallel) output = self.cast(output, origin_dtype) return output @@ -205,6 +204,7 @@ class VocabParallelEmbedding(nn.Cell): ), name="embedding_weight", ) + self.all_reduce = ops.AllReduce(group=COMMUN_HELPER.get_tensor_model_parallel_group()) self.max_index_per_partition = Tensor(self.num_embeddings_per_partition - 1, dtype=mstype.int32) self.expand_dims = ops.ExpandDims() self.gather = ops.Gather() @@ -222,7 +222,7 @@ class VocabParallelEmbedding(nn.Cell): input_mask = self.expand_dims(input_mask, -1) output_parallel = self.gather(self.embedding_weight, truncated_x, 0) output_parallel = self.mul(output_parallel, input_mask) - output = comm.comm_func.all_reduce(output_parallel, group=COMMUN_HELPER.get_tensor_model_parallel_group())[0] + output = self.all_reduce(output_parallel) return output def sharded_state_dict(self):