From 94e2d529c73f6687570dd4196e65241bccc052ed Mon Sep 17 00:00:00 2001 From: wuxiankun Date: Mon, 24 Jan 2022 15:30:15 +0800 Subject: [PATCH 1/2] npu_scatter --- test/test_network_ops/test_scatterv1.py | 39 +++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 3 ++ .../csrc/aten/ops/ScatterV1KernelNpu.cpp | 50 +++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 test/test_network_ops/test_scatterv1.py create mode 100644 torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp diff --git a/test/test_network_ops/test_scatterv1.py b/test/test_network_ops/test_scatterv1.py new file mode 100644 index 0000000000..c894d46a9a --- /dev/null +++ b/test/test_network_ops/test_scatterv1.py @@ -0,0 +1,39 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import numpy as np +from torch_npu.testing.common_device_type import dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor +from torch_npu.testing.common_utils import TestCase, run_tests + + +class TestScatterV1(TestCase): + def npu_op_exec(self, input1, indices, updates, dim): + output = torch.npu_scatter(input1, indices, updates, dim) + output = output.to("cpu") + output = output.numpy() + return output + + def test_scatterv1(self, device): + input1 = torch.tensor([[1.6279, 0.1226], [0.9041, 1.0980]]).npu() + indices = torch.tensor([0, 1]).npu().to(torch.int32) + updates = torch.tensor([-1.1993, -1.5247]).npu() + dim = 0 + exoutput = torch.tensor([[-1.1993, 0.1226], [0.9041, -1.5247]]) + output = self.npu_op_exec(input1, indices, updates, dim) + self.assertRtolEqual(exoutput.numpy(), output) + +instantiate_device_type_tests(TestScatterV1, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 412fb9da7c..6b6eded8dc 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -283,6 +283,9 @@ custom: - func: npu_slice.out(Tensor self, int[] offsets, int[] size, *, Tensor(a!) out) -> Tensor(a!) - func: npu_indexing(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0) -> Tensor - func: npu_indexing.out(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0, *, Tensor(a!) out) -> Tensor(a!) + - func: npu_scatter(Tensor self, Tensor indices, Tensor updates, int dim) -> Tensor + variants: function, method + custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp b/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp new file mode 100644 index 0000000000..08b77d8998 --- /dev/null +++ b/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& scatter_out_npu_nocheck( + at::Tensor& output, + const at::Tensor& self, + const at::Tensor& indices, + const at::Tensor& updates, + int64_t dim) { + OpCommand cmd; + cmd.Name("ArgMaxGrad") + .Input(self) + .Input(indices) + .Input(updates) + .Output(output) + .Attr("dimension", dim) + .Run(); + + return output; +} + +at::Tensor NPUNativeFunctions::scatter_npu(const at::Tensor& self, const at::Tensor& indices, const at::Tensor& updates, int64_t dim) { + at::Tensor outputs = OpPreparation::ApplyTensor(self); + scatter_out_npu_nocheck(outputs, self, indices, updates, dim); + + return outputs; +} + +} +} -- Gitee From ddd797df9bc0bce132cb27a7579a9c7a53ac2556 Mon Sep 17 00:00:00 2001 From: wuxiankun Date: Mon, 24 Jan 2022 17:17:14 +0800 Subject: [PATCH 2/2] fix bug --- test/test_network_ops/test_scatterv1.py | 5 +++-- torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_network_ops/test_scatterv1.py b/test/test_network_ops/test_scatterv1.py index c894d46a9a..22d1546fbf 100644 --- a/test/test_network_ops/test_scatterv1.py +++ b/test/test_network_ops/test_scatterv1.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +import torch_npu import numpy as np -from torch_npu.testing.common_device_type import dtypes, instantiate_device_type_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests from torch_npu.testing.util_test import create_common_tensor from torch_npu.testing.common_utils import TestCase, run_tests class TestScatterV1(TestCase): def npu_op_exec(self, input1, indices, updates, dim): - output = torch.npu_scatter(input1, indices, updates, dim) + output = torch_npu.npu_scatter(input1, indices, updates, dim) output = output.to("cpu") output = output.numpy() return output diff --git a/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp b/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp index 08b77d8998..9341316f6b 100644 --- a/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp @@ -39,7 +39,7 @@ at::Tensor& scatter_out_npu_nocheck( return output; } -at::Tensor NPUNativeFunctions::scatter_npu(const at::Tensor& self, const at::Tensor& indices, const at::Tensor& updates, int64_t dim) { +at::Tensor NPUNativeFunctions::npu_scatter(const at::Tensor& self, const at::Tensor& indices, const at::Tensor& updates, int64_t dim) { at::Tensor outputs = OpPreparation::ApplyTensor(self); scatter_out_npu_nocheck(outputs, self, indices, updates, dim); -- Gitee