From 69a6133f2cee8e67da934112fef29cdac132b254 Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Wed, 26 Jan 2022 13:50:21 +0800 Subject: [PATCH 1/2] =?UTF-8?q?BitwiseOr=E7=AE=97=E5=AD=90=E8=BF=81?= =?UTF-8?q?=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit modify file code checkout fix checking fix check1 fix check2 fix check3 : fix check4 fix check5 code fix code fix2 --- .../test_binary_cross_entropy.py | 94 +++++++ test/test_network_ops/test_bitwise_or.py | 239 ++++++++++++++++++ .../csrc/aten/ops/BinaryCrossEntropy.cpp | 88 +++++++ .../csrc/aten/ops/BitwiseOrKernelNpu.cpp | 160 ++++++++++++ 4 files changed, 581 insertions(+) create mode 100644 test/test_network_ops/test_binary_cross_entropy.py create mode 100644 test/test_network_ops/test_bitwise_or.py create mode 100644 torch_npu/csrc/aten/ops/BinaryCrossEntropy.cpp create mode 100644 torch_npu/csrc/aten/ops/BitwiseOrKernelNpu.cpp diff --git a/test/test_network_ops/test_binary_cross_entropy.py b/test/test_network_ops/test_binary_cross_entropy.py new file mode 100644 index 0000000000..b7cf9b113f --- /dev/null +++ b/test/test_network_ops/test_binary_cross_entropy.py @@ -0,0 +1,94 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import torch +import torch_npu +import torch.nn as nn +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +LOWER = 0 +UPPER = 1 + +class TestBinaryCrossEntropy(TestCase): + + def generate_input(self, lower, upper, shape, dtype): + temp = np.random.uniform(lower, upper, shape).astype(dtype) + npu_input = torch.from_numpy(temp) + return npu_input + + def cpu_op_exec(self, predict, target, weight=None, reduction="mean"): + res = torch.nn.functional.binary_cross_entropy(predict, target, weight=weight, reduction=reduction) + return res.numpy() + + def cpu_op_exec_half(self, predict, target, weight=None, reduction="mean"): + res = torch.nn.functional.binary_cross_entropy(predict, target, weight=weight, reduction=reduction) + return res.type(torch.float16).numpy() + + def npu_op_exec(self, predict, target, weight=None, reduction="mean"): + predict = predict.to("npu") + target = target.to("npu") + if weight is not None: + weight = weight.to("npu") + res = torch.nn.functional.binary_cross_entropy(predict, target, weight=weight, reduction=reduction) + res = res.to("cpu") + return res.numpy() + + def test_binary_cross_entropy_float32(self, device): + for shape, weight_shape, reduction in [ + ((10, 64), None, "mean"), + ((10, 64), (10, 1), "mean"), + ((10, 64), None, "mean"), + ((10, 64), (10, 64), "mean"), + ((10, 64), (10, 64), "sum" ), + ((10, 64), (10, 64), "none") + ]: + predict = self.generate_input(LOWER, UPPER, shape, np.float32) + target = torch.empty(shape, dtype=torch.float32).random_(2) + weight = None + if weight_shape is not None: + weight = self.generate_input(LOWER, UPPER, weight_shape, np.float32) + cpu_output = self.cpu_op_exec(predict, target, weight=weight, reduction=reduction) + npu_output = self.npu_op_exec(predict, target, weight=weight, reduction=reduction) + self.assertRtolEqual(cpu_output, npu_output) + + def test_binary_cross_entropy_float16(self, device): + for shape, weight_shape, reduction in [ + ((10, 64), (10, 64), "sum"), + ((10, 64), (10, 64), "mean"), + ((10, 64), (10, 64), "none") + ]: + predict = self.generate_input(LOWER, UPPER, shape, np.float16) + target = torch.empty(shape, dtype=torch.float16).random_(2) + predict_32 = predict.type(torch.float32) + target_32 = target.type(torch.float32) + weight = None + weight_32 = None + if weight_shape is not None: + weight = self.generate_input(LOWER, UPPER, weight_shape, np.float16) + weight_32 = weight.type(torch.float32) + + npu_output = self.npu_op_exec(predict, target, weight=weight, reduction=reduction) + cpu_output = self.cpu_op_exec_half(predict_32, target_32, weight=weight_32, reduction=reduction) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestBinaryCrossEntropy, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_bitwise_or.py b/test/test_network_ops/test_bitwise_or.py new file mode 100644 index 0000000000..8b81687d3a --- /dev/null +++ b/test/test_network_ops/test_bitwise_or.py @@ -0,0 +1,239 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestBitwiseOr(TestCase): + def generate_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + + return npu_input1, npu_input2 + + def generate_bool_data(self, min_d, max_d, shape): + input1 = np.random.uniform(min_d, max_d, shape) + input2 = np.random.uniform(min_d, max_d, shape) + input1 = input1.reshape(-1) + input2 = input2.reshape(-1) + seq = range(len(input1)) + for i in enumerate(seq): + if input1[i] < 0.5: + input1[i] = 0 + seq = range(len(input1)) + for i in enumerate(seq): + if input2[i] < 0.5: + input2[i] = 0 + input1 = input1.astype(np.bool) + input2 = input2.astype(np.bool) + input1 = input1.reshape(shape) + input2 = input2.reshape(shape) + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + + return npu_input1, npu_input2 + + def generate_single_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + + return npu_input1 + + def generate_single_bool_data(self, min_d, max_d, shape): + input1 = np.random.uniform(min_d, max_d, shape) + input1 = input1.reshape(-1) + seq = range(len(input1)) + for i in enumerate(seq): + if input1[i] < 0.5: + input1[i] = 0 + input1 = input1.astype(np.bool) + input1 = input1.reshape(shape) + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + + return npu_input1 + + def generate_three_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input3 = np.random.uniform(min_d, max_d, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + npu_input3 = torch.from_numpy(input3) + + return npu_input1, npu_input2, npu_input3 + + def npu_op_exec_out(self, input1, input2, input3): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = input3.to("npu") + torch.bitwise_or(input1, input2, out=output) + output = output.to("cpu") + if output.dtype not in [torch.int32, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def npu_mix_op_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.bitwise_or(input1, input2) + output = output.to("cpu") + if output.dtype not in [torch.int32, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def cpu_op_exec(self, input1, input2): + output = torch.bitwise_or(input1, input2) + if output.dtype not in [torch.int32, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def cpu_op_exec_out(self, input1, input2, output): + torch.bitwise_or(input1, input2, out = output) + if output.dtype not in [torch.int32, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def cpu_op_exec_scalar(self, input1, scalar): + output = torch.bitwise_or(input1, scalar) + if output.dtype not in [torch.int32, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def cpu_op_exec_scalar_out(self, input1, scalar, output): + torch.bitwise_or(input1, scalar, out = output) + if output.dtype not in [torch.int32, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def npu_op_exec_scalar_out(self, input1, scalar, output): + input1 = input1.to("npu") + output = output.to("npu") + output = torch.bitwise_or(input1, scalar, out = output) + output = output.to("cpu") + if output.dtype not in [torch.int32, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def bitwise_or_tensor_out_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 0, 100) + cpu_input3, npu_input3 = create_common_tensor(item[1], 0, 100) + cpu_output_out = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, npu_input3) + cpu_output_out = cpu_output_out.astype(npu_output_out.dtype) + + self.assertRtolEqual(cpu_output_out, npu_output_out) + + def test_bitwise_or_tensor_out(self, device): + shape_format = [ + [[np.int16, 0, [128, 3, 224, 224]], [np.int16, 0, [3, 3, 3]]], + [[np.int16, 0, [128, 116, 14, 14]], [np.int16, 0, [128, 116, 14, 14]]], + [[np.int32, 0, [256, 128, 7, 7]], [np.int32, 0, [128, 256, 3, 3]]], + [[np.int32, 0, [2, 3, 3, 3]], [np.int32, 0, [3, 1, 3]]], + [[np.int32, 0, [128, 232, 7, 7]], [np.int32, 0, [128, 232, 7, 7]]], + ] + self.bitwise_or_tensor_out_result(shape_format) + + def bitwise_or_scalar_out_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 100) + scalar = np.random.randint(1, 5) + cpu_output_out = self.cpu_op_exec_scalar(cpu_input1, scalar) + npu_output_out = self.npu_op_exec_scalar_out(npu_input1, scalar, npu_input2) + cpu_output_out = cpu_output_out.astype(npu_output_out.dtype) + self.assertRtolEqual(cpu_output_out, npu_output_out) + + def test_bitwise_or_scalar_out(self, device): + shape_format = [ + [[np.int16, 0, [16, 3, 1111, 1212]], [np.int16, 0, [3, 3, 3]]], + [[np.int16, 0, [128, 116, 14, 14]], [np.int16, 0, [128, 116, 14, 14]]], + [[np.int32, 0, [1313, 3, 3, 3]], [np.int32, 0, [3, 1, 3]]], + [[np.int32, 0, [128, 232, 7, 7]], [np.int32, 0, [128, 232, 7, 7]]], + ] + self.bitwise_or_scalar_out_result(shape_format) + + def test_bitwise_or_int32(self, device): + npu_input1, npu_input2 = self.generate_data(0, 100, (2,3), np.int32) + cpu_output = self.cpu_op_exec_out(npu_input1, npu_input2, npu_input1) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2, npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_or_bool_scalar(self, device): + npu_input1, npu_input2 = self.generate_data(0, 100, (2,3), np.int32) + cpu_output = self.cpu_op_exec_out(npu_input1, True, npu_input1) + npu_output = self.npu_op_exec_scalar_out(npu_input1, True, npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_or_int32_scalar(self, device): + npu_input1, npu_input2 = self.generate_data(0, 100, (2,3), np.int32) + cpu_output = self.cpu_op_exec_out(npu_input1, 1, npu_input1) + npu_output = self.npu_op_exec_scalar_out(npu_input1, 1, npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_or_int16(self, device): + npu_input1, npu_input2 = self.generate_data(0, 100, (1,6), np.int16) + cpu_output = self.cpu_op_exec_out(npu_input1, npu_input2, npu_input1) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2, npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_or_int16_scalar(self, device): + npu_input1, npu_input2 = self.generate_data(0, 100, (2,3), np.int16) + cpu_output = self.cpu_op_exec_out(npu_input1, 1, npu_input1) + npu_output = self.npu_op_exec_scalar_out(npu_input1, 1, npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_or_int16_diff(self, device): + npu_input1 = self.generate_single_data(0, 100, (1,6), np.int16) + npu_input2 = self.generate_single_data(0, 100, (1,1), np.int16) + cpu_output = self.cpu_op_exec_out(npu_input1, npu_input2, npu_input1) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2, npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_or_int16_out(self, device): + npu_input1, npu_input2, npu_input3 = self.generate_three_data(0, 100, (4,3), np.int16) + cpu_output = self.cpu_op_exec_out(npu_input1, npu_input2, npu_input3) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2, npu_input3) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_or_mix_dtype(self, device): + npu_input1 = self.generate_single_data(0, 100, (1,6), np.int32) + npu_input2 = self.generate_single_data(0, 100, (1,6), np.int16) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_mix_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestBitwiseOr, globals(), except_for='cpu') +if __name__ == '__main__': + run_tests() diff --git a/torch_npu/csrc/aten/ops/BinaryCrossEntropy.cpp b/torch_npu/csrc/aten/ops/BinaryCrossEntropy.cpp new file mode 100644 index 0000000000..9c949d35a9 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BinaryCrossEntropy.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ATen/native/npu/utils/OpAdapter.h" + +namespace at { +namespace native { +using namespace at::native::npu; + +Tensor& binary_cross_entropy_out_npu( + const Tensor& self, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction, + Tensor& result) { + const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();}); + Tensor weightTensor = weight; + if (!weight.defined()) { + weightTensor = at::ones(self.sizes(), self.options()); + } + + // constructs the attr of the NPUAttrDesc + string reductionStr = "mean"; + if (reduction == Reduction::None) { + reductionStr = "none"; + } else if (reduction == Reduction::Sum) { + reductionStr = "sum"; + } + + // executing the NPU operator + OpCommand cmd; + cmd.Name("BinaryCrossEntropy") + .Input(self) + .Input(target) + .Input(weightTensor) + .Output(result) + .Attr("reduction", reductionStr) + .Run(); + return result; +} + +Tensor binary_cross_entropy_npu( + const Tensor& self, + const Tensor& target, + const c10::optional& weight_opt, + int64_t reduction) { + const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();}); + // calculate the output size + IntArrayRef outputSize; + + if (reduction == Reduction::None) { + outputSize = input_same_output_size(self); + } else { + outputSize = ArrayRef(); + } + + // construct the output tensor of the NPU + Tensor result = OpPreparation::ApplyTensor(self, outputSize); + if (self.numel() == 0) { + // In this scenario, needs to return nan. And the nan of the NPU can only be fp32. + result = result.to(at::kFloat).fill_(0); + result = result / 0; + return result; + } + + // calculate the output result of the NPU + binary_cross_entropy_out_npu(self, target, weight, reduction, result); + return result; +} +TORCH_LIBRARY_IMPL(aten, NPU, m) { + m.impl("binary_cross_entropy.out", TORCH_FN(binary_cross_entropy_out_npu)); + m.impl("binary_cross_entropy", TORCH_FN(binary_cross_entropy_npu)); +} +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/BitwiseOrKernelNpu.cpp b/torch_npu/csrc/aten/ops/BitwiseOrKernelNpu.cpp new file mode 100644 index 0000000000..41c4b69935 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BitwiseOrKernelNpu.cpp @@ -0,0 +1,160 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& bitwise_or_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Scalar other) { + // executing the NPU operator + string real_op_name = + (self.dtype() == at::ScalarType::Bool) ? "LogicalOr" : "BitwiseOr"; + + OpCommand cmd; + cmd.Name(real_op_name) + .Input(self) + .Input(other, self.scalar_type()) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::bitwise_or_out( + const at::Tensor& self, + const at::Scalar other, + at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + bitwise_or_out_npu_nocheck(result, self, other); + return result; +} + +at::Tensor& bitwise_or_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& other) { + auto unified_result = OpPreparation::binary_op_check(result, self, other, true); + if (other.dim() == 0 && !other.is_npu()) { + NPUNativeFunctions::bitwise_or_out(self, other.item(), result); + } else if (self.dim() == 0 && !self.is_npu()) { + NPUNativeFunctions::bitwise_or_out(other, self.item(), result); + } else { + // executing the NPU operator + string real_op_name = + (self.dtype() == at::ScalarType::Bool) ? "LogicalOr" : "BitwiseOr"; + + OpCommand cmd; + cmd.Name(real_op_name) + .Expect(unified_result) + .Input(self) + .Input(other) + .Output(result) + .Run(); + } + return result; +} + +at::Tensor& NPUNativeFunctions::bitwise_or_out( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); + + at::Tensor outputTensor; + if (not isSelfWrapped) { + outputTensor = self; + } else { + outputTensor = other; + } + auto outputSize = broadcast_ops_npu_output_size(self, other); + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(outputTensor), + outputTensor.scalar_type(), + outputSize); + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + at::Tensor newResult = bitwise_or_out_npu_nocheck(contiguousResult, self, other); + NpuUtils::format_fresh_view(result, newResult); + } else { + bitwise_or_out_npu_nocheck(result, self, other); + } + return result; +} + +at::Tensor NPUNativeFunctions::bitwise_or(const at::Tensor& self, const at::Tensor& other) { + // calculate the output size + bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); + + at::Tensor outputTensor; + if (not isSelfWrapped) { + outputTensor = self; + } else { + outputTensor = other; + } + + auto outputSize = broadcast_ops_npu_output_size(self, other); + + // construct the output Tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(outputTensor, outputSize); + // calculate the output result of the NPU + bitwise_or_out_npu_nocheck(result, self, other); + + return result; +} + +at::Tensor NPUNativeFunctions::bitwise_or(const at::Tensor& self, at::Scalar other) { + at::Tensor result = OpPreparation::ApplyTensor(self); + + // calculate the output result of the NPU + bitwise_or_out_npu_nocheck(result, self, other); + return result; +} + +at::Tensor& NPUNativeFunctions::bitwise_or_(at::Tensor& self, const at::Tensor& other) { + OpPreparation::CheckMemory({self, other}, {self}); + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = bitwise_or_out_npu_nocheck(contiguousSelf, contiguousSelf, other); + NpuUtils::format_fresh_view(self, result); + } else { + bitwise_or_out_npu_nocheck(self, self, other); + } + return self; +} + +at::Tensor& NPUNativeFunctions::bitwise_or_(at::Tensor& self, at::Scalar other) { + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = bitwise_or_out_npu_nocheck(contiguousSelf, contiguousSelf, other); + NpuUtils::format_fresh_view(self, result); + } else { + bitwise_or_out_npu_nocheck(self, self, other); + } + return self; +} + +} // namespace native +} // namespace at_npu -- Gitee From 0c75ac00b8547aaeec8e161761c47d74ed5cfb64 Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Fri, 28 Jan 2022 17:49:44 +0800 Subject: [PATCH 2/2] code fix --- .../csrc/aten/ops/BinaryCrossEntropy.cpp | 88 ------------------- 1 file changed, 88 deletions(-) delete mode 100644 torch_npu/csrc/aten/ops/BinaryCrossEntropy.cpp diff --git a/torch_npu/csrc/aten/ops/BinaryCrossEntropy.cpp b/torch_npu/csrc/aten/ops/BinaryCrossEntropy.cpp deleted file mode 100644 index 9c949d35a9..0000000000 --- a/torch_npu/csrc/aten/ops/BinaryCrossEntropy.cpp +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co., Ltd -// Copyright (c) 2019, Facebook CORPORATION. -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ATen/native/npu/utils/OpAdapter.h" - -namespace at { -namespace native { -using namespace at::native::npu; - -Tensor& binary_cross_entropy_out_npu( - const Tensor& self, - const Tensor& target, - const c10::optional& weight_opt, - int64_t reduction, - Tensor& result) { - const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();}); - Tensor weightTensor = weight; - if (!weight.defined()) { - weightTensor = at::ones(self.sizes(), self.options()); - } - - // constructs the attr of the NPUAttrDesc - string reductionStr = "mean"; - if (reduction == Reduction::None) { - reductionStr = "none"; - } else if (reduction == Reduction::Sum) { - reductionStr = "sum"; - } - - // executing the NPU operator - OpCommand cmd; - cmd.Name("BinaryCrossEntropy") - .Input(self) - .Input(target) - .Input(weightTensor) - .Output(result) - .Attr("reduction", reductionStr) - .Run(); - return result; -} - -Tensor binary_cross_entropy_npu( - const Tensor& self, - const Tensor& target, - const c10::optional& weight_opt, - int64_t reduction) { - const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();}); - // calculate the output size - IntArrayRef outputSize; - - if (reduction == Reduction::None) { - outputSize = input_same_output_size(self); - } else { - outputSize = ArrayRef(); - } - - // construct the output tensor of the NPU - Tensor result = OpPreparation::ApplyTensor(self, outputSize); - if (self.numel() == 0) { - // In this scenario, needs to return nan. And the nan of the NPU can only be fp32. - result = result.to(at::kFloat).fill_(0); - result = result / 0; - return result; - } - - // calculate the output result of the NPU - binary_cross_entropy_out_npu(self, target, weight, reduction, result); - return result; -} -TORCH_LIBRARY_IMPL(aten, NPU, m) { - m.impl("binary_cross_entropy.out", TORCH_FN(binary_cross_entropy_out_npu)); - m.impl("binary_cross_entropy", TORCH_FN(binary_cross_entropy_npu)); -} -} // namespace native -} // namespace at \ No newline at end of file -- Gitee