From 5b2c892bf9f014c1cf6299872c9f0d0c4a02fcca Mon Sep 17 00:00:00 2001 From: lgq1997 Date: Thu, 7 Jul 2022 14:09:08 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=AF=B9ScaledMaskedSoftmax?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E7=9A=84=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_scaled_masked_softmax.py | 109 ++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 1 + .../aten/ops/ScaledMaskedSoftmaxKernelNpu.cpp | 101 ++++++++++++++++ 3 files changed, 211 insertions(+) create mode 100644 test/test_network_ops/test_scaled_masked_softmax.py create mode 100644 torch_npu/csrc/aten/ops/ScaledMaskedSoftmaxKernelNpu.cpp diff --git a/test/test_network_ops/test_scaled_masked_softmax.py b/test/test_network_ops/test_scaled_masked_softmax.py new file mode 100644 index 0000000000..f5899e0444 --- /dev/null +++ b/test/test_network_ops/test_scaled_masked_softmax.py @@ -0,0 +1,109 @@ +# Copyright (c) 2022, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import torch +import numpy as np +import torch.nn.functional as F +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestScaledMaskedSoftmax(TestCase): + def cpu_op_exec_forward(self, x, mask, scale, fixed_triu_mask): + x = x.float() + if fixed_triu_mask: + mask_tri = torch.triu(torch.ones(mask.shape, device=mask.device), diagonal=1).bool() + output = F.softmax((x * scale).masked_fill(mask_tri, value=-10000), dim=-1).half() + else: + output = F.softmax((x * scale).masked_fill(mask, value=-10000), dim=-1).half() + return output.detach().numpy() + + def npu_op_exec_forward(self, x, mask, scale, fixed_triu_mask): + output = torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask) + return output.cpu().detach().numpy() + + def cpu_op_exec_backward(self, x, y_grad, mask, scale, fixed_triu_mask): + x.requires_grad_(True) + x_fp32 = x.float() + y_grad = y_grad.float() + if fixed_triu_mask: + mask_tri = torch.triu(torch.ones(mask.shape, device=mask.device), diagonal=1).bool() + output = F.softmax((x_fp32 * scale).masked_fill(mask_tri, value=-10000), dim=-1).half() + output.backward(y_grad) + else: + output = F.softmax((x_fp32 * scale).masked_fill(mask, value=-10000), dim=-1).half() + output.backward(y_grad) + x_grad = x.grad + return x_grad.detach().numpy() + + def npu_op_exec_backward(self, x, y_grad, mask, scale, fixed_triu_mask): + x.requires_grad_(True) + output = torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask) + output.backward(y_grad) + x_grad = x.grad + return x_grad.half().cpu().detach().numpy() + + def test_scaled_masked_softmax_shape_format(self): + shape_format = [ + [[np.float16, 29, (16, 6, 128, 128)], [np.float16, 29, (16, 6, 128, 128)]], + [[np.float16, 2, (16, 6, 128, 512)], [np.float16, 2, (16, 1, 128, 512)]], + [[np.float16, 0, (16, 6, 512, 512)], [np.float16, 0, (16, 1, 512, 512)]], + [[np.float32, 29, (16, 6, 128, 128)], [np.float32, 29, (16, 6, 128, 128)]], + [[np.float32, 2, (16, 6, 128, 512)], [np.float32, 2, (16, 1, 128, 512)]], + [[np.float32, 0, (16, 6, 512, 512)], [np.float32, 0, (16, 1, 512, 512)]], + ] + + # forward ut test + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], -5, 5) + cpu_mask, npu_mask = create_common_tensor(item[1], -1, 1) + cpu_mask = cpu_mask > 0 + npu_mask = npu_mask > 0 + if cpu_input.dtype == torch.float16: + cpu_input = cpu_input.to(torch.float32) + scale = random.uniform(-1, 1) + fixed_triu_mask = False + cpu_output = self.cpu_op_exec_forward(cpu_input, cpu_mask, + scale, fixed_triu_mask) + npu_output = self.npu_op_exec_forward(npu_input, npu_mask, + scale, fixed_triu_mask) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + # backward ut test + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], -5, 5) + cpu_y_grad, npu_y_grad = create_common_tensor(item[0], -5, 5) + cpu_mask, npu_mask = create_common_tensor(item[1], -1, 1) + cpu_mask = cpu_mask > 0 + npu_mask = npu_mask > 0 + if cpu_input.dtype == torch.float16: + cpu_input = cpu_input.to(torch.float32) + if cpu_y_grad.dtype == torch.float16: + cpu_y_grad = cpu_y_grad.to(torch.float32) + scale = random.uniform(-1, 1) + fixed_triu_mask = False + cpu_x_grad = self.cpu_op_exec_backward(cpu_input, cpu_y_grad, + cpu_mask, scale, fixed_triu_mask) + npu_x_grad = self.npu_op_exec_backward(npu_input, npu_y_grad, + npu_mask, scale, fixed_triu_mask) + cpu_x_grad = cpu_x_grad.astype(npu_x_grad.dtype) + self.assertRtolEqual(cpu_x_grad, npu_x_grad) + + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 8a7eb6ec1a..8f74262501 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1999,6 +1999,7 @@ custom_autograd: variants: function - func: _dropout_with_byte_mask(Tensor self, float p) -> (Tensor, Tensor) - func: npu_dropout_with_add_softmax(Tensor self, Tensor x1, Scalar alpha, float prob, int dim) -> (Tensor, Tensor, Tensor) + - func: npu_scaled_masked_softmax(Tensor x, Tensor mask, Scalar scale=1, bool fixed_triu_mask=False) -> Tensor variants: function, method - func: npu_multi_head_attention(Tensor query, Tensor key, Tensor value, Tensor query_weight, Tensor key_weight, Tensor value_weight, Tensor attn_mask, Tensor out_proj_weight, Tensor? query_bias, Tensor? key_bias, Tensor? value_bias, Tensor? out_proj_bias, Tensor? dropout_mask, int attn_head_num, int attn_dim_per_head, int src_len, int tgt_len, float dropout_prob, bool softmax_use_float) -> Tensor[] - func: npu_dropout_do_mask(Tensor self, Tensor mask, float p) -> (Tensor, Tensor) diff --git a/torch_npu/csrc/aten/ops/ScaledMaskedSoftmaxKernelNpu.cpp b/torch_npu/csrc/aten/ops/ScaledMaskedSoftmaxKernelNpu.cpp new file mode 100644 index 0000000000..38ea7584ee --- /dev/null +++ b/torch_npu/csrc/aten/ops/ScaledMaskedSoftmaxKernelNpu.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2022 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +at::Tensor npu_scaled_masked_softmax_forward( + const at::Tensor& self, + const at::Tensor& mask, + at::Scalar scale, + bool fixed_triu_mask){ + at::Tensor result = OpPreparation::ApplyTensor(self, self.options().dtype(at::kHalf)); + OpCommand cmd; + cmd.Name("ScaledMaskedSoftmax") + .Input(self) + .Input(mask) + .Output(result) + .Attr("scale", scale) + .Attr("fixed_triu_mask", fixed_triu_mask) + .Run(); + return result; +} + +at::Tensor npu_scaled_masked_softmax_backward( + const at::Tensor& y_grad, + const at::Tensor& y, + const at::Tensor& mask, + at::Scalar scale, + bool fixed_triu_mask){ + at::Tensor result = OpPreparation::ApplyTensor(y_grad, y_grad.options().dtype(at::kHalf)); + OpCommand cmd; + cmd.Name("ScaledMaskedSoftmaxGrad") + .Input(y_grad) + .Input(y) + .Input(mask) + .Output(result) + .Attr("scale", scale) + .Attr("fixed_triu_mask", fixed_triu_mask) + .Run(); + return result; +} + +class NPUscalemsFunction: public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self, + const at::Tensor& mask, + at::Scalar scale, + bool fixed_triu_mask) { + ctx->saved_data["scale"] = scale; + ctx->saved_data["fixed_triu_mask"] = fixed_triu_mask; + auto result = npu_scaled_masked_softmax_forward(self, mask, scale, fixed_triu_mask); + ctx->save_for_backward({result, mask}); + return result; + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto scale = ctx->saved_data["scale"].toScalar(); + auto fixed_triu_mask = ctx->saved_data["fixed_triu_mask"].toBool(); + auto saved = ctx->get_saved_variables(); + auto result0 = saved[0]; + auto mask = saved[1]; + auto result = npu_scaled_masked_softmax_backward( + grad_outputs[0], result0, mask, scale, fixed_triu_mask); + tensor_list output = {result, at::Tensor(), at::Tensor(), at::Tensor()}; + return output; + } +}; + +at::Tensor NPUNativeFunctions::npu_scaled_masked_softmax( + const at::Tensor& self, + const at::Tensor& mask, + at::Scalar scale, + bool fixed_triu_mask){ + auto result = NPUscalemsFunction::apply(self, mask, scale, fixed_triu_mask); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee