From b1d536328c1c84a4578ec1d4569ab28b075bfb91 Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Fri, 28 Jan 2022 14:17:34 +0800 Subject: [PATCH] =?UTF-8?q?BinaryCrossEntropyWithLogits=20=E7=AE=97?= =?UTF-8?q?=E5=AD=90=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix code --- .../test_binary_cross_entropy_with_logits.py | 209 ++++++++++++++++++ .../BinaryCrossEntropyWithLogitsKernelNpu.cpp | 71 ++++++ 2 files changed, 280 insertions(+) create mode 100644 test/test_network_ops/test_binary_cross_entropy_with_logits.py create mode 100644 torch_npu/csrc/aten/ops/BinaryCrossEntropyWithLogitsKernelNpu.cpp diff --git a/test/test_network_ops/test_binary_cross_entropy_with_logits.py b/test/test_network_ops/test_binary_cross_entropy_with_logits.py new file mode 100644 index 0000000000..e93c3db305 --- /dev/null +++ b/test/test_network_ops/test_binary_cross_entropy_with_logits.py @@ -0,0 +1,209 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import torch +import torch_npu +import torch.nn as nn +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestBinaryCrossEntropyWithLogits(TestCase): + + def generate_two_input(self, lower, upper, shape, dtype): + x = np.random.uniform(lower, upper, shape).astype(dtype) + y = np.random.uniform(lower, upper, shape).astype(dtype) + + npu_input = torch.from_numpy(x) + target_input = torch.from_numpy(y) + + return npu_input, target_input + + def generate_one_input(self, lower, upper, shape, dtype): + x = np.random.uniform(lower, upper, shape).astype(dtype) + npu_input = torch.from_numpy(x) + return npu_input + + def cpu_op_exec(self, input1, target, weight=None, pos_weight=None, reduction="mean"): + criterion = torch.nn.BCEWithLogitsLoss(weight=weight, pos_weight=pos_weight, + reduction=reduction) + res = criterion(input1, target) + return res.numpy() + + def npu_op_exec(self, input1, target, weight=None, pos_weight=None, reduction="mean"): + input1 = input1.to("npu") + target = target.to("npu") + if weight is not None: + weight = weight.to("npu") + if pos_weight is not None: + pos_weight = pos_weight.to("npu") + + criterion = torch.nn.BCEWithLogitsLoss(weight=weight, pos_weight=pos_weight, + reduction=reduction) + criterion = criterion.to("npu") + res = criterion(input1, target) + res = res.to("cpu") + return res.numpy() + + def cpu_op_func_exec(self, input1, target, weight=None, pos_weight=None, reduction="mean"): + res = torch.nn.functional.binary_cross_entropy_with_logits(input1, target, weight=weight, pos_weight=pos_weight, + reduction=reduction) + return res.numpy() + + def npu_op_func_exec(self, input1, target, weight=None, pos_weight=None, reduction="mean"): + input1 = input1.to("npu") + target = target.to("npu") + if weight is not None: + weight = weight.to("npu") + if pos_weight is not None: + pos_weight = pos_weight.to("npu") + + res = torch.nn.functional.binary_cross_entropy_with_logits(input1, target, weight=weight, pos_weight=pos_weight, + reduction=reduction) + res = res.to("cpu") + return res.numpy() + + def test_binary_cross_with_logits_float32(self, device): + for shape, weight_shape, pos_weight_shape, reduction in [ + ((10, 64), None, None, "mean"), + ((10, 64), (10, 1), None, "mean"), + ((10, 64), None, (64,), "mean"), + ((10, 64), None, None, "none"), + ((10, 64), (10, 1), None, "none"), + ((10, 64), None, (64,), "none"), + ((10, 64), None, None, "sum"), + ((10, 64), (10, 1), None, "sum"), + ((10, 64), None, (64,), "sum"), + ((10, 64), (10, 64), (10, 64), "mean"), + ((10, 64), (10, 64), (10, 64), "sum"), + ((10, 64), (10, 64), (10, 64), "none") + ]: + input1 = self.generate_one_input(0, 10, shape, np.float32) + target = torch.empty(shape, dtype=torch.float32).random_(2) + weight = None + pos_weight = None + if weight_shape is not None: + weight = self.generate_one_input(0, 10, weight_shape, np.float32) + if pos_weight_shape is not None: + pos_weight = self.generate_one_input(0, 10, pos_weight_shape, np.float32) + cpu_output = self.cpu_op_exec(input1, target, weight=weight, pos_weight=pos_weight, reduction=reduction) + npu_output = self.npu_op_exec(input1, target, weight=weight, pos_weight=pos_weight, reduction=reduction) + self.assertRtolEqual(cpu_output, npu_output) + + def test_binary_cross_with_logits_float16(self, device): + for shape, weight_shape, pos_weight_shape, reduction in [ + ((10, 64), None, None, "mean"), + ((10, 64), (10, 1), None, "mean"), + ((10, 64), None, (64,), "mean"), + ((10, 64), None, None, "none"), + ((10, 64), (10, 1), None, "none"), + ((10, 64), None, (64,), "none"), + ((10, 64), None, None, "sum"), + ((10, 64), (10, 1), None, "sum"), + ((10, 64), None, (64,), "sum"), + ((10, 64), (10, 64), (10, 64), "sum"), + ((10, 64), (10, 64), (10, 64), "mean"), + ((10, 64), (10, 64), (10, 64), "none") + ]: + input1 = self.generate_one_input(0, 10, shape, np.float16) + target = torch.empty(shape, dtype=torch.float16).random_(2) + input_32 = input1.type(torch.float32) + target_32 = target.type(torch.float32) + weight = None + weight_32 = None + pos_weight = None + pos_weight_32 = None + + if weight_shape is not None: + weight = self.generate_one_input(0, 10, weight_shape, np.float16) + weight_32 = weight.type(torch.float32) + if pos_weight_shape is not None: + pos_weight = self.generate_one_input(0, 10, pos_weight_shape, np.float16) + pos_weight_32 = pos_weight.type(torch.float32) + + npu_output = self.npu_op_exec(input1, target, weight=weight, pos_weight=pos_weight, reduction=reduction) + cpu_output = self.cpu_op_exec(input_32, target_32, weight=weight_32, pos_weight=pos_weight_32, + reduction=reduction) + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + def test_binary_cross_with_logits_function_float32(self, device): + for shape, weight_shape, pos_weight_shape, reduction in [ + ((10, 64), None, None, "mean"), + ((10, 64), (10, 1), None, "mean"), + ((10, 64), None, (64,), "mean"), + ((10, 64), None, None, "none"), + ((10, 64), (10, 1), None, "none"), + ((10, 64), None, (64,), "none"), + ((10, 64), None, None, "sum"), + ((10, 64), (10, 1), None, "sum"), + ((10, 64), None, (64,), "sum"), + ((10, 64), (10, 64), (10, 64), "mean"), + ((10, 64), (10, 64), (10, 64), "sum"), + ((10, 64), (10, 64), (10, 64), "none") + ]: + input1 = self.generate_one_input(0, 2, shape, np.float32) + target = torch.empty(shape, dtype=torch.float32).random_(2) + weight = None + pos_weight = None + if weight_shape is not None: + weight = self.generate_one_input(0, 2, weight_shape, np.float32) + if pos_weight_shape is not None: + pos_weight = self.generate_one_input(0, 2, pos_weight_shape, np.float32) + cpu_output = self.cpu_op_func_exec(input1, target, weight=weight, pos_weight=pos_weight, reduction=reduction) + npu_output = self.npu_op_func_exec(input1, target, weight=weight, pos_weight=pos_weight, reduction=reduction) + self.assertRtolEqual(cpu_output, npu_output) + + def test_binary_cross_with_logits_function_float16(self, device): + for shape, weight_shape, pos_weight_shape, reduction in [ + ((10, 64), None, None, "mean"), + ((10, 64), (10, 1), None, "mean"), + ((10, 64), None, (64,), "mean"), + ((10, 64), None, None, "none"), + ((10, 64), (10, 1), None, "none"), + ((10, 64), None, (64,), "none"), + ((10, 64), None, None, "sum"), + ((10, 64), (10, 1), None, "sum"), + ((10, 64), None, (64,), "sum"), + ((10, 64), (10, 64), (10, 64), "sum"), + ((10, 64), (10, 64), (10, 64), "mean"), + ((10, 64), (10, 64), (10, 64), "none") + ]: + input1 = self.generate_one_input(0, 2, shape, np.float16) + target = torch.empty(shape, dtype=torch.float16).random_(2) + input_32 = input1.type(torch.float32) + target_32 = target.type(torch.float32) + weight = None + weight_32 = None + pos_weight = None + pos_weight_32 = None + + if weight_shape is not None: + weight = self.generate_one_input(0, 2, weight_shape, np.float16) + weight_32 = weight.type(torch.float32) + if pos_weight_shape is not None: + pos_weight = self.generate_one_input(0, 2, pos_weight_shape, np.float16) + pos_weight_32 = pos_weight.type(torch.float32) + + npu_output = self.npu_op_func_exec(input1, target, weight=weight, pos_weight=pos_weight, reduction=reduction) + cpu_output = self.cpu_op_func_exec(input_32, target_32, weight=weight_32, pos_weight=pos_weight_32, + reduction=reduction) + + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestBinaryCrossEntropyWithLogits, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/BinaryCrossEntropyWithLogitsKernelNpu.cpp b/torch_npu/csrc/aten/ops/BinaryCrossEntropyWithLogitsKernelNpu.cpp new file mode 100644 index 0000000000..6b4b2d8eba --- /dev/null +++ b/torch_npu/csrc/aten/ops/BinaryCrossEntropyWithLogitsKernelNpu.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::binary_cross_entropy_with_logits( + const at::Tensor& self, + const at::Tensor& target, + const c10::optional& weight_opt, + const c10::optional& pos_weight_opt, + int64_t reduction) { + const at::Tensor& weight = c10::value_or_else(weight_opt, [] {return at::Tensor();}); + const at::Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return at::Tensor();}); + at::IntArrayRef outputSize; + int64_t resultformat = CalcuOpUtil::get_tensor_npu_format(self); + + if (reduction == at::Reduction::None) { + outputSize = input_same_output_size(self); + } else { + outputSize = at::ArrayRef(); + resultformat = ACL_FORMAT_ND; + } + + at::Tensor result = OpPreparation::ApplyTensorWithFormat(outputSize, target.options(), resultformat); + at::Tensor weightTensor; + if (weight.defined()) { + weightTensor = NpuUtils::format_contiguous(weight); + } else { + weightTensor = at::ones(target.sizes(), target.options()); + } + + at::Tensor posWeightTensor; + if (pos_weight.defined()) { + posWeightTensor = NpuUtils::format_contiguous(pos_weight); + } else { + posWeightTensor = at::ones(target.sizes(), target.options()); + } + + string reductionStr = CalcuOpUtil::get_reduction_str(reduction); + OpCommand cmd; + cmd.Name("SigmoidCrossEntropyWithLogitsV2") + .Input(self.to(target.dtype())) + .Input(target) + .Input(weightTensor) + .Input(posWeightTensor) + .Output(result) + .Attr("reduction", reductionStr) + .Run(); + + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee