From 6a697770b2396dccf9661f01a860371cae05d2dc Mon Sep 17 00:00:00 2001 From: suyafeng s00639171 Date: Tue, 8 Aug 2023 17:51:11 +0800 Subject: [PATCH] support NZ grad backward --- torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp b/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp index f4c037313a2..2f3dccabd8d 100644 --- a/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp +++ b/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp @@ -182,9 +182,13 @@ std::tuple matmul_backward(const at::Tensor &grad, if (!grad.defined()) { return std::make_tuple(at::Tensor(), at::Tensor()); } + + // NZ -> ND + auto grad_nd = NPUNativeFunctions::npu_format_cast(grad, ACL_FORMAT_ND); + // backward mat1 and mat2 separately - auto self_grad = matmul_mat1_backward(self, other, grad); - auto other_grad = matmul_mat2_backward(self, other, grad); + auto self_grad = matmul_mat1_backward(self, other, grad_nd); + auto other_grad = matmul_mat2_backward(self, other, grad_nd); // strip added dim: (5,1)->(5) if (other.dim() == 1 && other_grad.size(-1) == 1) { -- Gitee