From b9e50ba661279b02108e984ff65363c29b9d8257 Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Sat, 29 Jan 2022 15:16:54 +0800 Subject: [PATCH] =?UTF-8?q?BinaryCrossEntropyBackward=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix code --- .../test_binary_cross_entropy_backward.py | 145 ++++++++++++++++++ .../BinaryCrossEntropyBackwardKernelNpu.cpp | 71 +++++++++ 2 files changed, 216 insertions(+) create mode 100644 test/test_network_ops/test_binary_cross_entropy_backward.py create mode 100644 torch_npu/csrc/aten/ops/BinaryCrossEntropyBackwardKernelNpu.cpp diff --git a/test/test_network_ops/test_binary_cross_entropy_backward.py b/test/test_network_ops/test_binary_cross_entropy_backward.py new file mode 100644 index 0000000000..3f377d21d1 --- /dev/null +++ b/test/test_network_ops/test_binary_cross_entropy_backward.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import torch.nn as nn +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestBinaryCrossEntropyBackward(TestCase): + def generate_data(self, min_val, max_val, shape, dtype): + x = np.random.uniform(min_val, max_val, shape).astype(dtype) + x = torch.from_numpy(x) + return x + + def cpu_op_exec(self, input1, target, weight, reduction="mean"): + float16flag = False + if input1.dtype == torch.float16: + input1 = input1.to(torch.float32) + target = target.to(torch.float32) + float16flag = True + if weight is not None: + weight = weight.to(torch.float32) + input1.requires_grad_(True) + cpu_output = torch.nn.functional.binary_cross_entropy(input1, target, weight=weight, size_average=None, reduce=None, reduction=reduction) + input_cpu = cpu_output.detach().numpy() + if reduction == 'none': + w = torch.ones_like(input1) + cpu_output.backward(w) + else: + cpu_output.backward() + res = input1.grad + res = res.numpy() + if float16flag: + input_cpu = input_cpu.astype(np.float16) + res = res.astype(np.float16) + return input_cpu, res + + def npu_op_exec(self, input1, target, weight, reduction="mean"): + input1 = input1.npu() + target = target.npu() + if weight is not None: + weight = weight.npu() + input1.requires_grad_(True) + npu_output = torch.nn.functional.binary_cross_entropy(input1, target, weight=weight, size_average=None, reduce=None, reduction=reduction) + npu_input = npu_output.cpu() + npu_input = npu_input.detach().numpy() + if reduction == 'none': + w = torch.ones_like(input1) + npu_output.backward(w) + else: + npu_output.backward() + res = input1.grad.cpu() + res = res.numpy() + return npu_input, res + + def test_binary_cross_entropy_backward_float16(self, device): + shape_list = [(10, 64)] + reduction_list = ["none", "mean", "sum"] + shape_format = [ + [np.float16, i, j] for i in shape_list for j in reduction_list + ] + for item in shape_format: + input1 = self.generate_data(0, 1, item[1], item[0]) + target = self.generate_data(0, 2, item[1], item[0]) + cpu_input1 = copy.deepcopy(input1) + cpu_target = copy.deepcopy(target) + weight = None + cpu_output, cpu_grad = self.cpu_op_exec(cpu_input1, cpu_target, weight, reduction=item[2]) + npu_output, npu_grad = self.npu_op_exec(input1, target, weight, reduction=item[2]) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad, npu_grad) + + def test_binary_cross_entropy_backward_float32(self, device): + shape_list = [(10, 64)] + reduction_list = ["none", "mean", "sum"] + shape_format = [ + [np.float32, i, j] for i in shape_list for j in reduction_list + ] + for item in shape_format: + input1 = self.generate_data(0, 1, item[1], item[0]) + target = self.generate_data(0, 2, item[1], item[0]).int().to(torch.float32) + cpu_input1 = copy.deepcopy(input1) + cpu_target = copy.deepcopy(target) + weight = None + cpu_output, cpu_grad = self.cpu_op_exec(cpu_input1, cpu_target, weight, reduction=item[2]) + npu_output, npu_grad = self.npu_op_exec(input1, target, weight, reduction=item[2]) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad, npu_grad) + + def test_binary_cross_entropy_backward_with_weight_float16(self, device): + shape_list = [(10, 64)] + reduction_list = ["none", "mean", "sum"] + shape_format = [ + [np.float16, i, j] for i in shape_list for j in reduction_list + ] + for item in shape_format: + input1 = self.generate_data(0, 1, item[1], item[0]) + target = self.generate_data(0, 2, item[1], item[0]) + weight = self.generate_data(0, 1, item[1], item[0]) + cpu_input1 = copy.deepcopy(input1) + cpu_target = copy.deepcopy(target) + cpu_weight = copy.deepcopy(weight) + cpu_output, cpu_grad = self.cpu_op_exec(cpu_input1, cpu_target, cpu_weight, reduction=item[2]) + npu_output, npu_grad = self.npu_op_exec(input1, target, weight, reduction=item[2]) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad, npu_grad) + + def test_binary_cross_entropy_backward_with_weight_float32(self, device): + shape_list = [(10, 64)] + reduction_list = ["none", "mean", "sum"] + shape_format = [ + [np.float32, i, j] for i in shape_list for j in reduction_list + ] + for item in shape_format: + input1 = self.generate_data(0, 1, item[1], item[0]) + target = self.generate_data(0, 1, item[1], item[0]) + weight = self.generate_data(0, 1, item[1], item[0]) + cpu_input1 = copy.deepcopy(input1) + cpu_target = copy.deepcopy(target) + cpu_weight = copy.deepcopy(weight) + cpu_output, cpu_grad = self.cpu_op_exec(cpu_input1, cpu_target, cpu_weight, reduction=item[2]) + npu_output, npu_grad = self.npu_op_exec(input1, target, weight, reduction=item[2]) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad, npu_grad) + +instantiate_device_type_tests(TestBinaryCrossEntropyBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/BinaryCrossEntropyBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/BinaryCrossEntropyBackwardKernelNpu.cpp new file mode 100644 index 0000000000..b32cb34c7c --- /dev/null +++ b/torch_npu/csrc/aten/ops/BinaryCrossEntropyBackwardKernelNpu.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" + +namespace at_npu { +namespace native { + +at::Tensor& binary_cross_entropy_backward_out_npu_nocheck( + at::Tensor& gradInput, + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + const at::Tensor& weight, + int64_t reduction) { + at::Tensor weightTensor = weight.defined() ? weight : + at::ones(self.sizes(), self.options()); + std::string reductionStr = CalcuOpUtil::get_reduction_str(reduction); + OpCommand cmd; + cmd.Name("BinaryCrossEntropyGrad") + .Input(self) + .Input(target) + .Input(grad_output) + .Input(weightTensor) + .Output(gradInput) + .Attr("reduction", reductionStr) + .Run(); + return gradInput; +} + +at::Tensor& NPUNativeFunctions::binary_cross_entropy_backward_out( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + at::Tensor& gradInput) { + const at::Tensor& weight = c10::value_or_else(weight_opt, [] {return at::Tensor();}); + binary_cross_entropy_backward_out_npu_nocheck(gradInput, grad_output, self, target, weight, reduction); + return gradInput; +} + +at::Tensor NPUNativeFunctions::binary_cross_entropy_backward( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + const c10::optional& weight_opt, + int64_t reduction) { + const at::Tensor& weight = c10::value_or_else(weight_opt, [] {return at::Tensor();}); + at::Tensor gradInput = OpPreparation::ApplyTensor(self); + // calculate the output result of the NPU + binary_cross_entropy_backward_out_npu_nocheck(gradInput, grad_output, self, target, weight, reduction); + return gradInput; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee