diff --git a/test/test_network_ops/test_renorm.py b/test/test_network_ops/test_renorm.py new file mode 100644 index 0000000000000000000000000000000000000000..338517fe5e27eb6d3c7c77b2b3050fa6ca5fd7f6 --- /dev/null +++ b/test/test_network_ops/test_renorm.py @@ -0,0 +1,272 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestRenorm(TestCase): + def generate_data(self, min_d, max_d, shape, dtype): + input_x = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input = torch.from_numpy(input_x) + return npu_input + + def get_p0_result_cpu(self, input_x, dim, maxnorm=1.0): + input_x = input_x.numpy() + dims = len(input_x.shape) + shape_list = [] + for i in range(dims): + if (i != dim): + shape_list = shape_list + [i] + shape_list = tuple(shape_list) + tmp = (input_x != 0) + N = np.sum(tmp, shape_list, keepdims=True) + N = np.where(N > maxnorm, maxnorm / (N + 1e-7), 1.0) + output = input_x * N + return output + + def cpu_op_exec(self, input_x, p, dim, maxnorm): + if (p == 0): + output = self.get_p0_result_cpu(input_x, dim, maxnorm) + else: + output = torch.renorm(input_x, p, dim, maxnorm) + output = output.numpy() + return output.astype(np.float32) + + def npu_op_exec(self, input_x, p, dim, maxnorm): + input1 = input_x.to("npu") + output = torch.renorm(input1, p, dim, maxnorm) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self, input_x, p, dim, maxnorm, output_y): + input_x = input_x.to("npu") + output_y = output_y.to("npu") + torch.renorm(input_x, p, dim, maxnorm, out=output_y) + output_y = output_y.to("cpu") + output_y = output_y.numpy() + return output_y + + def npu_op_exec_inplace(self, input_x, p, dim, maxnorm): + input_x = input_x.to("npu") + input_x.renorm_(p, dim, maxnorm) + output = input_x.to("cpu") + output = output.numpy() + return output + + def test_renorm_3_3_4_0_1(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 4, 0, 1) + npu_output1 = self.npu_op_exec(input_x1, 4, 0, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_1_1_1(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 1, 1, 1) + npu_output1 = self.npu_op_exec(input_x1, 1, 1, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_0_0_1_float16(self, device): + input_x1 = self.generate_data(-10, 10, (3, 3), np.float16) + input_x1_cpu = input_x1.float() + cpu_output1 = self.cpu_op_exec(input_x1_cpu, 0, 0, 1).astype(np.float16) + npu_output1 = self.npu_op_exec(input_x1, 0, 0, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_0_0_1(self, device): + input_x1 = self.generate_data(-10, 10, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 0, 0, 1) + npu_output1 = self.npu_op_exec(input_x1, 0, 0, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_4_0_1_float16(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float16) + input_x1_cpu = input_x1.float() + cpu_output1 = self.cpu_op_exec(input_x1_cpu, 4, 0, 1).astype(np.float16) + npu_output1 = self.npu_op_exec(input_x1, 4, 0, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_1_1_1_float16(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float16) + input_x1_cpu = input_x1.float() + cpu_output1 = self.cpu_op_exec(input_x1_cpu, 1, 1, 1).astype(np.float16) + npu_output1 = self.npu_op_exec(input_x1, 1, 1, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_1_0_1(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 1, 0, 1) + npu_output1 = self.npu_op_exec(input_x1, 1, 0, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_1_1(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 3, 1, 1) + npu_output1 = self.npu_op_exec(input_x1, 3, 1, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_2_2_1(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 2, 2, 1) + npu_output1 = self.npu_op_exec(input_x1, 2, 2, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_2_0_1(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 2, 0, 1) + npu_output1 = self.npu_op_exec(input_x1, 2, 0, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_3_3_1(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 3, 3, 1) + npu_output1 = self.npu_op_exec(input_x1, 3, 3, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_3_4_4_1(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 4, 4, 1) + npu_output1 = self.npu_op_exec(input_x1, 4, 4, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_4_0_1_out(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + output_y = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 4, 0, 1) + npu_output1 = self.npu_op_exec_out(input_x1, 4, 0, 1, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_1_1_1_out(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + output_y = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 1, 1, 1) + npu_output1 = self.npu_op_exec_out(input_x1, 1, 1, 1, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_1_0_1_out(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + output_y = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 1, 0, 1) + npu_output1 = self.npu_op_exec_out(input_x1, 1, 0, 1, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_1_1_out(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3), np.float32) + output_y = self.generate_data(-1, 1, (3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 3, 1, 1) + npu_output1 = self.npu_op_exec_out(input_x1, 3, 1, 1, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_30_40_50_2_1_1_out_fp16(self, device): + input_x1 = self.generate_data(-1, 1, (30, 40, 50), np.float16) + output_y = self.generate_data(-1, 1, (30, 40, 50), np.float16) + input_cpu = input_x1.float() + cpu_output1 = self.cpu_op_exec(input_cpu, 2, 1, 1) + cpu_output1 = cpu_output1.astype(np.float16) + npu_output1 = self.npu_op_exec_out(input_x1, 2, 1, 1, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_30_40_50_2_0_2_out_fp16(self, device): + input_x1 = self.generate_data(-1, 1, (30, 40, 50), np.float16) + output_y = self.generate_data(-1, 1, (30, 40, 50), np.float16) + input_cpu = input_x1.float() + cpu_output1 = self.cpu_op_exec(input_cpu, 2, 0, 2) + cpu_output1 = cpu_output1.astype(np.float16) + npu_output1 = self.npu_op_exec_out(input_x1, 2, 0, 2, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_2_2_1_out(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3), np.float32) + output_y = self.generate_data(-1, 1, (3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 2, 2, 1) + npu_output1 = self.npu_op_exec_out(input_x1, 2, 2, 1, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_2_0_1_out(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3, 3), np.float32) + output_y = self.generate_data(-1, 1, (3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 2, 0, 1) + npu_output1 = self.npu_op_exec_out(input_x1, 2, 0, 1, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_3_3_1_out(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3, 3), np.float32) + output_y = self.generate_data(-1, 1, (3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 3, 3, 1) + npu_output1 = self.npu_op_exec_out(input_x1, 3, 3, 1, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_3_4_4_1_out(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3, 3, 3), np.float32) + output_y = self.generate_data(-1, 1, (3, 3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 4, 4, 1) + npu_output1 = self.npu_op_exec_out(input_x1, 4, 4, 1, output_y) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_4_0_1_inplace(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 4, 0, 1) + npu_output1 = self.npu_op_exec_inplace(input_x1, 4, 0, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_1_1_1_inplace(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 1, 1, 1) + npu_output1 = self.npu_op_exec_inplace(input_x1, 1, 1, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_1_0_1_inplace(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 1, 0, 1) + npu_output1 = self.npu_op_exec_inplace(input_x1, 1, 0, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_1_1_inplace(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 3, 1, 1) + npu_output1 = self.npu_op_exec_inplace(input_x1, 3, 1, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_2_2_1_inplace(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 2, 2, 1) + npu_output1 = self.npu_op_exec_inplace(input_x1, 2, 2, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_2_0_1_inplace(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 2, 0, 1) + npu_output1 = self.npu_op_exec_inplace(input_x1, 2, 0, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_3_3_1_inplace(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 3, 3, 1) + npu_output1 = self.npu_op_exec_inplace(input_x1, 3, 3, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_renorm_3_3_3_3_3_4_4_1_inplace(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 4, 4, 1) + npu_output1 = self.npu_op_exec_inplace(input_x1, 4, 4, 1) + self.assertRtolEqual(cpu_output1, npu_output1) + +instantiate_device_type_tests(TestRenorm, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/RenormKernelNpu.cpp b/torch_npu/csrc/aten/ops/RenormKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..96572651f5ed279ebfd33409e1936cf44cdf216f --- /dev/null +++ b/torch_npu/csrc/aten/ops/RenormKernelNpu.cpp @@ -0,0 +1,147 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +c10::SmallVector renorm_npu_output_size( + const at::Tensor& self, + int64_t dim) { + c10::SmallVector outSize; + for(int64_t i=0; i < self.dim(); i++) { + if(i != dim) { + outSize.emplace_back(1); + } else { + outSize.emplace_back(self.sizes()[i]); + } + } + return outSize; +} + +at::Tensor& renorm_compute( + at::Tensor& result, + const at::Tensor& self, + at::Scalar p, + int64_t dim, + at::Scalar maxnorm) { + float p_value = CalcuOpUtil::get_scalar_float_value(p); + float maxnorm_value = CalcuOpUtil::get_scalar_float_value(maxnorm); + + OpCommand cmd; + cmd.Name("Renorm") + .Input(self) + .Output(result) + .Attr("p", p_value) + .Attr("maxnorm", maxnorm_value) + .Attr("dim", dim) + .Run(); + return result; +} + +at::Tensor& renorm_out_nocheck( + at::Tensor& result, + const at::Tensor& self, + at::Scalar p, + int64_t dim, + at::Scalar maxnorm) { + auto ori_type = self.scalar_type(); + if(ori_type != c10::ScalarType::Half && ori_type != c10::ScalarType::Float) { + AT_ERROR("Renorm only support float16 or float32 type."); + } + if(result.scalar_type() != ori_type) { + AT_ERROR("result's type must be equal to input's."); + } + dim = CalcuOpUtil::make_wrap_dim(dim, self.dim()); + auto outputSize = renorm_npu_output_size(self, dim); + at::Tensor result_bak = OpPreparation::ApplyTensorWithSizes( + outputSize, + self.options().dtype(at::kFloat)); + if(ori_type == c10::ScalarType::Half) { + at::Tensor self_no_name = self.rename(c10::nullopt); + at::Tensor result_no_name = result.rename(c10::nullopt); + self_no_name = NPUNativeFunctions::npu_dtype_cast(self_no_name, c10::ScalarType::Float); + result_no_name = NPUNativeFunctions::npu_dtype_cast(result_no_name, c10::ScalarType::Float); + renorm_compute( + result_bak, + self_no_name, + p, + dim, + maxnorm); + + at::Tensor result_broadcast = NPUNativeFunctions::npu_broadcast(result_bak, self.sizes()); + at::mul_out(result_no_name, result_broadcast, self_no_name); + NPUNativeFunctions::npu_dtype_cast_(result, result_no_name); + } else { + renorm_compute( + result_bak, + self, + p, + dim, + maxnorm); + + at::Tensor result_broadcast = NPUNativeFunctions::npu_broadcast(result_bak, self.sizes()); + at::mul_out(result, result_broadcast, self); + } + return result; +} + +at::Tensor& NPUNativeFunctions::renorm_out( + const at::Tensor& self, + at::Scalar p, + int64_t dim, + at::Scalar maxnorm, + at::Tensor& result) { + auto outputSize = input_same_output_size(self); + OpPreparation::CheckOut( + {self}, + result, + self, + outputSize); + + c10::SmallVector inputs = {self}; + c10::SmallVector outputs = {self}; + CalcuOpUtil::check_memory_over_laps(inputs, outputs); + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor checkresult = renorm_out_nocheck(contiguousSelf, contiguousSelf, p, dim, maxnorm); + NpuUtils::format_fresh_view(result, checkresult); + } else { + renorm_out_nocheck(result, self, p, dim, maxnorm); + } + + return result; +} + +at::Tensor NPUNativeFunctions::renorm(const at::Tensor& self, at::Scalar p, int64_t dim, at::Scalar maxnorm) { + // calculate the output size + auto outputSize = input_same_output_size(self); + at::Tensor result = OpPreparation::ApplyTensorWithSizes( + outputSize, + self.options()); + + return renorm_out_nocheck(result, self, p, dim, maxnorm); +} + +at::Tensor& NPUNativeFunctions::renorm_(at::Tensor& self, at::Scalar p, int64_t dim, at::Scalar maxnorm) { + + return renorm_out(self, p, dim, maxnorm, self); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file