diff --git a/test/test_network_ops/test_baddbmm.py b/test/test_network_ops/test_baddbmm.py new file mode 100644 index 0000000000000000000000000000000000000000..b59fcdcde5138d30cc2e3cf54492ae552a91b468 --- /dev/null +++ b/test/test_network_ops/test_baddbmm.py @@ -0,0 +1,114 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np +from torch.nn import functional as F + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, test_2args_broadcast, create_dtype_tensor, UT_FAST_MODE + +class TestBaddBmm(TestCase): + def generate_scalar(self, dtype, min1, max1): + if dtype == "float32": + scalar = np.random.uniform(min1, max1) + if dtype == "float16": + scalar = np.random.uniform(min1, max1) + if dtype == "int32": + scalar = np.random.randint(min1, max1) + return scalar + + def cpu_op_exec(self, input1, input2, input3, scalar1, scalar2): + output = torch.baddbmm(input1, input2, input3, beta=scalar1, alpha=scalar2) + output = output.numpy() + return output + + def cpu_op_exec_(self, input1, input2, input3, scalar1, scalar2): + input1.baddbmm_(input2, input3, beta=scalar1, alpha=scalar2) + input1 = input1.numpy() + return input1 + + def npu_op_exec(self, input1, input2, input3, scalar1, scalar2): + output = torch.baddbmm(input1, input2, input3, beta=scalar1, alpha=scalar2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_(self, input1, input2, input3, scalar1, scalar2): + input1.baddbmm_(input2, input3, beta=scalar1, alpha=scalar2) + input1 = input1.to("cpu") + input1 = input1.numpy() + return input1 + + def test_baddbmm_common_shape_format(self, device): + shape_format = [ + [[np.float32, -1, (1, 3, 5)], [np.float32, -1, (1, 3, 4)], + [np.float32, -1, (1, 4, 5)], "float32"], + [[np.float32, -1, (6, 4, 3)], [np.float32, -1, (6, 4, 5)], + [np.float32, -1, (6, 5, 3)], "float32"], + [[np.float32, -1, (175, 455, 22)], [np.float32, -1, (175, 455, 116)], + [np.float32, -1, (175, 116, 22)], "float32"], + [[np.float32, -1, (25, 56, 12)], [np.float32, -1, (25, 56, 51)], + [np.float32, -1, (25, 51, 12)], "float32"] + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 10) + cpu_input2, npu_input2 = create_common_tensor(item[1], 1, 10) + cpu_input3, npu_input3 = create_common_tensor(item[2], 1, 10) + scalar1 = self.generate_scalar(item[3], 0, 10) + scalar2 = self.generate_scalar(item[3], 0, 10) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2, cpu_input3, scalar1, scalar2) + npu_output = self.npu_op_exec(npu_input1, npu_input2, npu_input3, scalar1, scalar2) + self.assertRtolEqual(cpu_output, npu_output) + cpu_output_ = self.cpu_op_exec_(cpu_input1, cpu_input2, cpu_input3, scalar1, scalar2) + npu_output_ = self.npu_op_exec_(npu_input1, npu_input2, npu_input3, scalar1, scalar2) + self.assertRtolEqual(cpu_output_, npu_output_) + + def test_baddbmm_float16_shape_format(self, device): + def cpu_op_exec_fp16(input1, input2, input3, scalar1, scalar2): + input1 = input1.to(torch.float32) + input2 = input2.to(torch.float32) + input3 = input3.to(torch.float32) + output = torch.baddbmm(input1, input2, input3, beta=scalar1, alpha=scalar2) + output = output.numpy() + output = output.astype(np.float16) + return output + + shape_format = [ + [[np.float16, -1, (1, 3, 5)], [np.float16, -1, (1, 3, 4)], + [np.float16, -1, (1, 4, 5)], "float16"], + [[np.float16, -1, (500, 40, 300)], [np.float16, -1, (500, 40, 500)], + [np.float16, -1, (500, 500, 300)], "float16"], + [[np.float16, -1, (175, 455, 22)], [np.float16, -1, (175, 455, 116)], + [np.float16, -1, (175, 116, 22)], "float16"], + [[np.float16, -1, (25, 21, 11)], [np.float16, -1, (25, 21, 34)], + [np.float16, -1, (25, 34, 11)], "float16"], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 10) + cpu_input2, npu_input2 = create_common_tensor(item[1], 1, 10) + cpu_input3, npu_input3 = create_common_tensor(item[2], 1, 10) + scalar1 = self.generate_scalar(item[3], 0, 10) + scalar2 = self.generate_scalar(item[3], 0, 10) + cpu_output = cpu_op_exec_fp16(cpu_input1, cpu_input2, cpu_input3, scalar1, scalar2) + npu_output = self.npu_op_exec(npu_input1, npu_input2, npu_input3, scalar1, scalar2) + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestBaddBmm, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/BaddbmmKernelNpu.cpp b/torch_npu/csrc/aten/ops/BaddbmmKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..df462d143f36f7dfd1c946755e5587a1dbeff2b6 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BaddbmmKernelNpu.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +c10::SmallVector baddbmm_npu_input( + const at::Tensor& self, + const at::Tensor& other) { + return CalcuOpUtil::create_npu_input_tensor_desc({self, other}); +} + +c10::SmallVector baddbmm_npu_output( + const at::Tensor& result) { + return CalcuOpUtil::create_npu_output_tensor_desc({result}); +} + +c10::SmallVector baddbmm_npu_attr( + const at::Tensor& self, + const at::Tensor& mat2) { + bool isSelfT = CalcuOpUtil::is_transpose_last_two_dims(self); + bool isMat2T = CalcuOpUtil::is_transpose_last_two_dims(mat2); + NPUAttrDesc npuAttrSelfTranspose = NPUAttrDesc("adj_x1", isSelfT); + NPUAttrDesc npuAttrMat2Transpose = NPUAttrDesc("adj_x2", isMat2T); + c10::SmallVector attrs = {npuAttrSelfTranspose, npuAttrMat2Transpose}; + return attrs; +} + +at::Tensor& NPUNativeFunctions::baddbmm_out( + const at::Tensor& self, + const at::Tensor& tensor1, + const at::Tensor& tensor2, + at::Scalar beta, + at::Scalar alpha, + at::Tensor& result) { + auto outputSize = baddbmm_npu_output_size(tensor1, tensor2); + at::Tensor BatchMatMulTensor = OpPreparation::ApplyTensor(self, outputSize); + + auto inputs = baddbmm_npu_input(tensor1, tensor2); + auto outputs = baddbmm_npu_output({BatchMatMulTensor}); + auto attrs = baddbmm_npu_attr(tensor1, tensor2); + CalcuOpUtil::execute_npu_operate("BatchMatMul", inputs, outputs, attrs); + + at::Tensor alphaMulTensor = at::mul(BatchMatMulTensor, alpha); + + at::Tensor betaMulTensor = at::mul(self, beta); + + at::add_out(result, alphaMulTensor, betaMulTensor); + + return result; +} + +at::Tensor NPUNativeFunctions::baddbmm( + const at::Tensor& self, + const at::Tensor& tensor1, + const at::Tensor& tensor2, + at::Scalar beta, + at::Scalar alpha) { + at::Tensor outputTensor = self; + auto outputSize = baddbmm_npu_output_size(tensor1, tensor2); + at::Tensor result = OpPreparation::ApplyTensor( + outputTensor, + outputSize); + NPUNativeFunctions::baddbmm_out(self, tensor1, tensor2, beta, alpha, result); + return result; +} + +at::Tensor& NPUNativeFunctions::baddbmm_( + at::Tensor& self, + const at::Tensor& tensor1, + const at::Tensor& tensor2, + at::Scalar beta, + at::Scalar alpha) { + c10::SmallVector inputs = {self, tensor1, tensor2}; + c10::SmallVector outputs = {self}; + CalcuOpUtil::check_memory_over_laps(inputs, outputs); + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = NPUNativeFunctions::baddbmm_out(contiguousSelf, tensor1, tensor2, beta, alpha, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + NPUNativeFunctions::baddbmm_out(self, tensor1, tensor2, beta, alpha, self); + } + + return self; +} +} +}