diff --git a/codegen/autograd/aclnn_derivatives.yaml b/codegen/autograd/aclnn_derivatives.yaml index 89e061624f28acda6cd95923bb4043ba61caf44d..d2a3c3c11d8a97dd7ee4fc3c8808a474f4ac4b8e 100644 --- a/codegen/autograd/aclnn_derivatives.yaml +++ b/codegen/autograd/aclnn_derivatives.yaml @@ -1,3 +1,5 @@ # Defines derivative formulas and Python signatures of methods on Variable # #If you need any guidance, please refer to the comments in derivatives.yaml in PyTorch. +- name: matmul(Tensor self, Tensor other) -> Tensor + self, other: matmul_backward(grad, self, other, grad_input_mask) diff --git a/test/test_network_ops/test_matmul.py b/test/test_network_ops/test_matmul.py index 7ade8c135f7932ed95122655d73e50c1df59914a..3b8b8033ec7a1b400340b653fbdcafd2cf50b472 100644 --- a/test/test_network_ops/test_matmul.py +++ b/test/test_network_ops/test_matmul.py @@ -86,7 +86,7 @@ class TestMatMul(TestCase): npu_output = npu_output.cpu() return npu_output.detach().cpu().numpy(), input1.grad.cpu().numpy(), input2.grad.cpu().numpy() - def matmul_backward_result(self, shape_format): + def matmul_backward_result(self, shape_format, transpose_mat1=False, transpose_mat2=False): for item in shape_format: mat1_cpu, mat1_npu = create_common_tensor(item[0], -10, 10) if mat1_cpu.dtype == torch.float16: @@ -94,6 +94,15 @@ class TestMatMul(TestCase): mat2_cpu, mat2_npu = create_common_tensor(item[1], -10, 10) if mat2_cpu.dtype == torch.float16: mat2_cpu = mat2_cpu.to(torch.float32) + + if transpose_mat1: + mat1_cpu = mat1_cpu.transpose(-2, -1) + mat1_npu = mat1_npu.transpose(-2, -1) + + if transpose_mat2: + mat2_cpu = mat2_cpu.transpose(-2, -1) + mat2_npu = mat2_npu.transpose(-2, -1) + cpu_output, cpu_mat1_grad, cpu_mat2_grad = self.op_exec_cpu(mat1_cpu, mat2_cpu) npu_output, npu_mat1_grad, npu_mat2_grad = self.op_exec_npu(mat1_npu, mat2_npu) @@ -198,5 +207,56 @@ class TestMatMul(TestCase): self.matmul_backward_result(shape_format) torch.npu.matmul.allow_hf32 = False + def test_matmul_backward_big_memory(self): + torch.npu.matmul.allow_hf32 = True + shape_format = [ + [[np.float16, 2, [8192, 4, 5120]], [np.float16, 2, [5120, 3416]]], + ] + self.matmul_backward_result(shape_format) + torch.npu.matmul.allow_hf32 = False + + def test_matmul_backward_transpose_right(self): + torch.npu.matmul.allow_hf32 = True + shape_format = [ + [[np.float16, 2, [5, 4, 2]], [np.float16, 2, [3, 2]]], + [[np.float16, 2, [3, 2, 1, 2]], [np.float16, 2, [3, 2]]], + [[np.float16, 2, [4, 2]], [np.float16, 2, [3, 2]]], + [[np.float16, 2, [10, 2]], [np.float16, 2, [4, 3, 2]]], + ] + self.matmul_backward_result(shape_format, False, True) + torch.npu.matmul.allow_hf32 = False + + def test_matmul_backward_transpose_left(self): + torch.npu.matmul.allow_hf32 = True + shape_format = [ + [[np.float16, 2, [2, 3, 4]], [np.float16, 2, [3, 2]]], # 8, 4 + [[np.float16, 2, [3, 2, 3, 4]], [np.float16, 2, [3, 2]]],#8, 4 + [[np.float16, 2, [3, 4]], [np.float16, 2, [3, 2]]],#2,2 + [[np.float16, 2, [3, 4]], [np.float16, 2, [5, 3, 2]]], ### !!!!!!!!!!!! 8, 10 + ] + self.matmul_backward_result(shape_format, True, False) + torch.npu.matmul.allow_hf32 = False + + def test_matmul_backward_transpose_both_bmm(self): + torch.npu.matmul.allow_hf32 = True + shape_format = [ + [[np.float16, 2, [5, 3, 4]], [np.float16, 2, [5, 2, 3]]], # 8, 4 + [[np.float16, 2, [5, 2, 3, 4]], [np.float16, 2, [2, 2, 3]]],#8, 4 + ] + self.matmul_backward_result(shape_format, True, True) + torch.npu.matmul.allow_hf32 = False + + + def test_matmul_backward_transpose_both(self): + torch.npu.matmul.allow_hf32 = True + shape_format = [ + [[np.float16, 2, [2, 3, 4]], [np.float16, 2, [2,3]]], + [[np.float16, 2, [3, 2, 3, 4]], [np.float16, 2, [2,3]]], + [[np.float16, 2, [3, 4]], [np.float16, 2, [2,3]]], + [[np.float16, 2, [3, 4]], [np.float16, 2, [5, 2, 3]]], ### !!!!!!!!!!!! + ] + self.matmul_backward_result(shape_format, True, True) + torch.npu.matmul.allow_hf32 = False + if __name__ == "__main__": run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index cd502a07f775039502cd7260e6cf7bce8fcf57ee..0d76bfdb800461e687c7a3a8ca0d8f4af57b0444 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -753,7 +753,7 @@ supported: op_api: True - is_pinned - func: is_set_to - op_api: False + op_api: False - func: isclose op_api: True - func: isfinite @@ -1672,12 +1672,12 @@ supported: op_api: False - func: silu_backward op_api: False + - func: matmul + op_api: True + - func: matmul.out + op_api: True autograd: - - func: matmul - op_api: False - - func: matmul.out - op_api: False - celu - celu_ - func: elu.out @@ -1977,6 +1977,8 @@ unsupported: - special_zeta.other_scalar_out custom: + - func: matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] grad_input_mask) -> (Tensor, Tensor) + op_api: True - func: _npu_storage_resize(Tensor self, int size) -> Tensor - func: npu_change_data_ptr(Tensor dst, Tensor src, int index) -> int - func: npu_transpose(Tensor self, int[] perm, bool require_contiguous=True) -> Tensor diff --git a/torch_npu/csrc/aten/ops/MatmulKernelNpu.cpp b/torch_npu/csrc/aten/ops/MatmulKernelNpu.cpp index 59b1c376e59b704718802f04bcff5648b502fead..9e878ecdf833355c874ee54c2209847cda97d6e5 100644 --- a/torch_npu/csrc/aten/ops/MatmulKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/MatmulKernelNpu.cpp @@ -155,5 +155,13 @@ at::Tensor& NPUNativeFunctions::matmul_out(const at::Tensor & tensor1, const at: at::namedinference::propagate_names_if_nonempty(result, maybe_outnames); return result; } + +std::tuple NPUNativeFunctions::matmul_backward( + const at::Tensor &grad, const at::Tensor &self, const at::Tensor &other, + std::array grad_input_mask) { + AT_ERROR("!!!!!Caution temp codes!!!!!!!!!!! "); + return std::make_tuple(at::Tensor(), at::Tensor()); +} + } // namespace native } // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp b/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp index 97cb53768083bb7b52fb065d2576bf51261cb680..96c5d47e936e704188dcf4cb57dd85dbd7563ddb 100644 --- a/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp +++ b/torch_npu/csrc/aten/ops/op_api/MatmulKernelNpuOpApi.cpp @@ -14,10 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "torch_npu/csrc/aten/NPUNativeOpApiFunctions.h" #include "torch_npu/csrc/aten/ops/op_api/op_api_common.h" #include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include namespace at_npu { namespace native { @@ -29,14 +29,15 @@ using torch::autograd::AutogradContext; using torch::autograd::Function; using tensor_list = std::vector; -static c10::SmallVector get_output_size(const at::Tensor &tensor1, - const at::Tensor &tensor2) { +static c10::SmallVector +get_output_size(const at::Tensor &tensor1, const at::Tensor &tensor2) { c10::SmallVector output_size; auto dim_tensor1 = tensor1.dim(); auto dim_tensor2 = tensor2.dim(); TORCH_CHECK(dim_tensor1 > 0 && dim_tensor2 > 0, - "matmul got error dimentions: ", "(", dim_tensor1, ", ", dim_tensor2, ")"); + "matmul got error dimentions: ", "(", dim_tensor1, ", ", + dim_tensor2, ")"); if (dim_tensor1 == 1 && dim_tensor2 == 1) { output_size = {}; @@ -76,13 +77,15 @@ static c10::SmallVector get_output_size(const at::Tensor &tensor1 at::IntArrayRef batch_tensor1(tensor1.sizes().data(), dim_tensor1 - 2); int64_t p = tensor2.size(-1); at::IntArrayRef batch_tensor2(tensor2.sizes().data(), dim_tensor2 - 2); - std::vector expand_batch_portion = at::infer_size(batch_tensor1, batch_tensor2); + std::vector expand_batch_portion = + at::infer_size(batch_tensor1, batch_tensor2); c10::SmallVector output_expand_size(expand_batch_portion); output_expand_size.insert(output_expand_size.end(), {n, p}); output_size = output_expand_size; } else { - TORCH_CHECK(false, "matmul got error sizes: ", "(", dim_tensor1, ", ", dim_tensor2, ")"); + TORCH_CHECK(false, "matmul got error sizes: ", "(", dim_tensor1, ", ", + dim_tensor2, ")"); } return output_size; @@ -96,6 +99,12 @@ static inline void matmul_implement_npu(at::Tensor &out, const at::Tensor &self, return; } +//if input was column-major, return grad as column-order for efficiency +static inline bool is_column_major(const at::Tensor &mat) { + bool row_major = (mat.stride(-1) == 1 && mat.stride(-2) == mat.size(-1)); + return false == row_major; +} + at::Tensor matmul_mat1_backward(const at::Tensor self, const at::Tensor other, const at::Tensor grad_output) { /*mat1_grad = grad * mat2^T*/ @@ -119,19 +128,44 @@ at::Tensor matmul_mat1_backward(const at::Tensor self, const at::Tensor other, } at::Tensor output; - if (mat1.dim() == 2 && mat2.dim() > 2) { // mm - output = OpPreparation::ApplyTensorWithoutFormat(mat1.sizes(), grad.options()); - mat2 = mat2.transpose(-2, -1); - mat2 = mat2.reshape({-1, mat2.size(-1)}); - grad = grad.view({grad.size(-2), -1}); - matmul_implement_npu(output, grad, mat2); - output = output.reshape(self.sizes()); + if (mat1.dim() == 2) { // mm + + // 先转置在后面一个tensor:先转置再合并k轴??? + if (is_column_major(mat1)&&mat2.dim()==2) { + //if (is_column_major(mat1)) { + output = OpPreparation::ApplyTensorWithoutFormat(mat1.t().sizes(), grad.options()); + // 列主序, (mat2*grad^T)^T: + grad = grad.transpose(-2, -1); + grad = grad.reshape({-1, grad.size(-1)}); + mat2 = mat2.reshape({mat2.size(-2), -1}); //列向连续,列向融合?? + matmul_implement_npu(output, mat2, grad); + output = output.t(); + output = output.reshape(self.sizes()); + + }else { + output = OpPreparation::ApplyTensorWithoutFormat(mat1.sizes(), grad.options()); + //grad * mat2^T:先转置再合并k轴 + mat2 = mat2.transpose(-2, -1); + mat2 = mat2.reshape({-1, mat2.size(-1)}); + grad = grad.reshape({grad.size(-2), -1}); + matmul_implement_npu(output, grad, mat2); + output = output.reshape(self.sizes()); + } } else { // bmm - mat2 = mat2.transpose(-2, -1); - auto expend_sizes = get_output_size(grad, mat2); - output = OpPreparation::ApplyTensorWithoutFormat(expend_sizes, grad.options()); - matmul_implement_npu(output, grad, mat2); + if (is_column_major(mat1)) { //(mat2*grad^T)^T: + grad = grad.transpose(-2, -1); + auto expend_sizes = get_output_size(mat2, grad); + output = OpPreparation::ApplyTensorWithoutFormat(expend_sizes, grad.options()); + matmul_implement_npu(output, mat2, grad); + output = output.transpose(-2, -1); + }else { //grad * mat2^T + mat2 = mat2.transpose(-2, -1); + auto expend_sizes = get_output_size(grad, mat2); + output = OpPreparation::ApplyTensorWithoutFormat(expend_sizes, grad.options()); + matmul_implement_npu(output, grad, mat2); + } + } return output; @@ -159,32 +193,64 @@ at::Tensor matmul_mat2_backward(const at::Tensor self, const at::Tensor other, } at::Tensor output; - if (mat2.dim() == 2 && mat1.dim() > 2) { // mm - output = OpPreparation::ApplyTensorWithoutFormat(mat2.sizes(), mat1.options()); - mat1 = mat1.reshape({-1, mat1.size(-1)}); - grad = grad.reshape({-1, grad.size(-1)}); - mat1 = mat1.transpose(-2, -1); - matmul_implement_npu(output, mat1, grad); - output = output.reshape(other.sizes()); + if (mat2.dim() == 2) { // mm + if (is_column_major(mat2)) { + output = OpPreparation::ApplyTensorWithoutFormat(mat2.t().sizes(), mat1.options()); + // 列主序, (grad^T*mat1)^T: + grad = grad.reshape({-1, grad.size(-1)}); + mat1 = mat1.reshape({-1, mat1.size(-1)}); + grad = grad.transpose(-2, -1); + + matmul_implement_npu(output, grad, mat1); + output = output.t(); + output = output.reshape(other.sizes()); + + }else { + //mat1^T * grad:先合并k轴再转置 + output = OpPreparation::ApplyTensorWithoutFormat(mat2.sizes(), mat1.options()); + mat1 = mat1.reshape({-1, mat1.size(-1)}); + grad = grad.reshape({-1, grad.size(-1)}); + mat1 = mat1.transpose(-2, -1); + matmul_implement_npu(output, mat1, grad); + output = output.reshape(other.sizes()); + } + } else { // bmm - mat1 = mat1.transpose(-2, -1); - auto expend_sizes = get_output_size(mat1, grad); - output = OpPreparation::ApplyTensorWithoutFormat(expend_sizes, mat1.options()); - matmul_implement_npu(output, mat1, grad); + if (is_column_major(mat2)){ // (grad^T*mat1)^T: + grad = grad.transpose(-2, -1); + auto expend_sizes = get_output_size(grad, mat1); + output = OpPreparation::ApplyTensorWithoutFormat(expend_sizes, mat1.options()); + matmul_implement_npu(output, grad, mat1); + output = output.transpose(-2, -1); + + }else{ //mat1^T * grad + mat1 = mat1.transpose(-2, -1); + auto expend_sizes = get_output_size(mat1, grad); + output = OpPreparation::ApplyTensorWithoutFormat(expend_sizes, mat1.options()); + matmul_implement_npu(output, mat1, grad); + } } return output; } -std::tuple matmul_backward(const at::Tensor &grad, - const at::Tensor &self, - const at::Tensor &other) { +std::tuple NPUNativeOpApiFunctions::matmul_backward( + const at::Tensor &grad, const at::Tensor &self, + const at::Tensor &other, std::array grad_input_mask) { if (!grad.defined()) { return std::make_tuple(at::Tensor(), at::Tensor()); } + // backward mat1 and mat2 separately - auto self_grad = matmul_mat1_backward(self, other, grad); - auto other_grad = matmul_mat2_backward(self, other, grad); + at::Tensor self_grad; + at::Tensor other_grad; + if (grad_input_mask[1]) { + other_grad = matmul_mat2_backward(self, other, grad); + } + + if (grad_input_mask[0]) { + self_grad = matmul_mat1_backward(self, other, grad); + } // strip added dim: (5,1)->(5) if (other.dim() == 1 && other_grad.size(-1) == 1) { @@ -203,30 +269,11 @@ at::Tensor matmul_forward(const at::Tensor &self, const at::Tensor &mat2) { return out; } -class NPUMatmulOpApiFunction: public torch::autograd::Function { -public: - static at::Tensor forward(AutogradContext *ctx, const at::Tensor &self, - const at::Tensor &other) { - at::AutoNonVariableTypeMode g; - ctx->save_for_backward({self, other}); - auto result = matmul_forward(self, other); - return result; - } - static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto self = saved[0]; - auto other = saved[1]; - auto grads = matmul_backward(grad_outputs[0], self, other); - tensor_list output = {std::get<0>(grads), std::get<1>(grads)}; - return output; - } -}; - at::Tensor NPUNativeOpApiFunctions::matmul(const at::Tensor &tensor1, const at::Tensor &tensor2) { DO_COMPATIBILITY(aclnnMatmul, NPUNativeFunctions::matmul(tensor1, tensor2)); auto maybe_outnames = at::namedinference::compute_matmul_outnames(tensor1, tensor2); - auto result = NPUMatmulOpApiFunction::apply(tensor1, tensor2); + auto result = matmul_forward(tensor1, tensor2); at::namedinference::propagate_names_if_nonempty(result, maybe_outnames); return result; } diff --git a/torch_npu/csrc/framework/autograd/FunctionsManual.cpp b/torch_npu/csrc/framework/autograd/FunctionsManual.cpp index 6645f9ff3e03e603190f2e0b29682cd3bcaec442..a30a5821cad43c21bf624c09241ac8b4123a0181 100644 --- a/torch_npu/csrc/framework/autograd/FunctionsManual.cpp +++ b/torch_npu/csrc/framework/autograd/FunctionsManual.cpp @@ -43,7 +43,7 @@ #include #include -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/aten/NPUNativeOpApiFunctions.h" #include "FunctionsManual.h" // Helper functions for autogenerated code @@ -127,6 +127,14 @@ Tensor fast_gelu_backward(const Tensor& grad, const Tensor& self) { return at_npu::native::NPUNativeFunctions::npu_fast_gelu_backward(grad, self); } +std::tuple +matmul_backward(const at::Tensor &grad, const at::Tensor &self, + const at::Tensor &other, std::array grad_input_mask) { + + return at_npu::native::NPUNativeOpApiFunctions::matmul_backward( + grad, self, other, grad_input_mask); +} + std::tuple rotary_mul_backward( const Tensor& grad, const Tensor& self, diff --git a/torch_npu/csrc/framework/autograd/FunctionsManual.h b/torch_npu/csrc/framework/autograd/FunctionsManual.h index a406a3d1fec1957faad730589806d88488591265..69180b8735a169896cd76275a8adc04e56253aa4 100644 --- a/torch_npu/csrc/framework/autograd/FunctionsManual.h +++ b/torch_npu/csrc/framework/autograd/FunctionsManual.h @@ -59,6 +59,10 @@ std::vector not_implemented_list(const char* name, const char* reason="" Tensor fast_gelu_backward(const Tensor& grad, const Tensor& self); +std::tuple +matmul_backward(const at::Tensor &grad, const at::Tensor &self, + const at::Tensor &other, std::array grad_input_mask); + std::tuple rotary_mul_backward( const Tensor& grad, const Tensor& self,