diff --git a/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp b/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp index f4c037313a265a583555a43186dd32a35aa0c30a..2f3dccabd8dc4606a28855f661e9627ce3bec33c 100644 --- a/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp +++ b/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp @@ -182,9 +182,13 @@ std::tuple matmul_backward(const at::Tensor &grad, if (!grad.defined()) { return std::make_tuple(at::Tensor(), at::Tensor()); } + + // NZ -> ND + auto grad_nd = NPUNativeFunctions::npu_format_cast(grad, ACL_FORMAT_ND); + // backward mat1 and mat2 separately - auto self_grad = matmul_mat1_backward(self, other, grad); - auto other_grad = matmul_mat2_backward(self, other, grad); + auto self_grad = matmul_mat1_backward(self, other, grad_nd); + auto other_grad = matmul_mat2_backward(self, other, grad_nd); // strip added dim: (5,1)->(5) if (other.dim() == 1 && other_grad.size(-1) == 1) {