diff --git a/docs/mindspore/source_zh_cn/model_infer/ms_infer/parallel.md b/docs/mindspore/source_zh_cn/model_infer/ms_infer/parallel.md index 756dacdededd182fd26a57b22d13583a97712155..e8553bb9d0a0246d3e71c0014fc130ad43c094a8 100644 --- a/docs/mindspore/source_zh_cn/model_infer/ms_infer/parallel.md +++ b/docs/mindspore/source_zh_cn/model_infer/ms_infer/parallel.md @@ -147,7 +147,7 @@ def construct(self, x): origin_dtype = x.dtype x = self.cast(x, self.dtype) - out = self.matmul(x, self.weight) + out = self.matmul(x, ms.numpy.broadcast_to(self.weight, x.shape[:-2] + self.weight.shape)) if self.has_bias: output_parallel = self.bias_add( output_parallel, self.cast(self.bias, self.dtype)