From 93220b66e163dfa2128838a28fc69af659316e5a Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Wed, 16 Feb 2022 10:37:41 +0800 Subject: [PATCH] =?UTF-8?q?npu=5Flinear=E7=AE=97=E5=AD=901.8.1=E7=A7=BB?= =?UTF-8?q?=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_npu_linear.py | 63 ++++++++ .../test_npu_linear_backward.py | 81 ++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 4 +- torch_npu/csrc/aten/ops/LinearKernelNpu.cpp | 142 ++++++++++++++++++ 4 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 test/test_network_ops/test_npu_linear.py create mode 100644 test/test_network_ops/test_npu_linear_backward.py create mode 100644 torch_npu/csrc/aten/ops/LinearKernelNpu.cpp diff --git a/test/test_network_ops/test_npu_linear.py b/test/test_network_ops/test_npu_linear.py new file mode 100644 index 0000000000..fbf7329694 --- /dev/null +++ b/test/test_network_ops/test_npu_linear.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestNpuLinear(TestCase): + def cpu_op_exec(self, x, weight, bias): + output = torch.nn.functional.linear(x, weight, bias) + output = output.numpy() + return output + + def npu_op_exec(self, x, weight, bias): + output = torch_npu.npu_linear(x, weight, bias) + output = output.cpu().numpy() + return output + + def test_npu_linear_shape_format_fp32(self, device): + shape_format = [ + [[np.float32, -1, (6144, 1024)], [np.float32, -1, (256, 1024)], [np.float32, -1, (256)]], + [[np.float32, -1, (123, 456)], [np.float32, -1, (789, 456)], [np.float32, -1, (789)]], + ] + + for item in shape_format: + cpu_x, npu_x = create_common_tensor(item[0], -2, 2) + cpu_w, npu_w = create_common_tensor(item[1], -2, 2) + cpu_b, npu_b = create_common_tensor(item[2], -2, 2) + cpu_output = self.cpu_op_exec(cpu_x, cpu_w, cpu_b) + npu_output = self.npu_op_exec(npu_x, npu_w, npu_b) + self.assertRtolEqual(cpu_output, npu_output, 0.0002) + + def test_npu_linear_shape_format_fp16(self, device): + shape_format = [ + [[np.float16, -1, (6144, 1024)], [np.float16, -1, (256, 1024)], [np.float16, -1, (256)]], + [[np.float16, -1, (123, 456)], [np.float16, -1, (789, 456)], [np.float16, -1, (789)]], + ] + + for item in shape_format: + cpu_x, npu_x = create_common_tensor(item[0], -2, 2) + cpu_w, npu_w = create_common_tensor(item[1], -2, 2) + cpu_b, npu_b = create_common_tensor(item[2], -2, 2) + cpu_output = self.cpu_op_exec(cpu_x.float(), cpu_w.float(), cpu_b.float()).astype(np.float16) + npu_output = self.npu_op_exec(npu_x, npu_w, npu_b) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestNpuLinear, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_npu_linear_backward.py b/test/test_network_ops/test_npu_linear_backward.py new file mode 100644 index 0000000000..df00143775 --- /dev/null +++ b/test/test_network_ops/test_npu_linear_backward.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestNpuLinearBackward(TestCase): + def cpu_op_exec(self, x, weight, bias): + x.requires_grad = True + weight.requires_grad = True + bias.requires_grad = True + output = torch.nn.functional.linear(x, weight, bias) + loss = output.sum() + loss.backward() + list1 = [output.detach().numpy(), x.grad.numpy(), weight.grad.numpy(), bias.grad.numpy()] + return list1 + + def npu_op_exec(self, x, weight, bias): + x.requires_grad = True + weight.requires_grad = True + bias.requires_grad = True + output = torch_npu.npu_linear(x, weight, bias) + loss = output.sum() + loss.backward() + list2 = [output.cpu().detach().numpy(), x.grad.cpu().numpy(), weight.grad.cpu().numpy(), + bias.grad.cpu().numpy()] + return list2 + + def test_npu_linear_backward_shape_format_fp32(self, device): + shape_format = [ + [[np.float32, -1, (6144, 1024)], [np.float32, -1, (256, 1024)], [np.float32, -1, (256)]], + [[np.float32, -1, (123, 456)], [np.float32, -1, (789, 456)], [np.float32, -1, (789)]], + ] + + for item in shape_format: + cpu_x, npu_x = create_common_tensor(item[0], -2, 2) + cpu_w, npu_w = create_common_tensor(item[1], -2, 2) + cpu_b, npu_b = create_common_tensor(item[2], -2, 2) + getlist1 = self.cpu_op_exec(cpu_x, cpu_w, cpu_b) + getlist2 = self.npu_op_exec(npu_x, npu_w, npu_b) + self.assertRtolEqual(getlist1[0], getlist2[0], 0.0002) + self.assertRtolEqual(getlist1[1], getlist2[1]) + self.assertRtolEqual(getlist1[2], getlist2[2]) + self.assertRtolEqual(getlist1[3], getlist2[3]) + + def test_npu_linear_shape_format_fp16(self, device): + shape_format = [ + [[np.float16, -1, (6144, 1024)], [np.float16, -1, (256, 1024)], [np.float16, -1, (256)]], + [[np.float16, -1, (123, 456)], [np.float16, -1, (789, 456)], [np.float16, -1, (789)]], + ] + + for item in shape_format: + cpu_x, npu_x = create_common_tensor(item[0], -2, 2) + cpu_w, npu_w = create_common_tensor(item[1], -2, 2) + cpu_b, npu_b = create_common_tensor(item[2], -2, 2) + getlist1 = self.cpu_op_exec(cpu_x.float(), cpu_w.float(), cpu_b.float()) + getlist2 = self.npu_op_exec(npu_x, npu_w, npu_b) + self.assertRtolEqual(getlist1[0].astype(np.float16), getlist2[0]) + self.assertRtolEqual(getlist1[1].astype(np.float16), getlist2[1]) + self.assertRtolEqual(getlist1[2].astype(np.float16), getlist2[2]) + self.assertRtolEqual(getlist1[3].astype(np.float16), getlist2[3]) + +instantiate_device_type_tests(TestNpuLinearBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index f123789445..e250c74258 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1921,9 +1921,11 @@ custom: - func: npu_confusion_transpose_backward(Tensor grad, int[] perm, int[] shape, bool transpose_first) -> Tensor - func: npu_one_hot(Tensor self, int num_classes=-1, int depth=1, Scalar on_value=1, Scalar off_value=0) -> Tensor variants: function, method + - func: npu_linear_backward(Tensor grad, Tensor input, Tensor weight) -> (Tensor, Tensor) custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor - func: fast_gelu(Tensor self) -> Tensor - func: npu_confusion_transpose(Tensor self, int[] perm, int[] shape, bool transpose_first) -> Tensor - variants: function, method \ No newline at end of file + variants: function, method + - func: npu_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor diff --git a/torch_npu/csrc/aten/ops/LinearKernelNpu.cpp b/torch_npu/csrc/aten/ops/LinearKernelNpu.cpp new file mode 100644 index 0000000000..ffe290ab1e --- /dev/null +++ b/torch_npu/csrc/aten/ops/LinearKernelNpu.cpp @@ -0,0 +1,142 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +at::Tensor linear_npu( + const at::Tensor& input, + const at::Tensor& weight, + const c10::optional & bias_opt) { + const at::Tensor& bias = c10::value_or_else(bias_opt, [] {return at::Tensor();}); + c10::SmallVector outputSize = {input.size(0), weight.size(0)}; + at::Tensor output = OpPreparation::ApplyTensor(input, outputSize); + + int64_t offset_x = 0; + OpCommand cmd; + cmd.Name("MatMulV2") + .Input(input) + .Input(weight); + if (bias.defined()) { + cmd.Input(bias); + } + cmd.Output(output) + .Attr("transpose_x1", false) + .Attr("transpose_x2", true) + .Attr("offset_x", offset_x) + .Run(); + + return output; +} + +at::Tensor linear_backward_out_npu( + at::Tensor& result, + const at::Tensor& input, + const at::Tensor& weight, + bool transpose_x1, + bool transpose_x2) { + int64_t offset_x = 0; + OpCommand cmd; + cmd.Name("MatMulV2") + .Input(input) + .Input(weight) + .Output(result) + .Attr("transpose_x1", transpose_x1) + .Attr("transpose_x2", transpose_x2) + .Attr("offset_x", offset_x) + .Run(); + return result; +} + +tuple NPUNativeFunctions::npu_linear_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight) { + c10::SmallVector inputGradOutputSize = { + grad.size(0), + weight.size(1)}; + c10::SmallVector weightGradOutputSize = { + grad.size(1), + input.size(1)}; + at::Tensor inputGrad = OpPreparation::ApplyTensor(input, inputGradOutputSize); + at::Tensor weightGrad = OpPreparation::ApplyTensor(weight, weightGradOutputSize); + + if (CalcuOpUtil::get_tensor_npu_format(grad) == CalcuOpUtil::get_tensor_npu_format(weight)) { + linear_backward_out_npu(inputGrad, grad, weight, false, false); + linear_backward_out_npu(weightGrad, grad, input, true, false); + } else { + at::Tensor gradFormatcast = OpPreparation::ApplyTensor(grad, grad.sizes()); + gradFormatcast = NPUNativeFunctions::npu_format_cast(grad, CalcuOpUtil::get_tensor_npu_format(weight)); + linear_backward_out_npu(inputGrad, gradFormatcast, weight, false, false); + linear_backward_out_npu(weightGrad, gradFormatcast, input, true, false); + } + + return std::tie(inputGrad, weightGrad); +} + +class NPULinearFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& input, + const at::Tensor& weight, + const c10::optional& bias_opt) { + ctx->saved_data["bias_has_value"] = (bias_opt.has_value() == true) ? bias_opt.value().requires_grad() : false; + + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({input, weight}); + return linear_npu(input, weight, bias_opt); + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto bias_has_value = ctx->saved_data["bias_has_value"].toBool(); + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto weight = saved[1]; + + tuple result = NPUNativeFunctions::npu_linear_backward(grad_outputs[0], input, weight); + + tensor_list output; + if (bias_has_value) { + output = {std::get<0>(result), + std::get<1>(result), + grad_outputs[0]}; + } else { + output = {std::get<0>(result), + std::get<1>(result), + at::Tensor()}; + } + return output; + } +}; + +at::Tensor NPUNativeFunctions::npu_linear(const at::Tensor& input, + const at::Tensor& weight, + const c10::optional& bias_opt) { + return NPULinearFunction::apply(input, weight, bias_opt); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee