diff --git a/test/test_network_ops/test_group_norm.py b/test/test_network_ops/test_group_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..83fd8e02f635e7caf7f23aaf7124487b506f5d23 --- /dev/null +++ b/test/test_network_ops/test_group_norm.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import torch +import numpy as np +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.decorator import graph_mode +from torch_npu.testing.common_utils import create_common_tensor + +class TestGroupNormBackward(TestCase): + def cpu_op_exec(self, input1, num_groups): + output = torch.nn.functional.group_norm(input1, num_groups) + output_cpu = output.detach().numpy() + return output_cpu + + def npu_op_exec(self, input1, num_groups): + output = torch.nn.functional.group_norm(input1, num_groups).to("npu") + output = output.to("cpu").detach().numpy() + return output + + @graph_mode + def test_GroupNorm_default(self): + format_list = [0, 3] + shape_list = [[20, 6, 10, 10], [20, 2, 10, 10]] + dtype_list = [np.float32] + num_groups = [2] + shape_format = [ + [[i, j, k], m] for i in dtype_list for j in format_list for k in shape_list for m in num_groups + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + cpu_output = self.cpu_op_exec(cpu_input1, item[1]) + npu_output = self.npu_op_exec(npu_input1, item[1]) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_group_norm_backward.py b/test/test_network_ops/test_group_norm_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..485ede2d3df3aedcedf9de1c3f28f03a4b9b291e --- /dev/null +++ b/test/test_network_ops/test_group_norm_backward.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import torch +import numpy as np +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.decorator import graph_mode +from torch_npu.testing.common_utils import create_common_tensor + + +class TestGroupNormBackward(TestCase): + def cpu_op_exec(self, input1, num_groups): + input1.requires_grad_(True) + output = torch.nn.functional.group_norm(input1, num_groups) + output.backward(torch.ones_like(output)) + input_grad = input1.grad + + output = output.detach().numpy() + input_grad = input_grad.detach().numpy() + + return output, input_grad + + def npu_op_exec(self, input1, num_groups): + input1.requires_grad_(True) + output = torch.nn.functional.group_norm(input1, num_groups).to("npu") + output.backward(torch.ones_like(output)) + input_grad = input1.grad + + output = output.detach().cpu().float().numpy() + input_grad = input_grad.detach().cpu().float().numpy() + + return output, input_grad + + @graph_mode + def test_GroupNorm_Backward_default(self): + # create inputs and params + format_list = [0, 3] + shape_list = [[20, 6, 10, 10], [20, 2, 10, 10]] + dtype_list = [np.float32] + num_groups = [2] + shape_format = [ + [[i, j, k], m] for i in dtype_list for j in format_list for k in shape_list for m in num_groups + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + + # run exec + cpu_output, cpu_input_grad = self.cpu_op_exec(cpu_input1, item[1]) + npu_output, npu_input_grad = self.npu_op_exec(npu_input1, item[1]) + + # comparison results + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_input_grad, npu_input_grad) + + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/GroupNormBackwardKernel.cpp b/torch_npu/csrc/aten/ops/GroupNormBackwardKernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8fd68206feefe2e0645428b0084613bc9749e2c2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GroupNormBackwardKernel.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple NPUNativeFunctions::native_group_norm_backward( + const at::Tensor& dY, + const at::Tensor& X, + const at::Tensor& mean, + const at::Tensor& rstd, + const c10::optional& gamma_opt, + int64_t N, + int64_t C, + int64_t HxW, + int64_t group, + std::array grad_input_mask) { + TORCH_WARN_ONCE("Warning: kernel [native_group_norm_backward] is not supported by NPU currently." + "Now this kernel is running on CPU."); + at::Tensor dY_cpu = dY.to("cpu"); + at::Tensor X_cpu = X.to("cpu"); + at::Tensor mean_cpu = mean.to("cpu"); + at::Tensor rstd_cpu = rstd.to("cpu"); + const at::Tensor& gamma = c10::value_or_else(gamma_opt, [] {return at::Tensor();}); + at::Tensor gamma_opt_cpu = gamma.defined() ? gamma.to("cpu") : gamma; + + std::tuple result = at::native_group_norm_backward( + dY_cpu, X_cpu, mean_cpu, rstd_cpu, gamma_opt_cpu, N, C, HxW, group, grad_input_mask); + at::Tensor dX = std::get<0>(result); + at::Tensor dgamma = std::get<1>(result); + at::Tensor dbeta = std::get<2>(result); + + dX = dX.defined() ? dX.to(X.device()) : dX; + dgamma = dgamma.defined() ? dgamma.to(X.device()) : dgamma; + dbeta = dbeta.defined() ? dbeta.to(X.device()) : dbeta; + + return std::make_tuple(dX, dgamma, dbeta); +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/GroupNormKernel.cpp b/torch_npu/csrc/aten/ops/GroupNormKernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a821168f0b6bb289de9bdbb1fffb0d11f67115af --- /dev/null +++ b/torch_npu/csrc/aten/ops/GroupNormKernel.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple native_group_norm_out_npu( + at::Tensor& y, + at::Tensor& mean, + at::Tensor& variance, + at::Tensor& rstd, + const at::Tensor& X, + const c10::optional& gamma_opt, + const c10::optional& beta_opt, + int64_t num_groups, + double eps, + int64_t C) { + const at::Tensor& gamma_ = c10::value_or_else(gamma_opt, [] {return at::Tensor();}); + at::Tensor gamma = gamma_.defined() ? gamma_ : at::ones({C}, X.options()); + + const at::Tensor& beta_ = c10::value_or_else(beta_opt, [] {return at::Tensor();}); + at::Tensor beta = beta_.defined() ? beta_ : at::zeros({C}, X.options()); + + OpCommand cmd; + cmd.Name("GroupNorm") + .Input(X) + .Input(gamma) + .Input(beta) + .Output(y) + .Output(mean) + .Output(variance) + .Attr("num_groups", num_groups) + .Attr("eps", static_cast(eps)) + .Attr("is_training", true) + .Run(); + + rstd = 1.0 / (variance + eps).sqrt(); + return std::make_tuple(y, mean, rstd); +} + +std::tuple NPUNativeFunctions::native_group_norm( + const at::Tensor& X, + const c10::optional& gamma_opt, + const c10::optional& beta_opt, + int64_t N, + int64_t C, + int64_t HxW, + int64_t group, + double eps) { + at::Tensor result = OpPreparation::ApplyTensor(X); + at::Tensor mean = OpPreparation::ApplyTensor(X, {N * group}); + at::Tensor variance = OpPreparation::ApplyTensor(X, {N * group}); + at::Tensor rstd = OpPreparation::ApplyTensor(X, {N * group}); + return native_group_norm_out_npu(result, mean, variance, rstd, X, gamma_opt, beta_opt, group, eps, C); +} + +} // namespace native +} // namespace at_npu