From f1489a3418c030bf274207e5a84b25d74108f1d4 Mon Sep 17 00:00:00 2001 From: pipihugh Date: Wed, 16 Feb 2022 18:09:25 +0800 Subject: [PATCH] =?UTF-8?q?layer=5Fnorm=5Feval=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E7=A7=BB=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_layer_norm_eval.py | 53 +++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 1 + .../csrc/aten/ops/LayerNormEvalKernelNpu.cpp | 90 +++++++++++++++++++ torch_npu/utils/module.py | 2 +- 4 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 test/test_network_ops/test_layer_norm_eval.py create mode 100644 torch_npu/csrc/aten/ops/LayerNormEvalKernelNpu.cpp diff --git a/test/test_network_ops/test_layer_norm_eval.py b/test/test_network_ops/test_layer_norm_eval.py new file mode 100644 index 00000000000..d03703b6d01 --- /dev/null +++ b/test/test_network_ops/test_layer_norm_eval.py @@ -0,0 +1,53 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import torch.nn as nn +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestLayerNormEval(TestCase): + def cpu_op_exec(self, input1): + m = nn.LayerNorm(input1.size()[1:]).eval() + output = m(input1) + return output + + def npu_op_exec(self, input1): + m = nn.LayerNorm(input1.size()[1:]).npu().eval() + output = m(input1) + output = output.to("cpu") + return output + + def test_layernorm_shape_format(self, device): + shape_format = [ + [np.float32, 0, (64, 10)], + [np.float32, 0, (256, 2048, 7, 7)], + [np.float32, 0, (32, 1, 3, 3)], + [np.float32, 0, (10, 128)] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output.detach().numpy(), npu_output.detach().numpy()) + +instantiate_device_type_tests(TestLayerNormEval, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index fac8c89f70e..1464fd57475 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1912,6 +1912,7 @@ custom: - func: npu_indexing.out(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0, *, Tensor(a!) out) -> Tensor(a!) - func: npu_bmmV2(Tensor self, Tensor mat2, int[] output_sizes) -> Tensor variants: function, method + - func: npu_layer_norm_eval(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05) -> Tensor - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor variants: function, method - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor diff --git a/torch_npu/csrc/aten/ops/LayerNormEvalKernelNpu.cpp b/torch_npu/csrc/aten/ops/LayerNormEvalKernelNpu.cpp new file mode 100644 index 00000000000..7aab63d130d --- /dev/null +++ b/torch_npu/csrc/aten/ops/LayerNormEvalKernelNpu.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpTemplate.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::npu_layer_norm_eval( + const at::Tensor& input, + at::IntArrayRef normalized_shape, + const c10::optional &weight_opt, + const c10::optional &bias_opt, + double eps) { + const at::Tensor& weight = c10::value_or_else(weight_opt, [] {return at::Tensor();}); + const at::Tensor& bias = c10::value_or_else(bias_opt, [] {return at::Tensor();}); + const int normalized_ndim = normalized_shape.size(); + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + const int axis = input_ndim - normalized_ndim; + const int64_t M = std::accumulate( + input_shape.cbegin(), + input_shape.cbegin() + axis, + 1LL, + std::multiplies()); + const int64_t N = std::accumulate( + input_shape.cbegin() + axis, + input_shape.cend(), + 1LL, + std::multiplies()); + at::Tensor Y = OpPreparation::ApplyTensor(input); + int64_t numels = 1; + int64_t begin_dim = 0; + c10::SmallVector tmpSize; + for (int64_t i = input.dim() - 1; i >= 0; i--) { + numels *= input.size(i); + tmpSize.emplace_back(input.size(i)); + if(numels == N) { + begin_dim = i; + break; + } + } + std::reverse(tmpSize.begin(), tmpSize.end()); + at::Tensor resizeWeight = weight; + resizeWeight.requires_grad_(false); + at::Tensor resizeBias = bias; + resizeBias.requires_grad_(false); + if (!resizeWeight.defined()) { + resizeWeight = at::ones(tmpSize, input.options()); + } else if (!resizeWeight.sizes().equals(tmpSize)) { + resizeWeight.resize_(tmpSize); + } + if (!resizeBias.defined()) { + resizeBias = at::zeros(tmpSize, input.options()); + } else if (!resizeBias.sizes().equals(tmpSize)) { + resizeBias.resize_(tmpSize); + } + at::Tensor mean = OpPreparation::ApplyTensor(resizeWeight, {M}); + at::Tensor rstd = OpPreparation::ApplyTensor(resizeWeight, {M});; + OpCommand cmd; + cmd.Name("LayerNorm") + .Input(input) + .Input(resizeWeight) + .Input(resizeBias) + .Output(Y) + .Output(mean) + .Output(rstd) + .Attr("begin_norm_axis", begin_dim) + .Attr("begin_params_axis", begin_dim) + .Attr("epsilon", static_cast(eps)) + .Run(); + return Y; +} +}} diff --git a/torch_npu/utils/module.py b/torch_npu/utils/module.py index 332995804fd..8e2fb9ae030 100644 --- a/torch_npu/utils/module.py +++ b/torch_npu/utils/module.py @@ -117,7 +117,7 @@ def cast_weight(self, device): def layernorm_forward(self, input: torch.Tensor) -> torch.Tensor: - if self.training: + if self.training or (not input.is_npu): return torch.nn.functional.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps) else: -- Gitee