diff --git a/test/test_network_ops/test_layer_norm.py b/test/test_network_ops/test_layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..743858a2f20da6862d547ade3ed8e570c8bdcdfe --- /dev/null +++ b/test/test_network_ops/test_layer_norm.py @@ -0,0 +1,94 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import torch.nn as nn +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestLayerNorm(TestCase): + def test_c10_layer_norm(self, device): + # test that we can call c10 ops and they return a reasonable result + X = torch.rand(5, 5, dtype=torch.float, device="cpu") + X = X.to("npu") + weight = torch.rand(*X.size()[1:], dtype=torch.float, device="cpu") + weight = weight.to("npu") + bias = torch.rand(*X.size()[1:], dtype=torch.float, device="cpu") + bias = bias.to("npu") + epsilon = 1e-4 + + expected_norm = torch.nn.functional.layer_norm( + X, X.size()[1:], weight=weight, bias=bias, eps=epsilon) + expected_norm_cpu = torch.nn.functional.layer_norm( + X.cpu(), X.size()[1:], weight=weight.cpu(), bias=bias.cpu(), eps=epsilon) + self.assertRtolEqual(expected_norm.cpu().numpy(), expected_norm_cpu.numpy()) + + actual_norm, actual_mean, actual_stdev = \ + torch.ops._caffe2.LayerNorm(torch.tensor(X.cpu()), torch.tensor( + weight.cpu()), torch.tensor(bias.cpu()), 1, epsilon, True) + self.assertRtolEqual(expected_norm.cpu().numpy(), actual_norm.numpy()) + + def cpu_op_exec(self, input1): + m = nn.LayerNorm(input1.size()[1:]) + output = m(input1) + return output + + def npu_op_exec(self, input1): + m = nn.LayerNorm(input1.size()[1:]).npu() + output = m(input1) + output = output.to("cpu") + return output + + def test_layer_norm_shape_format(self, device): + shape_format = [ + [np.float32, 0, (64, 10)], + [np.float32, 0, (256, 2048, 7, 7)], + [np.float32, 0, (32, 1, 3, 3)], + [np.float32, 0, (10, 128)], + [np.float32, 2, (46, 16)], + [np.float32, 3, (2, 2, 2)], + [np.float32, 29, (3, 4, 5, 6)] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output.detach().numpy(), npu_output.detach().numpy()) + + def test_layer_norm_float16_format(self, device): + shape_format = [ + [np.float16, 0, (64, 10)], + [np.float16, 0, (256, 2048, 7, 7)], + [np.float16, 0, (32, 1, 3, 3)], + [np.float16, 0, (10, 128)], + [np.float16, 2, (46, 16)], + [np.float16, 3, (2, 2, 2)], + [np.float16, 29, (3, 4, 5, 6)] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 10) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + cpu_output = cpu_output.to(torch.float16) + self.assertRtolEqual(cpu_output.detach().numpy(), npu_output.detach().numpy()) + +instantiate_device_type_tests(TestLayerNorm, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_layer_norm_backward.py b/test/test_network_ops/test_layer_norm_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..1af28911a00d1c2be84428dba4fe613bc2cdc2bf --- /dev/null +++ b/test/test_network_ops/test_layer_norm_backward.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestLayerNorm(TestCase): + weight_grad = [] + + def getWeightGrad(self, grad): + self.weight_grad.append(grad.cpu()) + + def cpu_op_exec(self, input1, normalized_shape): + input1.requires_grad_(True) + input1.retain_grad() + m = torch.nn.LayerNorm(normalized_shape=normalized_shape) + res = m(input1) + w = torch.ones_like(res) + res.backward(w) + + grad_output = input1.grad.detach().numpy() + grad_bias = m.bias.grad.detach().numpy() + grad_weight = m.weight.grad.detach().numpy() + return grad_output, grad_weight, grad_bias + + def npu_op_exec_new(self, input1, normalized_shape): + input1.requires_grad_(True) + input1.retain_grad() + input1 = input1.npu() + m = torch.nn.LayerNorm(normalized_shape = normalized_shape).npu() + m.weight.register_hook(lambda grad: self.getWeightGrad(grad)) + res = m(input1) + w = torch.ones_like(res) + res.backward(w) + + grad_output = input1.grad.cpu().detach().numpy() + grad_bias = m.bias.grad.cpu().detach().numpy() + grad_weight = m.weight.grad.cpu().detach().numpy() + return grad_output, grad_weight, grad_bias + + def test_layernorm_shape_format(self, device): + shape_format = [ + [np.float32, 3, [256, 32, 112, 112]], + [np.float16, 3, [256, 672, 7, 7]], + [np.float16, 3, [256, 288, 14, 14]], + [np.float16, 3, [1024, 58, 28, 28]], + [np.float16, 3, [1024, 116, 14, 14]], + [np.float16, 3, [1024, 24, 112, 112]], + [np.float16, 0, [1024, 58, 56, 56]], + [np.float16, 0, [1024, 58, 56, 56]], + [np.float16, 2, [1024, 24, 28, 28]], + [np.float16, 2, [1024, 116, 28, 28]], + [np.float16, 29, [1024, 232, 7, 7]], + [np.float16, 29, [1024, 232, 14, 14]], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + if cpu_input.dtype == torch.float16: + cpu_input = cpu_input.to(torch.float32) + + cpu_grad_output, cpu_grad_weight, cpu_grad_bias = self.cpu_op_exec(cpu_input, item[2][3]) + npu_grad_output, npu_grad_weight, npu_grad_bias = self.npu_op_exec_new(npu_input, item[2][3]) + + cpu_grad_output = cpu_grad_output.astype(npu_grad_output.dtype) + cpu_grad_weight = cpu_grad_weight.astype(npu_grad_weight.dtype) + cpu_grad_bias = cpu_grad_bias.astype(npu_grad_bias.dtype) + + self.assertRtolEqual(cpu_grad_output, npu_grad_output) + # TODO(ascend): Insufficient precision + #npu_grad_weight精度未满足要求 + self.assertRtolEqual(cpu_grad_weight, npu_grad_weight, 1) + self.assertRtolEqual(cpu_grad_bias, npu_grad_bias) + + +instantiate_device_type_tests(TestLayerNorm, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/LayerNormBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/LayerNormBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40fcfd5027641859356828e610cd62e5c91fb087 --- /dev/null +++ b/torch_npu/csrc/aten/ops/LayerNormBackwardKernelNpu.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +tuple layer_norm_backward_npu_nocheck( + at::Tensor& dX, + at::Tensor& dgamma, + at::Tensor& dbeta, + const at::Tensor& dY, + const at::Tensor& X, + const at::Tensor& mean, + const at::Tensor& variance, + const at::Tensor& gamma, + int64_t M, + int64_t N) +{ + // constructs the input and output NPUTensorDesc + at::SmallVector tmpSize = array_to_small_vector(X.sizes()); + for (int i = X.dim() - gamma.dim(); i < X.dim(); i++) { + tmpSize[i] = 1; + } + at::Tensor mean_ex = mean.reshape(tmpSize); + at::Tensor variance_ex = variance.reshape(tmpSize); + double eps = 1e-05; + + OpCommand cmd; + cmd.Name("LayerNormGrad") + .Input(dY) + .Input(X) + .Input(variance_ex) + .Input(mean_ex) + .Input(gamma) + .Output(dX) + .Output(dgamma) + .Output(dbeta) + .Run(); + + return tuple(dX, dgamma, dbeta); +} + +std::tuple layer_norm_backward_npu_support( + const at::Tensor& dY, + const at::Tensor& X, + const at::Tensor& mean, + const at::Tensor& variance, + const c10::optional& gamma_ex, + int64_t M, + int64_t N, + std::array output_mask) { + const at::Tensor& gamma = c10::value_or_else(gamma_ex, [] {return at::Tensor();}); + at::Tensor dX; + at::Tensor dgamma; + at::Tensor dbeta; + at::Tensor gammaTemp = gamma; + + at::SmallVector tmpSize; + int64_t numels = 1; + for (int64_t i = X.dim() - 1; i >= 0; i--) { + numels *= X.size(i); + tmpSize.emplace_back(X.size(i)); + if(numels == N) { + break; + } + } + std::reverse(tmpSize.begin(), tmpSize.end()); + if (!gamma.defined()) { + gammaTemp = at::ones(tmpSize, X.options()); + } else if (!gamma.sizes().equals(tmpSize)) { + gammaTemp.resize_(tmpSize); + } + + // calculate the output size + auto outputSizes = layer_norm_backward_npu_output_size(dY, X, mean, variance, gammaTemp, M, N); + + if (M <= 0) { + dX = at::native::empty_like(X, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + dgamma = at::native::zeros_like(gammaTemp, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + dbeta = at::native::zeros_like(gammaTemp, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); + } + + // construct the output tensor + dX = OpPreparation::ApplyTensor(X, std::get<0>(outputSizes)); + dgamma = OpPreparation::ApplyTensor(gammaTemp, std::get<1>(outputSizes)); + dbeta = OpPreparation::ApplyTensor(gammaTemp, std::get<2>(outputSizes)); + + // calculate the output result of the NPU + return layer_norm_backward_npu_nocheck(dX, dgamma, dbeta, dY, X, mean, variance, gammaTemp, M, N); +} + +std::tuple NPUNativeFunctions::native_layer_norm_backward( + const at::Tensor& dY, + const at::Tensor& X, + at::IntArrayRef normalized_shape, + const at::Tensor& mean, + const at::Tensor& variance, + const c10::optional& gamma, + const c10::optional& beta, + std::array output_mask) { + const int normalized_ndim = normalized_shape.size(); + const auto input_shape = X.sizes(); + const auto input_ndim = X.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + AT_ERROR(ss.str()); + } + + const int axis = input_ndim - normalized_ndim; + const int64_t M = std::accumulate( + input_shape.cbegin(), + input_shape.cbegin() + axis, + 1LL, + std::multiplies()); + const int64_t N = std::accumulate( + input_shape.cbegin() + axis, + input_shape.cend(), + 1LL, + std::multiplies()); + + return layer_norm_backward_npu_support(dY, X, mean, variance, gamma, M, N, output_mask); +} + +}} // namespace at_npu::native \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/LayerNormKernelNpu.cpp b/torch_npu/csrc/aten/ops/LayerNormKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..698f68aec58ec0e994c0fabed5bf4c425fbf1a67 --- /dev/null +++ b/torch_npu/csrc/aten/ops/LayerNormKernelNpu.cpp @@ -0,0 +1,166 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple layer_norm_npu_support( + const at::Tensor& input, + const c10::optional& weight_ex, + const c10::optional& bias_ex, + int64_t M, + int64_t N, + double eps) { + const at::Tensor& weight_ = c10::value_or_else(weight_ex, [] {return at::Tensor();}); + at::Tensor weight = weight_; + const at::Tensor& bias_ = c10::value_or_else(bias_ex, [] {return at::Tensor();}); + at::Tensor bias = bias_; + + DCHECK_EQ(input.numel(), M * N); + DCHECK(!weight.defined() || weight.numel() == N); + DCHECK(!bias.defined() || bias.numel() == N); + + at::Tensor Y = OpPreparation::ApplyTensor(input); + at::Tensor mean; + at::Tensor variance; + if (M < 0) { + mean = OpPreparation::ApplyTensorWithFormat({M}, input.options(), ACL_FORMAT_ND); + variance = OpPreparation::ApplyTensorWithFormat({M}, input.options(), ACL_FORMAT_ND); + } else { + int64_t numels = 1; + int64_t begin_dim = 0; + + // the output of mean and rstd is Multidimension + at::SmallVector reduceDims; + + // the input of weight is Multidimension + at::SmallVector weightDims; + for (int64_t i = 0; i < input.dim(); i++) { + numels *= input.size(i); + reduceDims.emplace_back(input.size(i)); + if(numels == M){ + begin_dim = i + 1; + while (++i < input.dim()) { + reduceDims.emplace_back(1); + weightDims.emplace_back(input.size(i)); + } + break; + } + } + + if (!weight.defined()) { + weight = at::ones(weightDims, input.options()); + } else if (!weight.sizes().equals(weightDims)) { + weight.resize_(weightDims); + } + + if (!bias.defined()) { + bias = at::zeros(weightDims, input.options()); + } else if (!bias.sizes().equals(weightDims)) { + bias.resize_(weightDims); + } + + mean = OpPreparation::ApplyTensorWithFormat(reduceDims, weight.options(), ACL_FORMAT_ND); + variance = OpPreparation::ApplyTensorWithFormat(reduceDims, weight.options(), ACL_FORMAT_ND); + + OpCommand cmd; + cmd.Name("LayerNorm") + .Input(input) + .Input(weight) + .Input(bias) + .Output(Y) + .Output(mean) + .Output(variance) + .Attr("begin_norm_axis", begin_dim) + .Attr("begin_params_axis", begin_dim) + .Attr("epsilon", static_cast(eps)) + .Run(); + + } + + mean = mean.reshape({M}); + variance = variance.reshape({M}); + + return std::tie(Y, mean, variance); +} + +std::tuple NPUNativeFunctions::native_layer_norm( + const at::Tensor& input, + at::IntArrayRef normalized_shape, + const c10::optional& weight_ex, + const c10::optional& bias_ex, + double eps) { + const at::Tensor& weight = c10::value_or_else(weight_ex, [] {return at::Tensor();}); + const at::Tensor& bias = c10::value_or_else(bias_ex, [] {return at::Tensor();}); + const int normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias.defined() || bias.sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias.sizes(), + " and normalized_shape = ", + normalized_shape); + + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + AT_ERROR(ss.str()); + } + + const int axis = input_ndim - normalized_ndim; + const int64_t M = std::accumulate( + input_shape.cbegin(), + input_shape.cbegin() + axis, + 1LL, + std::multiplies()); + const int64_t N = std::accumulate( + input_shape.cbegin() + axis, + input_shape.cend(), + 1LL, + std::multiplies()); + + const auto& X = input.is_contiguous() ? input : input.contiguous(); + const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous(); + const auto& beta = bias.is_contiguous() ? bias : bias.contiguous(); + return layer_norm_npu_support(X, gamma, beta, M, N, eps); +} + +}} \ No newline at end of file