From c380c47f5c72dc8d00da08cc259e4c4cc521806c Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Fri, 18 Mar 2022 13:37:14 +0800 Subject: [PATCH] =?UTF-8?q?EmbeddingBagBackward=E3=80=81EmbeddingBag=20?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit code check --- test/test_network_ops/test_embedding_bag.py | 45 ++++++++++ .../test_embedding_bag_backward.py | 65 ++++++++++++++ .../ops/EmbeddingBagBackwardKernelNpu.cpp | 57 ++++++++++++ .../csrc/aten/ops/EmbeddingBagKernelNpu.cpp | 87 +++++++++++++++++++ 4 files changed, 254 insertions(+) create mode 100644 test/test_network_ops/test_embedding_bag.py create mode 100644 test/test_network_ops/test_embedding_bag_backward.py create mode 100644 torch_npu/csrc/aten/ops/EmbeddingBagBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/EmbeddingBagKernelNpu.cpp diff --git a/test/test_network_ops/test_embedding_bag.py b/test/test_network_ops/test_embedding_bag.py new file mode 100644 index 00000000000..0a86761e65f --- /dev/null +++ b/test/test_network_ops/test_embedding_bag.py @@ -0,0 +1,45 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np +import torch.nn.functional as F + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestEmbeddingBag(TestCase): + def test_embedding_bag_1d(self): + cpu_weight = torch.rand(10, 3) + cpu_indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) + cpu_offsets = torch.tensor([0, 4]) + npu_weight = cpu_weight.npu() + npu_indices = cpu_indices.npu() + npu_offsets = cpu_offsets.npu() + cpu_output = F.embedding_bag(cpu_weight, cpu_indices, cpu_offsets).detach().numpy() + npu_output = F.embedding_bag(npu_weight, npu_indices, npu_offsets).cpu().detach().numpy() + self.assertRtolEqual(cpu_output, npu_output) + + def test_embedding_bag_2d(self): + cpu_weight = torch.rand(10, 3) + cpu_indices = torch.tensor([[1, 2, 4, 5, 4, 3, 2, 9], [1, 2, 4, 5, 4, 3, 2, 9]]) + npu_weight = cpu_weight.npu() + npu_indices = cpu_indices.npu() + cpu_output = F.embedding_bag(cpu_weight, cpu_indices).detach().numpy() + npu_output = F.embedding_bag(npu_weight, npu_indices).cpu().detach().numpy() + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_embedding_bag_backward.py b/test/test_network_ops/test_embedding_bag_backward.py new file mode 100644 index 00000000000..b1c0beb6b7e --- /dev/null +++ b/test/test_network_ops/test_embedding_bag_backward.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np +import torch.nn.functional as F + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestEmbeddingBackward(TestCase): + def cpu_op_exec(self, weight, indices): + weight.requires_grad_(True) + out = F.embedding(indices, weight, scale_grad_by_freq=True, padding_idx=37) + out.backward(torch.ones_like(out)) + grad_cpu = weight.grad + return out.detach().numpy(), grad_cpu.detach().numpy() + + def npu_op_exec(self, weight, indices): + weight.requires_grad_(True) + out = F.embedding(indices, weight, scale_grad_by_freq=True, padding_idx=37) + out.backward(torch.ones_like(out)) + out_npu = out.to("cpu") + grad_npu = weight.grad + grad_npu = grad_npu.to("cpu") + return out_npu.detach().numpy(), grad_npu.detach().numpy() + + def test_embedding_backward_shape_format_fp32(self): + format_list = [0] + shape_list1 = [[40, 32], [40, 1024], [40000, 1024], [33712, 1024]] + shape_list2 = [[40], [40], [3125], [64, 7]] + shape_format1 = [ + [np.float32, i, j] for i in format_list for j in shape_list1 + ] + shape_format2 = [ + [np.int64, i, j] for i in format_list for j in shape_list2 + ] + shape_format = [ + [i, j] for i in shape_format1 for j in shape_format2 + ] + for item in shape_format: + weight_cpu, weight_npu = create_common_tensor(item[0], 1, 1) + indices_cpu, indices_npu = create_common_tensor(item[1], 0, min(item[0][2][0:-1])) + + cpu_out, cpu_grad = self.cpu_op_exec(weight_cpu, indices_cpu) + npu_out, npu_grad = self.npu_op_exec(weight_npu, indices_npu) + + self.assertRtolEqual(cpu_out, npu_out) + self.assertRtolEqual(cpu_grad, npu_grad) + +if __name__ == "__main__": + run_tests() + diff --git a/torch_npu/csrc/aten/ops/EmbeddingBagBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/EmbeddingBagBackwardKernelNpu.cpp new file mode 100644 index 00000000000..756c601f06a --- /dev/null +++ b/torch_npu/csrc/aten/ops/EmbeddingBagBackwardKernelNpu.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2021 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +at::Tensor NPUNativeFunctions::_embedding_bag_backward( + const at::Tensor& grad, + const at::Tensor& indices, + const at::Tensor& offsets, + const at::Tensor& offset2bag, + const at::Tensor& bag_size, + const at::Tensor& maximum_indices, + int64_t num_weights, + bool scale_grad_by_freq, + int64_t mode, + bool sparse, + const c10::optional& per_sample_weights_opt) { + const at::Tensor& per_sample_weights = c10::value_or_else(per_sample_weights_opt, [] {return at::Tensor();}); + + at::Tensor grad_cpu = grad.to("cpu"); + at::Tensor indices_cpu = indices.to("cpu"); + at::Tensor offsets_cpu = offsets.to("cpu"); + at::Tensor offset2bag_cpu = offset2bag.to("cpu"); + at::Tensor bag_size_cpu = bag_size.to("cpu"); + at::Tensor maximum_indices_cpu = maximum_indices.to("cpu"); + at::Tensor per_sample_weights_cpu = per_sample_weights; + if (per_sample_weights_cpu.defined()) { + at::Tensor per_sample_weights_cpu = per_sample_weights_cpu.to("cpu"); + } + + at::Tensor result = at::_embedding_bag_backward( + grad_cpu, indices_cpu, offsets_cpu, offset2bag_cpu, bag_size_cpu, + maximum_indices_cpu, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights_cpu); + + result = at::native::sparse_to_dense(result); + result = result.to(indices.device()); + + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/EmbeddingBagKernelNpu.cpp b/torch_npu/csrc/aten/ops/EmbeddingBagKernelNpu.cpp new file mode 100644 index 00000000000..fba1bd56811 --- /dev/null +++ b/torch_npu/csrc/aten/ops/EmbeddingBagKernelNpu.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +c10::SmallVector _embedding_bag_npu_output_size( + const at::Tensor& weight, + const at::Tensor& indices, + const at::Tensor& offsets) { + c10::SmallVector outputSize = {}; + if (indices.dim() == 1) { + outputSize = {offsets.size(0), weight.size(1)}; + } else { + outputSize = {indices.size(0), weight.size(1)}; + } + return outputSize; +} + +string get_mode_str(bool mode) { + string modeStr = "mean"; + if (mode == 0) { + modeStr = "sum"; + } else if (mode == 1) { + modeStr = "mean"; + } else { + modeStr = "max"; + } + return modeStr; +} + +tuple NPUNativeFunctions::_embedding_bag( + const at::Tensor& weight, + const at::Tensor& indices, + const at::Tensor& offsets, + bool scale_grad_by_freq, + int64_t mode, + bool sparse, + const c10::optional& per_sample_weights_opt, + bool include_last_offset) { + const at::Tensor& per_sample_weights = c10::value_or_else(per_sample_weights_opt, [] {return at::Tensor();}); + + at::Tensor weight_cpu = weight.to("cpu").requires_grad_(); + at::Tensor indices_cpu = indices.to("cpu"); + at::Tensor offsets_cpu = offsets.to("cpu"); + at::Tensor per_sample_weights_cpu = per_sample_weights; + if (per_sample_weights_cpu.defined()) { + at::Tensor per_sample_weights_cpu = per_sample_weights_cpu.to("cpu"); + } + + auto result = at::native::_embedding_bag_cpu(weight_cpu, indices_cpu, offsets_cpu, scale_grad_by_freq, mode, sparse, per_sample_weights_cpu, include_last_offset); + + at::Tensor output = std::get<0>(result).to(weight.device()); + at::Tensor offset2bag = std::get<1>(result).to(weight.device()); + at::Tensor bag_size = std::get<2>(result).to(weight.device()); + at::Tensor max_indices = std::get<3>(result).to(weight.device()); + + return std::tie(output, offset2bag, bag_size, max_indices); +} +tuple NPUNativeFunctions::_embedding_bag_forward_only( + const at::Tensor& weight, + const at::Tensor& indices, + const at::Tensor& offsets, + bool scale_grad_by_freq, + int64_t mode, + bool sparse, + const c10::optional& per_sample_weights_opt, + bool include_last_offset) { + return _embedding_bag(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights_opt, include_last_offset); +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee