diff --git a/test/test_network_ops/test_softmaxcrossentropywithlogits.py b/test/test_network_ops/test_softmaxcrossentropywithlogits.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc58dd5b7d928f6b2e02bab0619a3c1acdb8444 --- /dev/null +++ b/test/test_network_ops/test_softmaxcrossentropywithlogits.py @@ -0,0 +1,38 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestSoftmaxCrossentropyWithLogits(TestCase): + def npu_op_exec(self, input1, label): + output = torch_npu.npu_softmax_cross_entropy_with_logits(input1, label) + output = output.to("cpu") + output = output.numpy() + return output + + def test_softmaxcross(self, device): + input1 = torch.tensor([[1.,2.,3.,4.]]).npu() + label = torch.tensor([[1.,2.,3.,4.]]).npu() + exresult = torch.tensor([14.4019]) + output = self.npu_op_exec(input1, label) + self.assertRtolEqual(exresult.numpy(), output) + +instantiate_device_type_tests(TestSoftmaxCrossentropyWithLogits, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index f5c9f3adca73808821ba19df04bf786383590f39..54ea275b8ade5cbe00e96f0d664765e0ad0f8185 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1913,6 +1913,10 @@ custom: - func: npu_indexing.out(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0, *, Tensor(a!) out) -> Tensor(a!) - func: npu_bmmV2(Tensor self, Tensor mat2, int[] output_sizes) -> Tensor variants: function, method + - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor + variants: function, method + - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor + variants: function, method custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp b/torch_npu/csrc/aten/ops/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73793d896c34ac63ead4299d0e970d3a5a1fb8b2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include + +namespace at_npu { +namespace native { +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +tuple softmax_cross_entropy_with_logits_npu_nocheck( + at::Tensor& result, + at::Tensor& backprop, + const at::Tensor& self, + const at::Tensor& labels) { + OpCommand cmd; + cmd.Name("SoftmaxCrossEntropyWithLogits") + .Input(self) + .Input(labels) + .Output(result) + .Output(backprop) + .Run(); + + return std::make_tuple(result, backprop); +} + +tuple softmax_cross_entropy_with_logits_impl_npu( + const at::Tensor& self, + const at::Tensor& labels) { + auto outputSizes = + softmax_cross_entropy_with_logits_impl_npu_output_size(self); + at::Tensor result = OpPreparation::ApplyTensor(self, std::get<0>(outputSizes)); + at::Tensor backprop = OpPreparation::ApplyTensor(self, std::get<1>(outputSizes)); + + softmax_cross_entropy_with_logits_npu_nocheck(result, backprop, self, labels); + + return std::make_tuple(result, backprop); +} + +at::Tensor NPUNativeFunctions::npu_softmax_cross_entropy_with_logits( + const at::Tensor& self, + const at::Tensor& labels) { + TORCH_CHECK(self.device().type() == c10::DeviceType::NPU); + return std::get<0>(softmax_cross_entropy_with_logits_impl_npu(self, labels)); +} + +at::Tensor NPUNativeFunctions::npu_softmax_cross_entropy_with_logits_backward( + const at::Tensor& grad, + const at::Tensor& self, + const at::Tensor& labels) { + at::Tensor result1 = std::get<1>(softmax_cross_entropy_with_logits_impl_npu(self, labels)); + return result1 * grad.unsqueeze(-1); +} + +class NPUSoftmaxCrossEntropyWithLogitsFunction: public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self, + const at::Tensor& labels) { + ctx->saved_data["labels"] = labels; + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({self}); + return NPUNativeFunctions::npu_softmax_cross_entropy_with_logits(self, labels); + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto labels = ctx->saved_data["labels"].toTensor(); + auto saved = ctx->get_saved_variables(); + auto self = saved[0]; + + at::Tensor result = NPUNativeFunctions::npu_softmax_cross_entropy_with_logits_backward(grad_outputs[0], + self, + labels); + tensor_list output = {result, + at::Tensor()}; + return output; + } +}; + +at::Tensor npu_softmax_cross_entropy_with_logits_autograd(const at::Tensor& self, + const at::Tensor& labels) { + return NPUSoftmaxCrossEntropyWithLogitsFunction::apply(self, labels); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file