From 3437141c69ab5140b282570fc572a791c64a6518 Mon Sep 17 00:00:00 2001 From: sunxing <346736790@qq.com> Date: Tue, 8 Feb 2022 17:25:58 +0800 Subject: [PATCH] =?UTF-8?q?BmmV2=E8=87=AA=E5=AE=9A=E4=B9=89=E7=AE=97?= =?UTF-8?q?=E5=AD=90=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_bmmV2.py | 74 +++++ torch_npu/csrc/aten/npu_native_functions.yaml | 2 + torch_npu/csrc/aten/ops/BmmV2KernelNpu.cpp | 302 ++++++++++++++++++ 3 files changed, 378 insertions(+) create mode 100644 test/test_network_ops/test_bmmV2.py create mode 100644 torch_npu/csrc/aten/ops/BmmV2KernelNpu.cpp diff --git a/test/test_network_ops/test_bmmV2.py b/test/test_network_ops/test_bmmV2.py new file mode 100644 index 00000000000..1b92a3ea276 --- /dev/null +++ b/test/test_network_ops/test_bmmV2.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestBatchMatMulV2(TestCase): + def cpu_op_exec(self, input1, input2): + output = torch.bmm(input1, input2) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + output = torch_npu.npu_bmmV2(input1, input2, []) + output = output.to("cpu") + output = output.numpy() + return output + + def bmm_auto_list_exec(self, shape): + for item in shape: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 10) + cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 10) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input2.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def test_batchmatmul_shape_format_fp16_3d(self, device): + format_list = [0, 3, 29] + shape_list = [(1, 3, 2)] + shape_format1 = [[np.float16, i, j] + for i in format_list for j in shape_list] + format_list = [0, 3, 29] + shape_list = [(1, 2, 3)] + shape_format2 = [[np.float16, i, j] + for i in format_list for j in shape_list] + shape_format = [[i, j] for i in shape_format1 for j in shape_format2] + self.bmm_auto_list_exec(shape_format) + + def test_batchmatmul_shape_format_fp32_3d(self, device): + format_list = [0, 3, 29] + shape_list = [(1, 3, 2)] + shape_format1 = [[np.float32, i, j] + for i in format_list for j in shape_list] + format_list = [0, 3, 29] + shape_list = [(1, 2, 3)] + shape_format2 = [[np.float32, i, j] + for i in format_list for j in shape_list] + shape_format = [[i, j] for i in shape_format1 for j in shape_format2] + self.bmm_auto_list_exec(shape_format) + +instantiate_device_type_tests(TestBatchMatMulV2, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 1f08a26654f..f5c9f3adca7 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1911,6 +1911,8 @@ custom: - func: npu_slice.out(Tensor self, int[] offsets, int[] size, *, Tensor(a!) out) -> Tensor(a!) - func: npu_indexing(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0) -> Tensor - func: npu_indexing.out(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0, *, Tensor(a!) out) -> Tensor(a!) + - func: npu_bmmV2(Tensor self, Tensor mat2, int[] output_sizes) -> Tensor + variants: function, method custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/BmmV2KernelNpu.cpp b/torch_npu/csrc/aten/ops/BmmV2KernelNpu.cpp new file mode 100644 index 00000000000..2f696eb0ad3 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BmmV2KernelNpu.cpp @@ -0,0 +1,302 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +bool is_transpose_last_two_dims_v2(const at::Tensor& Tensors) { + if (Tensors.dim() < 2) { + return false; + } + auto storage_size = Tensors.storage().get_npu_desc().storage_sizes_; + int64_t numel = at::prod_intlist(storage_size); + + int64_t dim1 = Tensors.dim() - 1; + int64_t dim2 = Tensors.dim() - 2; + + int64_t tensor_size = Tensors.storage().nbytes() / Tensors.element_size(); + auto tensor_desc = Tensors.storage().get_npu_desc(); + if (tensor_desc.base_sizes_.size() == Tensors.dim() && + Tensors.stride(dim2) == 1 && Tensors.stride(dim1) == Tensors.size(dim2) && + Tensors.size(dim1) == tensor_desc.base_sizes_[dim2] && + Tensors.size(dim2) == tensor_desc.base_sizes_[dim1] && + tensor_size == numel) { + return true; + } else { + return false; + } +} + +c10::SmallVector bmm_v2_output_size(const at::Tensor& mat1, const at::Tensor& mat2) { + auto dim_tensor1 = mat1.dim(); + auto dim_tensor2 = mat2.dim(); + + int64_t m = dim_tensor1 == 1 ? 1 : mat1.size(-2); + int64_t n = dim_tensor2 == 1 ? 1 : mat2.size(-1); + + auto batch_a = array_to_small_vector(at::IntArrayRef(mat1.sizes().data(), std::max(dim_tensor1 - 2, 0))); + auto batch_b = array_to_small_vector(at::IntArrayRef(mat2.sizes().data(), std::max(dim_tensor2 - 2, 0))); + + batch_a.insert(batch_a.begin(), std::max(batch_a.size(), batch_b.size()) - batch_a.size(), 1); + batch_b.insert(batch_b.begin(), std::max(batch_a.size(), batch_b.size()) - batch_b.size(), 1); + + c10::SmallVector output_size; + for (size_t i = 0; i < batch_a.size(); ++i) { + if (batch_a[i] == 1) { + output_size.emplace_back(batch_b[i]); + } else if (batch_b[i] == 1) { + output_size.emplace_back(batch_a[i]); + } else if (batch_a[i] != batch_b[i]) { + AT_ERROR("mat1 and mat2 cannot broadcast, but they are mat1 ", + mat1.sizes().data(), " mat2 ", mat2.sizes().data()); + } else { + output_size.emplace_back(batch_a[i]); + } + } + output_size.emplace_back(m); + output_size.emplace_back(n); + + return output_size; +} + +at::Tensor pure_bmm_v2_npu(const at::Tensor& self, const at::Tensor& mat2, const c10::SmallVector& output_size) { + auto tensor1 = self.dim() == 1 ? self.view({1, self.size(0)}) : self; + auto tensor2 = mat2.dim() == 1 ? mat2.view({mat2.size(0), 1}) : mat2; + + at::Tensor result; + + if ((tensor1.scalar_type() == at::ScalarType::Half)) { + result = at::empty_with_format(output_size, tensor1.options(), ACL_FORMAT_FRACTAL_NZ); + } else { + result = at::empty_with_format(output_size, tensor1.options(), ACL_FORMAT_ND); + } + + at::Tensor contiguous_self = tensor1; + at::Tensor contiguous_mat2 = tensor2; + bool is_self_t = is_transpose_last_two_dims_v2(tensor1); + bool is_mat2_t = is_transpose_last_two_dims_v2(tensor2); + + if(!is_self_t) { + contiguous_self = NpuUtils::format_contiguous(tensor1); + } + if(!is_mat2_t) { + contiguous_mat2 = NpuUtils::format_contiguous(tensor2); + } + + auto func1 = [&contiguous_self]() { + bool pass = false; + return std::tie(pass, contiguous_self); + }; + auto func2 = [&contiguous_mat2]() { + bool pass = false; + return std::tie(pass, contiguous_mat2); + }; + + // executing the NPU operator + OpCommand cmd; + cmd.Name("BatchMatMul") + .InputWithFunc(func1) + .InputWithFunc(func2) + .Output(result) + .Attr("adj_x1", is_self_t) + .Attr("adj_x2", is_mat2_t) + .Run(); + + return result; +} + +at::Tensor reshape_tensor_self(const at::Tensor& self, c10::SmallVector& expect_output_size) { + // self, expect_output: [5,6,7,17], [1,6,7,65] + // self permute + reshape: [5,6,7,17] -> [6,7,5,17] -> [6,7,85] + c10::SmallVector self_permute_idx; + c10::SmallVector self_batch_idx; + + for (int64_t i = 0; i < self.dim(); ++i) { + if (i < self.dim() - 2) { + if (expect_output_size[i] == 1) { + self_batch_idx.emplace_back(i); + continue; + } + } else if (i == self.dim() - 1) { + for (int64_t j = 0; j < self_batch_idx.size(); ++j) { + self_permute_idx.emplace_back(self_batch_idx[j]); + } + } + self_permute_idx.emplace_back(i); + } + at::Tensor tmp_self = self.permute(self_permute_idx); + + int64_t m_idx = 0; + c10::SmallVector tmp_self_size; + c10::SmallVector tmp_self_size_low; + + m_idx = self.dim() - self_batch_idx.size() - 1; + tmp_self_size = array_to_small_vector(tmp_self.sizes()); + tmp_self_size_low.insert(tmp_self_size_low.end(), tmp_self_size.begin(), tmp_self_size.begin() + m_idx); + tmp_self_size_low.emplace_back(-1); + tmp_self = tmp_self.reshape(tmp_self_size_low); + return tmp_self; +} + +at::Tensor reshape_tensor_mat2(const at::Tensor& mat2, c10::SmallVector& expect_output_size) { + // mat2, expect_output_size: [5,6,17,65], [1,6,7,65] + // mat2 permute + reshape: [5,6,17,65] -> [6,5,17,65] -> [6,85,65] + c10::SmallVector mat2_permute_idx; + c10::SmallVector mat2_batch_idx; + + for (int64_t i = 0; i < mat2.dim(); ++i) { + if (i < mat2.dim() - 2) { + if (expect_output_size[i] == 1) { + mat2_batch_idx.emplace_back(i); + continue; + } + } else if (i == mat2.dim() - 2) { + for (int64_t j = 0; j < mat2_batch_idx.size(); ++j) { + mat2_permute_idx.emplace_back(mat2_batch_idx[j]); + } + } + mat2_permute_idx.emplace_back(i); + } + at::Tensor tmp_mat2 = mat2.permute(mat2_permute_idx); + + int64_t k_idx = 0; + c10::SmallVector tmp_mat2_size; + c10::SmallVector tmp_mat2_size_low; + + k_idx = mat2.dim() - mat2_batch_idx.size() - 2; + tmp_mat2_size = array_to_small_vector(tmp_mat2.sizes()); + tmp_mat2_size_low.insert(tmp_mat2_size_low.end(), tmp_mat2_size.begin(), tmp_mat2_size.begin() + k_idx); + tmp_mat2_size_low.insert(tmp_mat2_size_low.end(), {-1, mat2.size(-1)}); + tmp_mat2 = tmp_mat2.reshape(tmp_mat2_size_low); + return tmp_mat2; +} + +c10::SmallVector align_small_vector(c10::SmallVector svec, c10::SmallVector golden_svec) { + // svec, golden: [6,7,65], [5,6,7,65] + // expect: [6,7,65] -> [1,6,7,65] + c10::SmallVector tmp_svec; + tmp_svec = svec; + int64_t size_to_fill = golden_svec.size() - svec.size(); + if (size_to_fill > 0) { + tmp_svec.insert(tmp_svec.begin(), size_to_fill, 1); + } + return tmp_svec; +} + +void expand_tensor(at::Tensor& self, at::Tensor& mat2, c10::SmallVector& expand_output_size) { + self = self.dim() == 1 ? self.view({1, self.size(0)}) : self; + mat2 = mat2.dim() == 1 ? mat2.view({mat2.size(0), 1}) : mat2; + int64_t m = self.size(-2); + int64_t k1 = self.size(-1); + int64_t k2 = mat2.size(-2); + int64_t n = mat2.size(-1); + + std::vector expand_batch_portion(expand_output_size.begin(), expand_output_size.end() - 2); + std::vector self_expand_size(expand_batch_portion); + std::vector mat2_expand_size(expand_batch_portion); + + self_expand_size.insert(self_expand_size.end(), {m, k1}); + mat2_expand_size.insert(mat2_expand_size.end(), {k2, n}); + + int64_t expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), + 1L, std::multiplies()); + + std::vector self_bmm_view({expand_batch_product}); + std::vector mat2_bmm_view({expand_batch_product}); + self_bmm_view.insert(self_bmm_view.end(), {m, k1}); + mat2_bmm_view.insert(mat2_bmm_view.end(), {k2, n}); + + self = self.expand(self_expand_size).reshape(self_bmm_view); + mat2 = mat2.expand(mat2_expand_size).reshape(mat2_bmm_view); +} + +at::Tensor NPUNativeFunctions::npu_bmmV2(const at::Tensor& self, const at::Tensor& mat2, at::IntArrayRef output_sizes) { + auto expect_output_size = array_to_small_vector(output_sizes); + auto infer_output_size = bmm_v2_output_size(self, mat2); + at::Tensor tmp_self = self; + at::Tensor tmp_mat2 = mat2; + + // forward propagation + if (expect_output_size.empty()) { + // special issure for dim n*n + if (tmp_self.dim() == tmp_mat2.dim()) { + return pure_bmm_v2_npu(tmp_self, tmp_mat2, infer_output_size); + } + // avoid some accuracy error caused by transdata + expand_tensor(tmp_self, tmp_mat2, infer_output_size); + expect_output_size = infer_output_size; + infer_output_size = bmm_v2_output_size(tmp_self, tmp_mat2); + + auto res = pure_bmm_v2_npu(tmp_self, tmp_mat2, infer_output_size).view(expect_output_size); + infer_output_size = expect_output_size; + + if (self.dim() == 1) { + // [k][b, k, n] -> [b, 1, n] -> [b, n] + infer_output_size.erase(infer_output_size.end() - 2); + return res.view(infer_output_size); + } else if (mat2.dim() == 1) { + // [b, m, k][k] -> [b, m, 1] -> [b, m] + infer_output_size.erase(infer_output_size.end() - 1); + return res.view(infer_output_size); + } + return res; + } + + // backward propagation + c10::SmallVector tmp_expect_output_size = expect_output_size; + c10::SmallVector axis_reduce; + c10::SmallVector tmp_self_size; + c10::SmallVector tmp_mat2_size; + + tmp_expect_output_size = align_small_vector(expect_output_size, infer_output_size); + for (int i = 0; i < tmp_expect_output_size.size(); ++i) { + if (tmp_expect_output_size[i] != infer_output_size[i]) { + axis_reduce.emplace_back(i); + } + } + + // no reduce_sum + if (axis_reduce.empty()) { + // special issure for dim n*n + if (tmp_self.dim() == tmp_mat2.dim()) { + return pure_bmm_v2_npu(tmp_self, tmp_mat2, infer_output_size); + } + // avoid some accuracy error caused by transdata + expand_tensor(tmp_self, tmp_mat2, infer_output_size); + infer_output_size = bmm_v2_output_size(tmp_self, tmp_mat2); + return pure_bmm_v2_npu(tmp_self, tmp_mat2, infer_output_size).view(expect_output_size); + } + + // reduce sum without accuracy error + tmp_self_size = align_small_vector(array_to_small_vector(self.sizes()), infer_output_size); + tmp_mat2_size = align_small_vector(array_to_small_vector(mat2.sizes()), infer_output_size); + tmp_self = self.reshape(tmp_self_size); + tmp_mat2 = mat2.reshape(tmp_mat2_size); + tmp_self = reshape_tensor_self(tmp_self, tmp_expect_output_size); + tmp_mat2 = reshape_tensor_mat2(tmp_mat2, tmp_expect_output_size); + infer_output_size = bmm_v2_output_size(tmp_self, tmp_mat2); + // avoid some accuracy error caused by transdata + expand_tensor(tmp_self, tmp_mat2, infer_output_size); + infer_output_size = bmm_v2_output_size(tmp_self, tmp_mat2); + return pure_bmm_v2_npu(tmp_self, tmp_mat2, infer_output_size).view(expect_output_size); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee