From daf61545cb39fef8000edcb1efb980f3b2224f7d Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Wed, 26 Jan 2022 17:04:49 +0800 Subject: [PATCH] Add EmbeddingDenseBackward Operator --- .../test_embeddingdensebackward.py | 68 +++++++++++++++++++ .../ops/EmbeddingDenseBackwardKernelNpu.cpp | 65 ++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 test/test_network_ops/test_embeddingdensebackward.py create mode 100644 torch_npu/csrc/aten/ops/EmbeddingDenseBackwardKernelNpu.cpp diff --git a/test/test_network_ops/test_embeddingdensebackward.py b/test/test_network_ops/test_embeddingdensebackward.py new file mode 100644 index 00000000000..ea27b1a9baf --- /dev/null +++ b/test/test_network_ops/test_embeddingdensebackward.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np +import torch.nn.functional as F + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestEmbeddingDenseBackward(TestCase): + def cpu_op_exec(self, weight, indices): + weight.requires_grad_(True) + out = F.embedding(indices, weight, scale_grad_by_freq=False, padding_idx=-1) + out.backward(torch.ones_like(out)) + grad_cpu = weight.grad + return out.detach().numpy(), grad_cpu.detach().numpy() + + def npu_op_exec(self, weight, indices): + weight.requires_grad_(True) + out = F.embedding(indices, weight, scale_grad_by_freq=False, padding_idx=-1) + out.backward(torch.ones_like(out)) + out_npu = out.to("cpu") + grad_npu = weight.grad + grad_npu = grad_npu.to("cpu") + return out_npu.detach().numpy(), grad_npu.detach().numpy() + + def test_embedding_dense_backward_shape_format_fp32(self, device): + format_list = [0] + shape_list1 = [[40, 32], [40, 1024], [40000, 1024], [33712, 1024]] + shape_list2 = [[40], [40], [40000], [33712]] + shape_format1 = [ + [np.float32, i, j] for i in format_list for j in shape_list1 + ] + shape_format2 = [ + [np.int64, i, j] for i in format_list for j in shape_list2 + ] + shape_format = [ + [shape_format1[i], shape_format2[i]] for i in range(len(shape_list1)) + ] + for item in shape_format: + weight_cpu, weight_npu = create_common_tensor(item[0], 1, 1) + indices_cpu, indices_npu = create_common_tensor(item[1], 0, min(item[0][2][0:-1])) + + cpu_out, cpu_grad = self.cpu_op_exec(weight_cpu, indices_cpu) + npu_out, npu_grad = self.npu_op_exec(weight_npu, indices_npu) + + self.assertRtolEqual(cpu_out, npu_out) + self.assertRtolEqual(cpu_grad, npu_grad) + + +instantiate_device_type_tests(TestEmbeddingDenseBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() + diff --git a/torch_npu/csrc/aten/ops/EmbeddingDenseBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/EmbeddingDenseBackwardKernelNpu.cpp new file mode 100644 index 00000000000..bfdbdae87da --- /dev/null +++ b/torch_npu/csrc/aten/ops/EmbeddingDenseBackwardKernelNpu.cpp @@ -0,0 +1,65 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +namespace { +at::Tensor& embedding_dense_backward_nocheck( + at::Tensor& result, + const at::Tensor& grad_output, + const at::Tensor& indices, + int64_t num_weights, + int64_t padding_idx, + bool scale_grad_by_freq) { + // indices must be int64 in pytorch, but npu can only support int32 + auto indices_int32 = NPUNativeFunctions::npu_dtype_cast(indices, at::kInt); + + OpCommand cmd; + cmd.Name("EmbeddingDenseGrad") + .Input(grad_output) + .Input(indices_int32) + .Attr("num_weights", num_weights) + .Attr("padding_idx", padding_idx) + .Attr("scale_grad_by_freq", scale_grad_by_freq) + .Output(result) + .Run(); + return result; +} +} // namespace + +at::Tensor NPUNativeFunctions::embedding_dense_backward( + const at::Tensor& grad_weight, + const at::Tensor& indices, + int64_t num_weights, + int64_t padding_idx, + bool scale_grad_by_freq) { + // calculate the output size + auto outputSize = {num_weights, grad_weight.size(-1)}; + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(grad_weight, outputSize); + + // calculate the output resugt of the NPU + embedding_dense_backward_nocheck( + result, grad_weight, indices, num_weights, padding_idx, scale_grad_by_freq); + + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee