diff --git a/test/test_network_ops/test_index_put.py b/test/test_network_ops/test_index_put.py new file mode 100644 index 0000000000000000000000000000000000000000..deb7c643f93acf39d98772a4a0d1df8900c8c1e9 --- /dev/null +++ b/test/test_network_ops/test_index_put.py @@ -0,0 +1,144 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestIndexPut(TestCase): + def cpu_op_exec(self, input1, indices, value): + output = input1.index_put(indices, value) + output = output.numpy() + return output + + def npu_op_exec(self, input1, indices, value): + output = input1.index_put(indices, value) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_inp_exec(self, input1, indices, value): + input1.index_put_(indices, value) + output = input1.numpy() + return output + + def npu_op_inp_exec(self, input1, indices, value): + input1.index_put_(indices, value) + input1 = input1.to("cpu") + output = input1.numpy() + return output + + def case_exec(self, shape): + cpu_indices2 = [] + npu_indices2 = [] + for item in shape: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + for i in range(1, 3): + cpu_indices1, npu_indices1 = create_common_tensor( + item[1], 1, 5) + cpu_indices2.append(cpu_indices1) + npu_indices2.append(npu_indices1) + cpu_value, npu_value = create_common_tensor(item[2], 1, 100) + cpu_output = self.cpu_op_exec(cpu_input, cpu_indices2, cpu_value) + npu_output = self.npu_op_exec(npu_input, npu_indices2, npu_value) + self.assertEqual(cpu_output, npu_output) + + def case_exec_fp16(self, shape): + cpu_indices3 = [] + npu_indices3 = [] + for item in shape: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_input = cpu_input.to(torch.float32) + for i in range(1, 3): + cpu_indices1, npu_indices1 = create_common_tensor( + item[1], 1, 5) + cpu_indices3.append(cpu_indices1) + npu_indices3.append(npu_indices1) + cpu_value, npu_value = create_common_tensor(item[2], 1, 100) + cpu_value = cpu_value.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input, cpu_indices3, cpu_value) + npu_output = self.npu_op_exec(npu_input, npu_indices3, npu_value) + cpu_output = cpu_output.astype(np.float16) + self.assertEqual(cpu_output, npu_output) + + def case_inp_exec(self, shape): + cpu_indices4 = [] + npu_indices4 = [] + for item in shape: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + for i in range(1, 3): + cpu_indices1, npu_indices1 = create_common_tensor( + item[1], 1, 5) + cpu_indices4.append(cpu_indices1) + npu_indices4.append(npu_indices1) + cpu_value, npu_value = create_common_tensor(item[2], 1, 100) + cpu_output = self.cpu_op_inp_exec( + cpu_input, cpu_indices4, cpu_value) + npu_output = self.npu_op_inp_exec( + npu_input, npu_indices4, npu_value) + self.assertEqual(cpu_output, npu_output) + + def case_inp_exec_fp16(self, shape): + cpu_indices5 = [] + npu_indices5 = [] + for item in shape: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_input = cpu_input.to(torch.float32) + for i in range(1, 3): + cpu_indices1, npu_indices1 = create_common_tensor( + item[1], 1, 5) + cpu_indices5.append(cpu_indices1) + npu_indices5.append(npu_indices1) + cpu_value, npu_value = create_common_tensor(item[2], 1, 100) + cpu_value = cpu_value.to(torch.float32) + cpu_output = self.cpu_op_inp_exec( + cpu_input, cpu_indices5, cpu_value) + npu_output = self.npu_op_inp_exec( + npu_input, npu_indices5, npu_value) + cpu_output = cpu_output.astype(np.float16) + self.assertEqual(cpu_output, npu_output) + + def test_index_put_shape_format_fp32(self, device): + format_list = [0] + shape_list = [(5, 6)] + shape_format = [[[np.float32, i, j], [np.int64, 0, [1, 2]], [ + np.float32, 0, [1, 2]]] for i in format_list for j in shape_list] + self.case_exec(shape_format) + self.case_inp_exec(shape_format) + + def test_index_put_shape_format_fp16(self, device): + format_list = [0] + shape_list = [(5, 6)] + shape_format = [[[np.float16, i, j], [np.int64, 0, [1, 2]], [ + np.float16, 0, [1, 2]]] for i in format_list for j in shape_list] + self.case_exec_fp16(shape_format) + self.case_inp_exec_fp16(shape_format) + + def test_index_put_null(self, device): + cpu_input1 = torch.rand(2, 2) + cpu_input2 = torch.rand(2, 2) + cpu_mask_index = torch.tensor([[False, False], [False, False]]) + npu_mask_index = cpu_mask_index.to("npu") + npu_input1 = cpu_input1.to("npu") + npu_input2 = cpu_input2.to("npu") + cpu_input1[cpu_mask_index] = cpu_input2.detach()[cpu_mask_index] + npu_input1[npu_mask_index] = npu_input2.detach()[npu_mask_index] + self.assertEqual(cpu_input1, npu_input1.to("cpu")) + +instantiate_device_type_tests(TestIndexPut, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/IndexPutKernelNpu.cpp b/torch_npu/csrc/aten/ops/IndexPutKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..88e655d5e257e2d1ea910eda9edf59c0ee6d9468 --- /dev/null +++ b/torch_npu/csrc/aten/ops/IndexPutKernelNpu.cpp @@ -0,0 +1,117 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& index_put_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::TensorList& indices, + const at::Tensor& value, + bool accumulate) { + if (value.numel() == 0) { + return result; + } + + at::SmallVector masks; + std::vector allDefinedIndices; + for (int64_t i = 0; i < indices.size(); i++) { + if (indices[i].defined()) { + masks.emplace_back(1); + allDefinedIndices.emplace_back(indices[i]); + } else { + masks.emplace_back(0); + } + } + + auto masksTensor = CalcuOpUtil::copy_tensor_host_to_device( + at::from_blob(masks.data(), {masks.size()}, dtype(at::ScalarType::Long))); + + at::Tensor tempSelf = self; + at::Tensor tempValue = value; + if (self.scalar_type() == at::ScalarType::Half) { + tempSelf = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Float); + tempValue = NPUNativeFunctions::npu_dtype_cast(value, at::ScalarType::Float); + result = NPUNativeFunctions::npu_dtype_cast(result, at::ScalarType::Float); + } + + OpCommand cmd; + cmd.Name("IndexPut") + .Input(tempSelf) + .Input(tempValue) + .Input(masksTensor) + .Inputs(allDefinedIndices) + .Output(result) + .Attr("accumulate", accumulate) + .Run(); + + if (self.scalar_type() == at::ScalarType::Half) { + result = NPUNativeFunctions::npu_dtype_cast(result, at::ScalarType::Half); + } + return result; +} + +at::Tensor NPUNativeFunctions::index_put( + const at::Tensor& self, + const c10::List> & indices, + const at::Tensor& value, + bool accumulate) { + return self.clone(at::MemoryFormat::Contiguous) + .index_put_(indices, value, accumulate); +} + +at::Tensor& NPUNativeFunctions::index_put_( + at::Tensor& self, + const c10::List> & indices, + const at::Tensor& value, + const bool accumulate) { + return at::_index_put_impl_( + self, indices, value, accumulate, false); +} + +at::Tensor& NPUNativeFunctions::_index_put_impl_( + at::Tensor& self, + const c10::List> & indices, + const at::Tensor& value, + const bool accumulate, + const bool unsafe) { + at::native::checkIndexTensorTypes(indices); + auto indices_cast = at::native::expandTensors(self, indices); + + OpPreparation::CastBackToOriFormat(self); + at::Tensor valueCopy = value; + at::Tensor selfCopy = self; + OpPreparation::CastBackToOriFormat(valueCopy); + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(selfCopy); + at::Tensor result = index_put_nocheck( + contiguousSelf, contiguousSelf, indices_cast, valueCopy, accumulate); + self.copy_(result); + } else { + index_put_nocheck(selfCopy, selfCopy, indices_cast, valueCopy, accumulate); + self.copy_(selfCopy); + } + return self; +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp b/torch_npu/csrc/aten/ops/normalization/BatchNormBackwardElemtKernelNpu.cpp similarity index 100% rename from torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp rename to torch_npu/csrc/aten/ops/normalization/BatchNormBackwardElemtKernelNpu.cpp