diff --git a/test/test_network_ops/test_layer_norm_eval.py b/test/test_network_ops/test_layer_norm_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b134dd9a60f6eccb1be422f726c743b1dfc699 --- /dev/null +++ b/test/test_network_ops/test_layer_norm_eval.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import torch.nn as nn +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestLayerNormEval(TestCase): + def cpu_op_exec(self, input1): + m = nn.LayerNorm(input1.size()[1:]).eval() + output = m(input1) + return output + + def npu_op_exec(self, input1): + m = nn.LayerNorm(input1.size()[1:]).npu().eval() + output = m(input1) + output = output.to("cpu") + return output + + def test_layernorm_shape_format(self, device="npu"): + shape_format = [ + [np.float32, 0, (64, 10)], + [np.float32, 0, (256, 2048, 7, 7)], + [np.float32, 0, (32, 1, 3, 3)], + [np.float32, 0, (10, 128)] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output.detach().numpy(), npu_output.detach().numpy()) + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index b4c04c619f7991a453360e2260fc216d6566ea5a..9c1f5eb37a82d07a845c43eaed2636ed3dc0613f 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1932,6 +1932,7 @@ custom: - func: npu_sub_sample(Tensor self, int per_images, float positive_fraction) -> Tensor - func: npu_yolo_boxes_encode(Tensor self, Tensor gt_bboxes, Tensor stride, bool performance_mode=False) -> Tensor - func: npu_scatter(Tensor self, Tensor indices, Tensor updates, int dim) -> Tensor + - func: npu_layer_norm_eval(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/LayerNormEvalKernelNpu.cpp b/torch_npu/csrc/aten/ops/LayerNormEvalKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..634893cf097f2eb4614b5f84abccdbd6b8525294 --- /dev/null +++ b/torch_npu/csrc/aten/ops/LayerNormEvalKernelNpu.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpTemplate.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::npu_layer_norm_eval( + const at::Tensor& input, + at::IntArrayRef normalized_shape, + const c10::optional &weight_opt, + const c10::optional &bias_opt, + double eps) { + const at::Tensor& weight = c10::value_or_else(weight_opt, [] {return at::Tensor();}); + const at::Tensor& bias = c10::value_or_else(bias_opt, [] {return at::Tensor();}); + const int normalized_ndim = normalized_shape.size(); + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + const int axis = input_ndim - normalized_ndim; + const int64_t M = std::accumulate( + input_shape.cbegin(), + input_shape.cbegin() + axis, + 1LL, + std::multiplies()); + const int64_t N = std::accumulate( + input_shape.cbegin() + axis, + input_shape.cend(), + 1LL, + std::multiplies()); + at::Tensor result = OpPreparation::ApplyTensor(input); + int64_t numels = 1; + int64_t begin_dim = 0; + c10::SmallVector tmpSize; + for (int64_t i = input.dim() - 1; i >= 0; i--) { + numels *= input.size(i); + tmpSize.emplace_back(input.size(i)); + if(numels == N) { + begin_dim = i; + break; + } + } + std::reverse(tmpSize.begin(), tmpSize.end()); + at::Tensor resizeWeight = weight; + resizeWeight.requires_grad_(false); + at::Tensor resizeBias = bias; + resizeBias.requires_grad_(false); + if (!resizeWeight.defined()) { + resizeWeight = at::ones(tmpSize, input.options()); + } else if (!resizeWeight.sizes().equals(tmpSize)) { + resizeWeight.resize_(tmpSize); + } + if (!resizeBias.defined()) { + resizeBias = at::zeros(tmpSize, input.options()); + } else if (!resizeBias.sizes().equals(tmpSize)) { + resizeBias.resize_(tmpSize); + } + at::Tensor mean = OpPreparation::ApplyTensor(resizeWeight, {M}); + at::Tensor rstd = OpPreparation::ApplyTensor(resizeWeight, {M}); + OpCommand cmd; + cmd.Name("LayerNorm") + .Input(input) + .Input(resizeWeight) + .Input(resizeBias) + .Output(result) + .Output(mean) + .Output(rstd) + .Attr("begin_norm_axis", begin_dim) + .Attr("begin_params_axis", begin_dim) + .Attr("epsilon", static_cast(eps)) + .Run(); + return result; +} +}} diff --git a/torch_npu/utils/module.py b/torch_npu/utils/module.py index e0ff721267067acf209871e2a0b8a5058adc0aed..9d48d4060f8855179aaf2fe9391185eed222c15a 100644 --- a/torch_npu/utils/module.py +++ b/torch_npu/utils/module.py @@ -131,7 +131,7 @@ def cast_weight(self, device): def layernorm_forward(self, input: torch.Tensor) -> torch.Tensor: - if self.training: + if self.training or (not input.is_npu): return torch.nn.functional.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps) else: