From 4269f15ae0d5ff9c9771ca0d5705a163f688197f Mon Sep 17 00:00:00 2001 From: w30042780 Date: Thu, 31 Jul 2025 21:23:54 +0800 Subject: [PATCH 1/2] =?UTF-8?q?PpMatMul=E4=B8=8B=E6=B2=89canndev?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/kernels/kernels/matmul/CMakeLists.txt | 14 +- .../kernels/matmul/matmul_operation.cpp | 50 +++---- ...p_mat_mul_bf16nd_bf16nd_bf16nd_kernel.cpp} | 137 ++++++++---------- ...pp_mat_mul_bf16nd_bf16nd_f32nd_kernel.cpp} | 115 +++++++-------- ...p_mat_mul_bf16nd_bf16nz_bf16nd_kernel.cpp} | 133 ++++++++--------- .../pp_mat_mul_f16nd_f16nd_f16nd_kernel.cpp | 62 ++++++++ .../pp_mat_mul_f16nd_f16nd_f32nd_kernel.cpp} | 112 +++++++------- .../pp_mat_mul_f16nd_f16nz_f16nd_kernel.cpp | 68 +++++++++ .../pp_mat_mul_f16nz_f16nz_f16nz_kernel.cpp} | 123 +++++++--------- .../pp_matmul_ein_sum_kernel.cpp | 51 ------- .../pp_matmul_f16_kernel.cpp | 45 ------ src/ops_infer/linear/linear_ops_runner.cpp | 22 +-- .../matmul/test_pp_matmul_accum.py | 36 ++--- .../kernelstest/matmul/test_pp_matmul_bf16.py | 56 +++---- .../matmul/test_pp_matmul_ein_sum_f16.py | 40 ++--- .../kernelstest/matmul/test_pp_matmul_f16.py | 100 ++++++------- .../matmul/test_pp_matmul_with_bias.py | 36 ++--- .../kernelstest/matmul/test_ppmatmul_310p.py | 4 +- 18 files changed, 600 insertions(+), 604 deletions(-) rename src/kernels/kernels/matmul/{pp_matmul_accum_kernel/pp_matmul_accum_kernel.cpp => pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel.cpp} (41%) rename src/kernels/kernels/matmul/{pp_matmul_bf16_kernel/pp_matmul_bf16_kernel.cpp => pp_mat_mul_bf16nd_bf16nd_f32nd_kernel/pp_mat_mul_bf16nd_bf16nd_f32nd_kernel.cpp} (48%) rename src/kernels/kernels/matmul/{pp_matmul_with_bias_kernel/pp_matmul_with_bias_kernel.cpp => pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel.cpp} (38%) create mode 100644 src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nd_f16nd_kernel/pp_mat_mul_f16nd_f16nd_f16nd_kernel.cpp rename src/kernels/kernels/matmul/{pp_matmul_bf16_nd_nz_nd_kernel/pp_matmul_bf16_nd_nz_nd_kernel.cpp => pp_mat_mul_f16nd_f16nd_f32nd_kernel/pp_mat_mul_f16nd_f16nd_f32nd_kernel.cpp} (47%) create mode 100644 src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nz_f16nd_kernel/pp_mat_mul_f16nd_f16nz_f16nd_kernel.cpp rename src/kernels/kernels/matmul/{pp_matmul_f16_opt_kernel/pp_matmul_f16_opt_kernel.cpp => pp_mat_mul_f16nz_f16nz_f16nz_kernel/pp_mat_mul_f16nz_f16nz_f16nz_kernel.cpp} (60%) delete mode 100644 src/kernels/kernels/matmul/pp_matmul_ein_sum_kernel/pp_matmul_ein_sum_kernel.cpp delete mode 100644 src/kernels/kernels/matmul/pp_matmul_f16_kernel/pp_matmul_f16_kernel.cpp diff --git a/src/kernels/kernels/matmul/CMakeLists.txt b/src/kernels/kernels/matmul/CMakeLists.txt index 2fa62f2a..cd6bba9a 100644 --- a/src/kernels/kernels/matmul/CMakeLists.txt +++ b/src/kernels/kernels/matmul/CMakeLists.txt @@ -4,19 +4,12 @@ set(matmul_srcs ${CMAKE_CURRENT_LIST_DIR}/batch_matmul_kernel/batch_matmul_nz_kernel.cpp ${CMAKE_CURRENT_LIST_DIR}/matmul_kernel/matmul_nd_kernel.cpp ${CMAKE_CURRENT_LIST_DIR}/matmul_kernel/matmul_nz_kernel.cpp - ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_bf16_kernel/pp_matmul_bf16_kernel.cpp - ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_bf16_nd_nz_nd_kernel/pp_matmul_bf16_nd_nz_nd_kernel.cpp ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_f16_mix_kernel/pp_matmul_f16_mix_kernel.cpp - ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_f16_kernel/pp_matmul_f16_kernel.cpp - ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_f16_opt_kernel/pp_matmul_f16_opt_kernel.cpp ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_i8_kernel/pp_matmul_i8_kernel.cpp ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_f16_m300_kernel/pp_matmul_f16_m300_kernel.cpp ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_i8_nz_kernel/pp_matmul_i8_nz_kernel.cpp ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_i8_nz_compress_kernel/pp_matmul_i8_nz_compress_kernel.cpp ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_kernel/pp_matmul_nz_kernel.cpp - ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_accum_kernel/pp_matmul_accum_kernel.cpp - ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_with_bias_kernel/pp_matmul_with_bias_kernel.cpp - ${CMAKE_CURRENT_LIST_DIR}/pp_matmul_ein_sum_kernel/pp_matmul_ein_sum_kernel.cpp ${CMAKE_CURRENT_LIST_DIR}/tiling/matmul_nd_tiling.cpp ${CMAKE_CURRENT_LIST_DIR}/tiling/matmul_nz_tiling.cpp ${CMAKE_CURRENT_LIST_DIR}/tiling/pp_matmul_mix_tiling.cpp @@ -24,6 +17,13 @@ set(matmul_srcs ${CMAKE_CURRENT_LIST_DIR}/tiling/pp_matmul_i8_nz_tiling.cpp ${CMAKE_CURRENT_LIST_DIR}/tiling/pp_matmul_nz_tiling.cpp ${CMAKE_CURRENT_LIST_DIR}/tiling/pp_matmul_nd_tiling.cpp + ${CMAKE_CURRENT_LIST_DIR}/pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel.cpp + ${CMAKE_CURRENT_LIST_DIR}/pp_mat_mul_bf16nd_bf16nd_f32nd_kernel/pp_mat_mul_bf16nd_bf16nd_f32nd_kernel.cpp + ${CMAKE_CURRENT_LIST_DIR}/pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel.cpp + ${CMAKE_CURRENT_LIST_DIR}/pp_mat_mul_f16nd_f16nd_f16nd_kernel/pp_mat_mul_f16nd_f16nd_f16nd_kernel.cpp + ${CMAKE_CURRENT_LIST_DIR}/pp_mat_mul_f16nd_f16nd_f32nd_kernel/pp_mat_mul_f16nd_f16nd_f32nd_kernel.cpp + ${CMAKE_CURRENT_LIST_DIR}/pp_mat_mul_f16nd_f16nz_f16nd_kernel/pp_mat_mul_f16nd_f16nz_f16nd_kernel.cpp + ${CMAKE_CURRENT_LIST_DIR}/pp_mat_mul_f16nz_f16nz_f16nz_kernel/pp_mat_mul_f16nz_f16nz_f16nz_kernel.cpp ) add_operation(MatMulOperation "${matmul_srcs}") diff --git a/src/kernels/kernels/matmul/matmul_operation.cpp b/src/kernels/kernels/matmul/matmul_operation.cpp index 1d14d76d..5bfc849e 100644 --- a/src/kernels/kernels/matmul/matmul_operation.cpp +++ b/src/kernels/kernels/matmul/matmul_operation.cpp @@ -30,19 +30,20 @@ constexpr uint32_t PP_MATMUL_I8_BF16_WEIGHT_NZ_KERNEL_KEY = 0b1'00'00'10'0'1'0'1 constexpr uint32_t PP_MATMUL_I8_FP16_WEIGHT_NZ_KERNEL_KEY = 0b1'00'00'01'0'1'0'11; constexpr uint32_t PP_MATMUL_I8_KERNEL_KEY = 0b1'00'00'01'0'0'0'11; constexpr uint32_t PP_MATMUL_I8_WEIGHT_NZ_KERNEL_KEY = 0b1'00'00'01'0'1'0'11; -constexpr uint32_t PP_MATMUL_F16_KERNEL_KEY = 0b1'01'01'01'0'0'0'00; constexpr uint32_t PP_MATMUL_F16_MIX_KERNEL_KEY = 0b1'01'01'01'0'0'0'01; -constexpr uint32_t PP_MATMUL_BF16_KERNEL_KEY = 0b1'10'10'10'0'0'0'00; -constexpr uint32_t PP_MATMUL_F16_OPT_KERNEL_KEY = 0b1'01'01'01'0'1'0'00; -constexpr uint32_t PP_MATMUL_BF16_ND_NZ_ND_KERNEL_KEY = 0b1'10'10'10'0'1'0'00; -constexpr uint32_t PP_MATMUL_NZ_F16_KERNEL_KEY = 0b0'01'01'01'1'1'1'00; constexpr uint32_t PP_MATMUL_I8_NZ_KERNEL_KEY = 0b0'00'00'01'1'1'1'11; constexpr uint32_t PP_MATMUL_I8_NZ_COMPRESS_KERNEL_KEY = 0b0'00'00'01'1'1'1'11; -constexpr uint32_t PP_MATMUL_ACCUM_KERNEL_KEY = 0b1'10'10'11'0'0'0'00; constexpr uint32_t PP_MATMUL_FP_ND_ND_KERNEL_KEY = 0b0'01'01'01'0'0'0'00; constexpr uint32_t PP_MATMUL_I8_NZ_M300_KERNEL_KEY = 0b0'00'00'01'0'1'0'11; constexpr uint32_t PP_MATMUL_I8_ND_M300_KERNEL_KEY = 0b0'00'00'01'0'0'0'11; constexpr uint32_t PP_MATMUL_F16_NZ_M300_KERNEL_KEY = 0b0'01'01'01'0'1'0'00; +constexpr uint32_t PP_MatMul_F16ND_F16ND_F16ND_KERNEL_KEY = 0b1'01'01'01'0'0'0'10; +constexpr uint32_t PP_MatMul_BF16ND_BF16ND_BF16ND_KERNEL_KEY = 0b1'10'10'10'0'0'0'10; +constexpr uint32_t PP_MatMul_F16ND_F16ND_F32ND_KERNEL_KEY = 0b1'01'01'11'0'0'0'10; +constexpr uint32_t PP_MatMul_BF16ND_BF16ND_F32ND_KERNEL_KEY = 0b1'10'10'11'0'0'0'10; +constexpr uint32_t PP_MatMul_F16ND_F16NZ_F16ND_KERNEL_KEY = 0b1'01'01'01'0'1'0'10; +constexpr uint32_t PP_MatMul_BF16ND_BF16NZ_BF16ND_KERNEL_KEY = 0b1'10'10'10'0'1'0'10; +constexpr uint32_t PP_MatMul_F16NZ_F16NZ_F16NZ_KERNEL_KEY = 0b0'01'01'01'1'1'1'10; } // namespace @@ -107,16 +108,7 @@ public: kernelKey = (kernelKey << INPUT_BIT_COUNT) + (inTensorCount - DIM_2); MKI_LOG(INFO) << "kernelKey: " << kernelKey; MKI_LOG(INFO) << ">>> PpMatmulType:" << static_cast(opParam.matmulType); - if (opParam.matmulType == MmType::MATMUL_ACCUM_ATOMIC) { - return GetKernelByName("PpMatmulAccumAtomicKernel"); - } - if (opParam.matmulType == MmType::MATMUL_WITH_BIAS) { - return GetKernelByName("PpMatmulWithBiasKernel"); - } - if (opParam.matmulType == MmType::MATMUL_EIN_SUM) { - MKI_CHECK(!opParam.transposeA, "Unsupported transposed A_matrix", return nullptr); - return GetKernelByName("PpMatmulEinSumKernel"); - } + // 先判断w8a8compress if (isSparseDequant) { switch (kernelKey) { @@ -139,12 +131,7 @@ public: case PP_MATMUL_I8_BF16_WEIGHT_NZ_KERNEL_KEY: return GetKernelByName("PpMatmulW8A8Bf16NDNZKernel"); case PP_MATMUL_I8_KERNEL_KEY: return GetKernelByName("PpMatmulW8A8Kernel"); case PP_MATMUL_I8_WEIGHT_NZ_KERNEL_KEY: return GetKernelByName("PpMatmulW8A8WeightNzKernel"); - case PP_MATMUL_F16_KERNEL_KEY: return GetKernelByName("PpMatMulF16Kernel"); case PP_MATMUL_F16_MIX_KERNEL_KEY: return GetKernelByName("PpMatMulF16MixKernel"); - case PP_MATMUL_BF16_KERNEL_KEY: return GetKernelByName("PpMatMulBf16Kernel"); - case PP_MATMUL_F16_OPT_KERNEL_KEY: return GetKernelByName("PpMatMulF16OptKernel"); - case PP_MATMUL_BF16_ND_NZ_ND_KERNEL_KEY: return GetKernelByName("PpMatMulBf16NdNzNdKernel"); - case PP_MATMUL_NZ_F16_KERNEL_KEY: return GetKernelByName("PpMatMulNzF16Kernel"); case PP_MATMUL_I8_NZ_KERNEL_KEY: if (platform == PlatformType::ASCEND_910A) { return GetKernelByName("PpMatMulI8NzKernel"); @@ -155,6 +142,20 @@ public: case PP_MATMUL_I8_ND_M300_KERNEL_KEY: return GetKernelByName("PpMatMulI8Kernel"); case PP_MATMUL_I8_NZ_M300_KERNEL_KEY: return GetKernelByName("PpMatMulI8NdNzKernel"); case PP_MATMUL_F16_NZ_M300_KERNEL_KEY: return GetKernelByName("PpMatmulF16NzM300Kernel"); + case PP_MatMul_F16ND_F16ND_F16ND_KERNEL_KEY: return GetKernelByName("PpMatMulF16NDF16NDF16NDKernel"); + case PP_MatMul_BF16ND_BF16ND_BF16ND_KERNEL_KEY: return GetKernelByName("PpMatMulBF16NDBF16NDBF16NDKernel"); + case PP_MatMul_F16ND_F16ND_F32ND_KERNEL_KEY: return GetKernelByName("PpMatMulF16NDF16NDF32NDKernel"); + case PP_MatMul_BF16ND_BF16ND_F32ND_KERNEL_KEY: return GetKernelByName("PpMatMulBF16NDBF16NDF32NDKernel"); + case PP_MatMul_F16ND_F16NZ_F16ND_KERNEL_KEY: return GetKernelByName("PpMatMulF16NDF16NZF16NDKernel"); + case PP_MatMul_BF16ND_BF16NZ_BF16ND_KERNEL_KEY: return GetKernelByName("PpMatMulBF16NDBF16NZBF16NDKernel"); + case PP_MatMul_F16NZ_F16NZ_F16NZ_KERNEL_KEY: + { + if (platform == PlatformType::ASCEND_910A) { + return GetKernelByName("PpMatMulNzF16Kernel"); + } else { + return GetKernelByName("PpMatMulF16NZF16NZF16NZKernel"); + } + } default: MKI_LOG(ERROR) << "No matched kernel for matmul operation."; return nullptr; } return nullptr; @@ -276,12 +277,7 @@ public: if (param.enDequant) { return 5; // There're 5 inputs if enable post dequant. } - bool fuseAdd = (param.withBias || param.matmulType == MmType::MATMUL_ACCUM_ATOMIC || - param.matmulType == MmType::MATMUL_WITH_BIAS); - if (fuseAdd) { - return 3; // 3 withBias matmul - } - return 2; // matmul has 2 inputs + return 4; // matmul has 2 inputs } protected: diff --git a/src/kernels/kernels/matmul/pp_matmul_accum_kernel/pp_matmul_accum_kernel.cpp b/src/kernels/kernels/matmul/pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel.cpp similarity index 41% rename from src/kernels/kernels/matmul/pp_matmul_accum_kernel/pp_matmul_accum_kernel.cpp rename to src/kernels/kernels/matmul/pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel.cpp index 07a30f22..a51af2b7 100644 --- a/src/kernels/kernels/matmul/pp_matmul_accum_kernel/pp_matmul_accum_kernel.cpp +++ b/src/kernels/kernels/matmul/pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nd_bf16nd_kernel.cpp @@ -1,74 +1,63 @@ -/* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include -#include -#include -#include -#include "asdops/params/params.h" -#include "kernels/matmul/common/common.h" -#include "kernels/matmul/common/common_tiling.h" -#include "kernels/matmul/tiling/pp_matmul_tiling.h" -#include "kernels/matmul/tiling/tiling_data.h" - -namespace AsdOps { -class PpMatmulAccumCommonKernel : public KernelBase { -public: - explicit PpMatmulAccumCommonKernel(const std::string &kernelName, const BinHandle *handle) noexcept - : KernelBase(kernelName, handle) - { - } - - bool CanSupport(const LaunchParam &launchParam) const override - { - MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "Platform not support.", - return false); - MKI_CHECK(launchParam.GetInTensorCount() == 3, "Check inTensor count failed.", return false); - MKI_CHECK(launchParam.GetOutTensorCount() == 1, "Check outTensor count failed.", return false); - - const auto &descA = launchParam.GetInTensor(0).desc; - const auto &descB = launchParam.GetInTensor(1).desc; - const auto &descC = launchParam.GetInTensor(2).desc; - - MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "Tensor format is invalid.", return false); - MKI_CHECK(descA.dtype == TENSOR_DTYPE_BF16 || descA.dtype == TENSOR_DTYPE_FLOAT16, - "Tensor dtype is invalid.", return false); - MKI_CHECK(descB.format == TENSOR_FORMAT_ND, "Tensor format is invalid.", return false); - MKI_CHECK(descB.dtype == TENSOR_DTYPE_BF16 || descB.dtype == TENSOR_DTYPE_FLOAT16, - "Tensor dtype is invalid.", return false); - MKI_CHECK(descA.dtype == descB.dtype, "Input tensor's dtype must be the same.", return false); - MKI_CHECK(descC.format == TENSOR_FORMAT_ND, "Tensor format is invalid.", return false); - MKI_CHECK(descC.dtype == TENSOR_DTYPE_FLOAT, "Tensor dtype is invalid.", return false); - - return true; - } - - uint64_t GetTilingSize(const LaunchParam &launchParam) const override - { - (void)launchParam; - constexpr uint32_t CONST_256 = 256; - return Round(sizeof(PpMatmulTilingData)); - } -}; - -class PpMatmulAccumAtomicKernel : public PpMatmulAccumCommonKernel { -public: - explicit PpMatmulAccumAtomicKernel(const std::string &kernelName, const BinHandle *handle) noexcept - : PpMatmulAccumCommonKernel(kernelName, handle) - { - } - - Status InitImpl(const LaunchParam &launchParam) override - { - return PpMatmulTiling(launchParam, kernelInfo_); - } -}; - -REG_KERNEL_BASE(PpMatmulAccumAtomicKernel); -} // namespace AsdOps \ No newline at end of file +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include "asdops/params/params.h" +#include "kernels/matmul/tiling/pp_matmul_tiling.h" +#include "kernels/matmul/tiling/tiling_data.h" +#include "kernels/matmul/common/common.h" +#include "kernels/matmul/common/common_tiling.h" +#include "sink_common.h" + +namespace AsdOps { +class PpMatMulBF16NDBF16NDBF16NDKernel : public KernelBase { +public: + explicit PpMatMulBF16NDBF16NDBF16NDKernel(const std::string &kernelName, const BinHandle *handle) noexcept + : KernelBase(kernelName, handle) + { + } + + bool CanSupport(const LaunchParam &launchParam) const override + { + MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "platform not support", + return false); + MKI_CHECK(launchParam.GetInTensorCount() == 4, "check inTensor count failed", return false); + MKI_CHECK(launchParam.GetOutTensorCount() == 1, "check outTensor count failed", return false); + + const auto &descA = launchParam.GetInTensor(0).desc; + const auto &descB = launchParam.GetInTensor(1).desc; + + MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descA.dtype == TENSOR_DTYPE_BF16, "tensor dtype invalid", return false); + + MKI_CHECK(descB.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descB.dtype == TENSOR_DTYPE_BF16, "tensor dtype invalid", return false); + + MKI_CHECK(descA.dims.size() == descB.dims.size(), "tensor dims invalid", return false); + auto opParam = AnyCast(launchParam.GetParam()); + if (opParam.matmulType == OpParam::MatMul::MatMulType::MATMUL_EIN_SUM) { + MKI_CHECK(descA.dims.size() == 3, "Matrix A must be a 3D-tensor in MATMUL_EIN_SUM scenario.", return false); + MKI_CHECK(descA.dims[1] == descB.dims[0], "tensor dims invalid: batchA != batchB", return false); + } else { + MKI_CHECK((descA.dims.size() == 2) || (descA.dims[0] == descB.dims[0]), "tensor dims invalid: batchA != batchB", return false); + } + + return true; + } + + Status InitImpl(const LaunchParam &launchParam) override + { + return optiling::CallGeTiling("PpMatMul", *GetBinHandle(), launchParam, + AsdOps::GetMkiSpecificAttr, kernelInfo_); + } +}; +REG_KERNEL_BASE(PpMatMulBF16NDBF16NDBF16NDKernel); +} // namespace AsdOps diff --git a/src/kernels/kernels/matmul/pp_matmul_bf16_kernel/pp_matmul_bf16_kernel.cpp b/src/kernels/kernels/matmul/pp_mat_mul_bf16nd_bf16nd_f32nd_kernel/pp_mat_mul_bf16nd_bf16nd_f32nd_kernel.cpp similarity index 48% rename from src/kernels/kernels/matmul/pp_matmul_bf16_kernel/pp_matmul_bf16_kernel.cpp rename to src/kernels/kernels/matmul/pp_mat_mul_bf16nd_bf16nd_f32nd_kernel/pp_mat_mul_bf16nd_bf16nd_f32nd_kernel.cpp index 8691eb15..b7d2b18c 100644 --- a/src/kernels/kernels/matmul/pp_matmul_bf16_kernel/pp_matmul_bf16_kernel.cpp +++ b/src/kernels/kernels/matmul/pp_mat_mul_bf16nd_bf16nd_f32nd_kernel/pp_mat_mul_bf16nd_bf16nd_f32nd_kernel.cpp @@ -1,60 +1,55 @@ -/* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include -#include -#include -#include -#include "asdops/params/params.h" -#include "kernels/matmul/tiling/tiling_data.h" -#include "kernels/matmul/tiling/pp_matmul_tiling.h" -#include "kernels/matmul/common/common_tiling.h" - -namespace AsdOps { -using namespace Mki; -class PpMatMulBf16Kernel : public KernelBase { -public: - explicit PpMatMulBf16Kernel(const std::string &kernelName, const BinHandle *handle) noexcept - : KernelBase(kernelName, handle) - { - } - - bool CanSupport(const LaunchParam &launchParam) const override - { - MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, - "platformtype is invalid", return false); - MKI_CHECK(launchParam.GetInTensorCount() == 2, "in tensor num invalid", return false); - MKI_CHECK(launchParam.GetOutTensorCount() == 1, "out tensor num invalid", return false); - - const auto &descA = launchParam.GetInTensor(0).desc; - const auto &descB = launchParam.GetInTensor(1).desc; - - MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "in tensor0 format invalid", return false); - MKI_CHECK(descA.dtype == TENSOR_DTYPE_BF16, "in tensor0 dtype invalid", return false); - - MKI_CHECK(descB.format == TENSOR_FORMAT_ND, "in tensor1 format invalid", return false); - MKI_CHECK(descB.dtype == TENSOR_DTYPE_BF16, "in tensor1 dtype invalid", return false); - - return true; - } - - uint64_t GetTilingSize(const LaunchParam &launchParam) const override - { - (void)launchParam; - constexpr uint32_t CONST_256 = 256; - return RoundUp(sizeof(PpMatmulTilingData), CONST_256); - } - - Status InitImpl(const LaunchParam &launchParam) override - { - return PpMatmulTiling(launchParam, kernelInfo_); - } -}; -REG_KERNEL_BASE(PpMatMulBf16Kernel); -} // namespace AsdOps \ No newline at end of file +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include "asdops/params/params.h" +#include "kernels/matmul/tiling/pp_matmul_tiling.h" +#include "kernels/matmul/tiling/tiling_data.h" +#include "kernels/matmul/common/common.h" +#include "kernels/matmul/common/common_tiling.h" +#include "sink_common.h" + +namespace AsdOps { +class PpMatMulBF16NDBF16NDF32NDKernel : public KernelBase { +public: + explicit PpMatMulBF16NDBF16NDF32NDKernel(const std::string &kernelName, const BinHandle *handle) noexcept + : KernelBase(kernelName, handle) + { + } + + bool CanSupport(const LaunchParam &launchParam) const override + { + MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "platform not support", + return false); + MKI_CHECK(launchParam.GetInTensorCount() == 4, "check inTensor count failed", return false); + MKI_CHECK(launchParam.GetOutTensorCount() == 1, "check outTensor count failed", return false); + + const auto &descA = launchParam.GetInTensor(0).desc; + const auto &descB = launchParam.GetInTensor(1).desc; + + MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descA.dtype == TENSOR_DTYPE_BF16, "tensor dtype invalid", return false); + + MKI_CHECK(descB.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descB.dtype == TENSOR_DTYPE_BF16, "tensor dtype invalid", return false); + + MKI_CHECK((descA.dims.size() == 2) || (descA.dims[0] == descB.dims[0]), "tensor dims invalid: batchA != batchB", return false); + return true; + } + + Status InitImpl(const LaunchParam &launchParam) override + { + return optiling::CallGeTiling("PpMatMul", *GetBinHandle(), launchParam, + AsdOps::GetMkiSpecificAttr, kernelInfo_); + } +}; +REG_KERNEL_BASE(PpMatMulBF16NDBF16NDF32NDKernel); +} // namespace AsdOps diff --git a/src/kernels/kernels/matmul/pp_matmul_with_bias_kernel/pp_matmul_with_bias_kernel.cpp b/src/kernels/kernels/matmul/pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel.cpp similarity index 38% rename from src/kernels/kernels/matmul/pp_matmul_with_bias_kernel/pp_matmul_with_bias_kernel.cpp rename to src/kernels/kernels/matmul/pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel.cpp index bac9a200..3deb5ad2 100644 --- a/src/kernels/kernels/matmul/pp_matmul_with_bias_kernel/pp_matmul_with_bias_kernel.cpp +++ b/src/kernels/kernels/matmul/pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel/pp_mat_mul_bf16nd_bf16nz_bf16nd_kernel.cpp @@ -1,66 +1,67 @@ -/* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include -#include -#include -#include -#include "asdops/params/params.h" -#include "kernels/matmul/common/common.h" -#include "kernels/matmul/tiling/pp_matmul_tiling.h" -#include "kernels/matmul/tiling/tiling_data.h" - -namespace AsdOps { -class PpMatmulWithBiasKernel : public KernelBase { -public: - explicit PpMatmulWithBiasKernel(const std::string &kernelName, const BinHandle *handle) noexcept - : KernelBase(kernelName, handle) - { - } - - bool CanSupport(const LaunchParam &launchParam) const override - { - MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "Platform not support.", - return false); - MKI_CHECK(launchParam.GetInTensorCount() == 3, "Check inTensor count failed.", return false); - MKI_CHECK(launchParam.GetOutTensorCount() == 1, "Check outTensor count failed.", return false); - - const auto &descA = launchParam.GetInTensor(0).desc; - const auto &descB = launchParam.GetInTensor(1).desc; - const auto &descBias = launchParam.GetInTensor(2).desc; - const auto &descC = launchParam.GetOutTensor(0).desc; - - MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "Tensor format is invalid.", return false); - MKI_CHECK(descA.dtype == TENSOR_DTYPE_BF16 || descA.dtype == TENSOR_DTYPE_FLOAT16, - "Tensor dtype is invalid.", return false); - MKI_CHECK(descB.format == TENSOR_FORMAT_ND, "Tensor format is invalid.", return false); - MKI_CHECK(descB.dtype == TENSOR_DTYPE_BF16 || descB.dtype == TENSOR_DTYPE_FLOAT16, - "Tensor dtype is invalid.", return false); - MKI_CHECK(descBias.format == TENSOR_FORMAT_ND, "Tensor format is invalid.", return false); - MKI_CHECK(descBias.dtype == TENSOR_DTYPE_FLOAT, "Tensor dtype is invalid.", return false); - - MKI_CHECK(descA.dtype == descB.dtype, "Input dtype must be the same.", return false); - MKI_CHECK(descA.dtype == descC.dtype, "Output dtype must be the same as input dtype.", return false); - MKI_CHECK(descC.format == TENSOR_FORMAT_ND, "Tensor format is invalid.", return false); - return true; - } - - uint64_t GetTilingSize(const LaunchParam &launchParam) const override - { - return GetMatmulTilingSizeCommon(launchParam); - } - - Status InitImpl(const LaunchParam &launchParam) override - { - return PpMatmulTiling(launchParam, kernelInfo_); - } -}; - -REG_KERNEL_BASE(PpMatmulWithBiasKernel); -} // namespace AsdOps \ No newline at end of file +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include "asdops/params/params.h" +#include "kernels/matmul/tiling/pp_matmul_tiling.h" +#include "kernels/matmul/tiling/tiling_data.h" +#include "kernels/matmul/common/common.h" +#include "kernels/matmul/common/common_tiling.h" +#include "sink_common.h" + +namespace AsdOps { +class PpMatMulBF16NDBF16NZBF16NDKernel : public KernelBase { +public: + explicit PpMatMulBF16NDBF16NZBF16NDKernel(const std::string &kernelName, const BinHandle *handle) noexcept + : KernelBase(kernelName, handle) + { + } + + bool CanSupport(const LaunchParam &launchParam) const override + { + MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "platform not support", + return false); + MKI_CHECK(launchParam.GetInTensorCount() == 4, "check inTensor count failed", return false); + MKI_CHECK(launchParam.GetOutTensorCount() == 1, "check outTensor count failed", return false); + + const auto &descA = launchParam.GetInTensor(0).desc; + const auto &descB = launchParam.GetInTensor(1).desc; + + MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descA.dtype == TENSOR_DTYPE_BF16, "tensor dtype invalid", return false); + + MKI_CHECK(descB.format == TENSOR_FORMAT_FRACTAL_NZ, "tensor format invalid", return false); + MKI_CHECK(descB.dtype == TENSOR_DTYPE_BF16, "tensor dtype invalid", return false); + + auto opParam = AnyCast(launchParam.GetParam()); + if (opParam.matmulType == OpParam::MatMul::MatMulType::MATMUL_EIN_SUM) { + MKI_CHECK(descA.dims.size() == 3, "Matrix A must be a 3D-tensor in MATMUL_EIN_SUM scenario.", return false); + MKI_CHECK(descA.dims[1] == descB.dims[0], "tensor dims invalid: batchA != batchB", return false); + } else { + if (descA.dims.size() == 2) { // 2: The first input is a two-dimensional tensor while batch size is one. + MKI_CHECK(descB.dims[0] == 1, "tensor dims invalid", return false); + } else { + MKI_CHECK(descA.dims.size() == 3, "tensor dims invalid", return false); + MKI_CHECK(descA.dims[0] == descB.dims[0], "tensor dims invalid", return false); + } + } + + return true; + } + + Status InitImpl(const LaunchParam &launchParam) override + { + return optiling::CallGeTiling("PpMatMul", *GetBinHandle(), launchParam, + AsdOps::GetMkiSpecificAttr, kernelInfo_); + } +}; +REG_KERNEL_BASE(PpMatMulBF16NDBF16NZBF16NDKernel); +} // namespace AsdOps diff --git a/src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nd_f16nd_kernel/pp_mat_mul_f16nd_f16nd_f16nd_kernel.cpp b/src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nd_f16nd_kernel/pp_mat_mul_f16nd_f16nd_f16nd_kernel.cpp new file mode 100644 index 00000000..11abda49 --- /dev/null +++ b/src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nd_f16nd_kernel/pp_mat_mul_f16nd_f16nd_f16nd_kernel.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include "asdops/params/params.h" +#include "kernels/matmul/tiling/pp_matmul_tiling.h" +#include "kernels/matmul/tiling/tiling_data.h" +#include "kernels/matmul/common/common.h" +#include "kernels/matmul/common/common_tiling.h" +#include "sink_common.h" + +namespace AsdOps { +class PpMatMulF16NDF16NDF16NDKernel : public KernelBase { +public: + explicit PpMatMulF16NDF16NDF16NDKernel(const std::string &kernelName, const BinHandle *handle) noexcept + : KernelBase(kernelName, handle) + { + } + + bool CanSupport(const LaunchParam &launchParam) const override + { + MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "platform not support", + return false); + MKI_CHECK(launchParam.GetInTensorCount() == 4, "check inTensor count failed", return false); + MKI_CHECK(launchParam.GetOutTensorCount() == 1, "check outTensor count failed", return false); + + const auto &descA = launchParam.GetInTensor(0).desc; + const auto &descB = launchParam.GetInTensor(1).desc; + + MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descA.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); + + MKI_CHECK(descB.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descB.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); + + auto opParam = AnyCast(launchParam.GetParam()); + if (opParam.matmulType == OpParam::MatMul::MatMulType::MATMUL_EIN_SUM) { + MKI_CHECK(descA.dims.size() == 3, "Matrix A must be a 3D-tensor in MATMUL_EIN_SUM scenario.", return false); + MKI_CHECK(descA.dims[1] == descB.dims[0], "tensor dims invalid: batchA != batchB", return false); + } else { + MKI_CHECK((descA.dims.size() == 2) || (descA.dims[0] == descB.dims[0]), "tensor dims invalid: batchA != batchB", return false); + } + + return true; + } + + Status InitImpl(const LaunchParam &launchParam) override + { + return optiling::CallGeTiling("PpMatMul", *GetBinHandle(), launchParam, + AsdOps::GetMkiSpecificAttr, kernelInfo_); + } +}; +REG_KERNEL_BASE(PpMatMulF16NDF16NDF16NDKernel); +} // namespace AsdOps diff --git a/src/kernels/kernels/matmul/pp_matmul_bf16_nd_nz_nd_kernel/pp_matmul_bf16_nd_nz_nd_kernel.cpp b/src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nd_f32nd_kernel/pp_mat_mul_f16nd_f16nd_f32nd_kernel.cpp similarity index 47% rename from src/kernels/kernels/matmul/pp_matmul_bf16_nd_nz_nd_kernel/pp_matmul_bf16_nd_nz_nd_kernel.cpp rename to src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nd_f32nd_kernel/pp_mat_mul_f16nd_f16nd_f32nd_kernel.cpp index 0bc16dc3..59c0aedf 100644 --- a/src/kernels/kernels/matmul/pp_matmul_bf16_nd_nz_nd_kernel/pp_matmul_bf16_nd_nz_nd_kernel.cpp +++ b/src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nd_f32nd_kernel/pp_mat_mul_f16nd_f16nd_f32nd_kernel.cpp @@ -1,59 +1,53 @@ -/* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include -#include -#include -#include -#include "asdops/params/params.h" -#include "kernels/matmul/tiling/tiling_data.h" -#include "kernels/matmul/tiling/pp_matmul_tiling.h" - -namespace AsdOps { -using namespace Mki; -class PpMatMulBf16NdNzNdKernel : public KernelBase { -public: - explicit PpMatMulBf16NdNzNdKernel(const std::string &kernelName, const BinHandle *handle) noexcept - : KernelBase(kernelName, handle) - { - } - - bool CanSupport(const LaunchParam &launchParam) const override - { - MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "platformtype is invalid", - return false); - MKI_CHECK(launchParam.GetInTensorCount() == 2, "in tensor num invalid", return false); - MKI_CHECK(launchParam.GetOutTensorCount() == 1, "out tensor num invalid", return false); - - const auto &descA = launchParam.GetInTensor(0).desc; - const auto &descB = launchParam.GetInTensor(1).desc; - - MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "in tensor0 format invalid", return false); - MKI_CHECK(descA.dtype == TENSOR_DTYPE_BF16, "in tensor0 dtype invalid", return false); - - MKI_CHECK(descB.format == TENSOR_FORMAT_FRACTAL_NZ, "in tensor1 format invalid", return false); - MKI_CHECK(descB.dtype == TENSOR_DTYPE_BF16, "in tensor1 dtype invalid", return false); - - return true; - } - - uint64_t GetTilingSize(const LaunchParam &launchParam) const override - { - (void)launchParam; - constexpr uint32_t CONST_256 = 256; - return (sizeof(PpMatmulTilingData) + CONST_256 - 1) / CONST_256 * CONST_256; - } - - Status InitImpl(const LaunchParam &launchParam) override - { - return PpMatmulTiling(launchParam, kernelInfo_); - } -}; -REG_KERNEL_BASE(PpMatMulBf16NdNzNdKernel); -} // namespace AsdOps \ No newline at end of file +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include "asdops/params/params.h" +#include "kernels/matmul/tiling/pp_matmul_tiling.h" +#include "kernels/matmul/tiling/tiling_data.h" +#include "kernels/matmul/common/common.h" +#include "kernels/matmul/common/common_tiling.h" +#include "sink_common.h" + +namespace AsdOps { +class PpMatMulF16NDF16NDF32NDKernel : public KernelBase { +public: + explicit PpMatMulF16NDF16NDF32NDKernel(const std::string &kernelName, const BinHandle *handle) noexcept + : KernelBase(kernelName, handle) + { + } + bool CanSupport(const LaunchParam &launchParam) const override + { + MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "platform not support", + return false); + MKI_CHECK(launchParam.GetInTensorCount() == 4, "check inTensor count failed", return false); + MKI_CHECK(launchParam.GetOutTensorCount() == 1, "check outTensor count failed", return false); + + const auto &descA = launchParam.GetInTensor(0).desc; + const auto &descB = launchParam.GetInTensor(1).desc; + + MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descA.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); + + MKI_CHECK(descB.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descB.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); + MKI_CHECK((descA.dims.size() == 2) || (descA.dims[0] == descB.dims[0]), "tensor dims invalid: batchA != batchB", return false); + return true; + } + + Status InitImpl(const LaunchParam &launchParam) override + { + return optiling::CallGeTiling("PpMatMul", *GetBinHandle(), launchParam, + AsdOps::GetMkiSpecificAttr, kernelInfo_); + } +}; +REG_KERNEL_BASE(PpMatMulF16NDF16NDF32NDKernel); +} // namespace AsdOps diff --git a/src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nz_f16nd_kernel/pp_mat_mul_f16nd_f16nz_f16nd_kernel.cpp b/src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nz_f16nd_kernel/pp_mat_mul_f16nd_f16nz_f16nd_kernel.cpp new file mode 100644 index 00000000..374a2b2b --- /dev/null +++ b/src/kernels/kernels/matmul/pp_mat_mul_f16nd_f16nz_f16nd_kernel/pp_mat_mul_f16nd_f16nz_f16nd_kernel.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include "asdops/params/params.h" +#include "kernels/matmul/tiling/pp_matmul_tiling.h" +#include "kernels/matmul/tiling/tiling_data.h" +#include "kernels/matmul/common/common.h" +#include "kernels/matmul/common/common_tiling.h" +#include "sink_common.h" + +namespace AsdOps { +class PpMatMulF16NDF16NZF16NDKernel : public KernelBase { +public: + explicit PpMatMulF16NDF16NZF16NDKernel(const std::string &kernelName, const BinHandle *handle) noexcept + : KernelBase(kernelName, handle) + { + } + + bool CanSupport(const LaunchParam &launchParam) const override + { + MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "platform not support", + return false); + MKI_CHECK(launchParam.GetInTensorCount() == 4, "check inTensor count failed", return false); + MKI_CHECK(launchParam.GetOutTensorCount() == 1, "check outTensor count failed", return false); + + const auto &descA = launchParam.GetInTensor(0).desc; + const auto &descB = launchParam.GetInTensor(1).desc; + + MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); + MKI_CHECK(descA.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); + + MKI_CHECK(descB.format == TENSOR_FORMAT_FRACTAL_NZ, "tensor format invalid", return false); + MKI_CHECK(descB.dims.size() == 4, "tensor dims invalid", return false); + MKI_CHECK(descB.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); + + auto opParam = AnyCast(launchParam.GetParam()); + if (opParam.matmulType == OpParam::MatMul::MatMulType::MATMUL_EIN_SUM) { + MKI_CHECK(descA.dims.size() == 3, "Matrix A must be a 3D-tensor in MATMUL_EIN_SUM scenario.", return false); + MKI_CHECK(descA.dims[1] == descB.dims[0], "tensor dims invalid: batchA != batchB", return false); + } else { + if (descA.dims.size() == 2) { // 2: The first input is a two-dimensional tensor while batch size is one. + MKI_CHECK(descB.dims[0] == 1, "tensor dims invalid", return false); + } else { + MKI_CHECK(descA.dims.size() == 3, "tensor dims invalid", return false); + MKI_CHECK(descA.dims[0] == descB.dims[0], "tensor dims invalid", return false); + } + } + + return true; + } + + Status InitImpl(const LaunchParam &launchParam) override + { + return optiling::CallGeTiling("PpMatMul", *GetBinHandle(), launchParam, + AsdOps::GetMkiSpecificAttr, kernelInfo_); + } +}; +REG_KERNEL_BASE(PpMatMulF16NDF16NZF16NDKernel); +} // namespace AsdOps diff --git a/src/kernels/kernels/matmul/pp_matmul_f16_opt_kernel/pp_matmul_f16_opt_kernel.cpp b/src/kernels/kernels/matmul/pp_mat_mul_f16nz_f16nz_f16nz_kernel/pp_mat_mul_f16nz_f16nz_f16nz_kernel.cpp similarity index 60% rename from src/kernels/kernels/matmul/pp_matmul_f16_opt_kernel/pp_matmul_f16_opt_kernel.cpp rename to src/kernels/kernels/matmul/pp_mat_mul_f16nz_f16nz_f16nz_kernel/pp_mat_mul_f16nz_f16nz_f16nz_kernel.cpp index 81feefa5..6430a2d4 100644 --- a/src/kernels/kernels/matmul/pp_matmul_f16_opt_kernel/pp_matmul_f16_opt_kernel.cpp +++ b/src/kernels/kernels/matmul/pp_mat_mul_f16nz_f16nz_f16nz_kernel/pp_mat_mul_f16nz_f16nz_f16nz_kernel.cpp @@ -1,68 +1,55 @@ -/* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include -#include -#include -#include -#include "asdops/params/params.h" -#include "kernels/matmul/tiling/tiling_data.h" -#include "kernels/matmul/tiling/pp_matmul_tiling.h" -#include "kernels/matmul/common/common_tiling.h" - -namespace AsdOps { -using namespace Mki; -class PpMatMulF16OptKernel : public KernelBase { -public: - explicit PpMatMulF16OptKernel(const std::string &kernelName, const BinHandle *handle) noexcept - : KernelBase(kernelName, handle) - { - } - - bool CanSupport(const LaunchParam &launchParam) const override - { - MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910B, "platform not support", - return false); - MKI_CHECK(launchParam.GetInTensorCount() == 2, "check inTensor count failed", return false); - MKI_CHECK(launchParam.GetOutTensorCount() == 1, "check outTensor count failed", return false); - - const auto &descA = launchParam.GetInTensor(0).desc; - const auto &descB = launchParam.GetInTensor(1).desc; - - MKI_CHECK(descA.format == TENSOR_FORMAT_ND, "tensor format invalid", return false); - MKI_CHECK(descA.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); - - MKI_CHECK(descB.format == TENSOR_FORMAT_FRACTAL_NZ, "tensor format invalid", return false); - MKI_CHECK(descB.dims.size() == 4, "tensor dims invalid", return false); - MKI_CHECK(descB.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); - - if (descA.dims.size() == 2) { // 2: The first input is a two-dimensional tensor while batch size is one. - MKI_CHECK(descB.dims[0] == 1, "tensor dims invalid", return false); - } else { - MKI_CHECK(descA.dims.size() == 3, "tensor dims invalid", return false); - MKI_CHECK(descA.dims[0] == descB.dims[0], "tensor dims invalid", return false); - } - - return true; - } - - uint64_t GetTilingSize(const LaunchParam &launchParam) const override - { - (void)launchParam; - constexpr uint32_t CONST_256 = 256; - return Round(sizeof(PpMatmulTilingData)); - } - - Status InitImpl(const LaunchParam &launchParam) override - { - return PpMatmulTiling(launchParam, kernelInfo_); - } -}; -REG_KERNEL_BASE(PpMatMulF16OptKernel); -} // namespace AsdOps +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include "asdops/params/params.h" +#include "kernels/matmul/tiling/pp_matmul_tiling.h" +#include "kernels/matmul/tiling/tiling_data.h" +#include "kernels/matmul/common/common.h" +#include "kernels/matmul/common/common_tiling.h" +#include "sink_common.h" + +namespace AsdOps { +class PpMatMulF16NZF16NZF16NZKernel : public KernelBase { +public: + explicit PpMatMulF16NZF16NZF16NZKernel(const std::string &kernelName, const BinHandle *handle) noexcept + : KernelBase(kernelName, handle) + { + } + + bool CanSupport(const LaunchParam &launchParam) const override + { + MKI_CHECK(PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_310P, "platform not support", + return false); + MKI_CHECK(launchParam.GetInTensorCount() == 4, "check inTensor count failed", return false); + MKI_CHECK(launchParam.GetOutTensorCount() == 1, "check outTensor count failed", return false); + + const auto &descA = launchParam.GetInTensor(0).desc; + const auto &descB = launchParam.GetInTensor(1).desc; + + MKI_CHECK(descA.format == TENSOR_FORMAT_FRACTAL_NZ, "tensor format invalid", return false); + MKI_CHECK(descA.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); + + MKI_CHECK(descB.format == TENSOR_FORMAT_FRACTAL_NZ, "tensor format invalid", return false); + MKI_CHECK(descB.dims.size() == 4, "tensor dims invalid", return false); + MKI_CHECK(descB.dtype == TENSOR_DTYPE_FLOAT16, "tensor dtype invalid", return false); + + return true; + } + + Status InitImpl(const LaunchParam &launchParam) override + { + return optiling::CallGeTiling("PpMatMul", *GetBinHandle(), launchParam, + AsdOps::GetMkiSpecificAttr, kernelInfo_); + } +}; +REG_KERNEL_BASE(PpMatMulF16NZF16NZF16NZKernel); +} // namespace AsdOps diff --git a/src/kernels/kernels/matmul/pp_matmul_ein_sum_kernel/pp_matmul_ein_sum_kernel.cpp b/src/kernels/kernels/matmul/pp_matmul_ein_sum_kernel/pp_matmul_ein_sum_kernel.cpp deleted file mode 100644 index d8e25bfa..00000000 --- a/src/kernels/kernels/matmul/pp_matmul_ein_sum_kernel/pp_matmul_ein_sum_kernel.cpp +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include -#include -#include -#include "asdops/params/params.h" -#include "kernels/matmul/common/common.h" -#include "kernels/matmul/tiling/pp_matmul_tiling.h" -#include "kernels/matmul/tiling/tiling_data.h" - -namespace AsdOps { -using namespace Mki; -class PpMatmulEinSumKernel : public KernelBase { -public: - explicit PpMatmulEinSumKernel(const std::string &kernelName, const BinHandle *handle) noexcept - : KernelBase(kernelName, handle) - { - } - - bool CanSupport(const LaunchParam &launchParam) const override - { - auto attr = launchParam.GetParam(); - MKI_CHECK(attr.Type() == typeid(OpParam::MatMul), "Invalid launch param type.", return false); - auto mmType = AnyCast(attr).matmulType; - MKI_CHECK(mmType == OpParam::MatMul::MatMulType::MATMUL_EIN_SUM, "Invalid matmul type.", return false); - const auto &inTensor1 = launchParam.GetInTensor(1); - if (inTensor1.desc.format == TENSOR_FORMAT_FRACTAL_NZ) { - return CheckAsdOpsWeightNZ(launchParam, 2); // 输入参数数量为2 - } else { - return CheckAsdOpsND(launchParam, 2); // 输入参数数量为2 - } - } - - uint64_t GetTilingSize(const LaunchParam &launchParam) const override - { - (void)launchParam; - constexpr uint32_t CONST_256 = 256; - return (sizeof(PpMatmulTilingData) + CONST_256 - 1) / CONST_256 * CONST_256; - } - - Status InitImpl(const LaunchParam &launchParam) override { return PpMatmulTiling(launchParam, kernelInfo_); } -}; -REG_KERNEL_BASE(PpMatmulEinSumKernel); -} // namespace AsdOps diff --git a/src/kernels/kernels/matmul/pp_matmul_f16_kernel/pp_matmul_f16_kernel.cpp b/src/kernels/kernels/matmul/pp_matmul_f16_kernel/pp_matmul_f16_kernel.cpp deleted file mode 100644 index 16414010..00000000 --- a/src/kernels/kernels/matmul/pp_matmul_f16_kernel/pp_matmul_f16_kernel.cpp +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include -#include -#include -#include "asdops/params/params.h" -#include "kernels/matmul/tiling/pp_matmul_tiling.h" -#include "kernels/matmul/tiling/tiling_data.h" -#include "kernels/matmul/common/common.h" -#include "kernels/matmul/common/common_tiling.h" - -namespace AsdOps { -class PpMatMulF16Kernel : public KernelBase { -public: - explicit PpMatMulF16Kernel(const std::string &kernelName, const BinHandle *handle) noexcept - : KernelBase(kernelName, handle) - { - } - - bool CanSupport(const LaunchParam &launchParam) const override - { - return CheckAsdOpsND(launchParam, 2); // 输入参数数量为2 - } - - uint64_t GetTilingSize(const LaunchParam &launchParam) const override - { - (void)launchParam; - constexpr uint32_t CONST_256 = 256; - return Round(sizeof(PpMatmulTilingData)); - } - - Status InitImpl(const LaunchParam &launchParam) override - { - return PpMatmulTiling(launchParam, kernelInfo_); - } -}; -REG_KERNEL_BASE(PpMatMulF16Kernel); -} // namespace AsdOps diff --git a/src/ops_infer/linear/linear_ops_runner.cpp b/src/ops_infer/linear/linear_ops_runner.cpp index 023f9678..f46e7e40 100644 --- a/src/ops_infer/linear/linear_ops_runner.cpp +++ b/src/ops_infer/linear/linear_ops_runner.cpp @@ -149,7 +149,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmul910B() KernelGraphNode &matmulNode = kernelGraph_.nodes.at(0); matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&xTensor, &weightTensor}; + matmulNode.inTensors = {&xTensor, &weightTensor, &nullTensor_, &nullTensor_ }; matmulNode.outTensors = {&outTensor}; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (xNeedMergeAxis_) { @@ -203,7 +203,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulWeightNdNot910B() } matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&transdataAOutTensor, &transdataBOutTensor}; + matmulNode.inTensors = {&transdataAOutTensor, &transdataBOutTensor, &nullTensor_, &nullTensor_ }; matmulNode.outTensors = {&matmulOutTensor}; transdataOutNode.opDesc = {0, "TransdataOperation", transdataNzToNdParam_}; @@ -241,7 +241,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulWeightNzNot910B() } matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&transdataAOutTensor, &weightTensor}; + matmulNode.inTensors = {&transdataAOutTensor, &weightTensor, &nullTensor_, &nullTensor_ }; matmulNode.outTensors = {&matmulOutTensor}; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (isWeightNz_) { @@ -274,7 +274,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulElewiseAdd910B() KernelGraphNode &addNode = kernelGraph_.nodes.at(1); matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&xTensor, &weightTensor}; + matmulNode.inTensors = {&xTensor, &weightTensor, &nullTensor_, &nullTensor_}; matmulNode.outTensors = {&matmulOutTensor}; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (xNeedMergeAxis_) { @@ -338,7 +338,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulElewiseAddWeightNdNot910B() } matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&transdataAOutTensor, &transdataBOutTensor}; + matmulNode.inTensors = {&transdataAOutTensor, &transdataBOutTensor, &nullTensor_, &nullTensor_}; matmulNode.outTensors = {&matmulOutTensor}; transdataOutNode.opDesc = {0, "TransdataOperation", transdataNzToNdParam_}; @@ -387,7 +387,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulElewiseAddWeightNzNot910B() } matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&transdataAOutTensor, &weightTensor}; + matmulNode.inTensors = {&transdataAOutTensor, &weightTensor, &nullTensor_, &nullTensor_ }; matmulNode.outTensors = {&matmulOutTensor}; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (isWeightNz_) { @@ -425,7 +425,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulWithBias() matmulParam_.withBias = true; matmulParam_.matmulType = AsdOps::OpParam::MatMul::MatMulType::MATMUL_WITH_BIAS; matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&xTensor, &weightTensor, &biasTensor}; + matmulNode.inTensors = {&xTensor, &weightTensor, &biasTensor, &nullTensor_}; matmulNode.outTensors = {&outTensor}; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (xNeedMergeAxis_) { @@ -474,7 +474,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulAccum() matmulParam_.matmulType = AsdOps::OpParam::MatMul::MatMulType::MATMUL_ACCUM_ATOMIC; matmulNode.opDesc = {1, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&transposedXtensor, &weightTensor, &accumTensor}; + matmulNode.inTensors = {&transposedXtensor, &weightTensor, &accumTensor, &nullTensor_}; matmulNode.outTensors = {&outTensor}; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (xNeedMergeAxis_) { @@ -489,7 +489,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulAccum() matmulParam_.matmulType = AsdOps::OpParam::MatMul::MatMulType::MATMUL_ACCUM_ATOMIC; matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&xTensor, &weightTensor, &accumTensor}; + matmulNode.inTensors = {&xTensor, &weightTensor, &accumTensor, &nullTensor_}; matmulNode.outTensors = {&outTensor}; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (xNeedMergeAxis_) { @@ -515,7 +515,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulEin() matmulParam_.matmulType = AsdOps::OpParam::MatMul::MatMulType::MATMUL_EIN_SUM; matmulNode.opDesc = { 0, "MatMulOperation", matmulParam_ }; - matmulNode.inTensors = { &xTensor, &weightTensor }; + matmulNode.inTensors = { &xTensor, &weightTensor, &nullTensor_, &nullTensor_ }; matmulNode.outTensors = { &outTensor }; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (isWeightNz_) { @@ -545,7 +545,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulEinElewiseAdd() matmulParam_.matmulType = AsdOps::OpParam::MatMul::MatMulType::MATMUL_EIN_SUM; matmulNode.opDesc = { 0, "MatMulOperation", matmulParam_ }; - matmulNode.inTensors = { &xTensor, &weightTensor }; + matmulNode.inTensors = { &xTensor, &weightTensor, &nullTensor_, &nullTensor_ }; matmulNode.outTensors = { &matmuloutTensor }; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (isWeightNz_) { diff --git a/tests/apitest/kernelstest/matmul/test_pp_matmul_accum.py b/tests/apitest/kernelstest/matmul/test_pp_matmul_accum.py index ad6c50c8..e30e049a 100644 --- a/tests/apitest/kernelstest/matmul/test_pp_matmul_accum.py +++ b/tests/apitest/kernelstest/matmul/test_pp_matmul_accum.py @@ -96,11 +96,11 @@ class TestPpMatmulAccum(op_test.OpTest): "matmulType": MATMUL_ACCUM_ATOMIC, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), ta, tb, torch.bfloat16) self.execute( - [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_C.float()], + [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_C.float(), torch.Tensor()], [2], ) @@ -117,11 +117,11 @@ class TestPpMatmulAccum(op_test.OpTest): "matmulType": MATMUL_ACCUM_ATOMIC, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), ta, tb, torch.bfloat16) self.execute( - [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_C.float()], + [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_C.float(), torch.Tensor()], [2], ) @@ -138,11 +138,11 @@ class TestPpMatmulAccum(op_test.OpTest): "matmulType": MATMUL_ACCUM_ATOMIC, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), ta, tb, torch.bfloat16) self.execute( - [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_C.float()], + [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_C.float(), torch.Tensor()], [2], ) @@ -159,11 +159,11 @@ class TestPpMatmulAccum(op_test.OpTest): "matmulType": MATMUL_ACCUM_ATOMIC, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), ta, tb, torch.bfloat16) self.execute( - [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_C.float()], + [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_C.float(), torch.Tensor()], [2], ) @@ -180,11 +180,11 @@ class TestPpMatmulAccum(op_test.OpTest): "matmulType": MATMUL_ACCUM_ATOMIC, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), ta, tb, torch.float16) self.execute( - [self.bat_A.half(), self.bat_B.half(), self.bat_C.float()], + [self.bat_A.half(), self.bat_B.half(), self.bat_C.float(), torch.Tensor()], [2], ) @@ -201,11 +201,11 @@ class TestPpMatmulAccum(op_test.OpTest): "matmulType": MATMUL_ACCUM_ATOMIC, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), ta, tb, torch.float16) self.execute( - [self.bat_A.half(), self.bat_B.half(), self.bat_C.float()], + [self.bat_A.half(), self.bat_B.half(), self.bat_C.float(), torch.Tensor()], [2], ) @@ -222,11 +222,11 @@ class TestPpMatmulAccum(op_test.OpTest): "matmulType": MATMUL_ACCUM_ATOMIC, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), ta, tb, torch.float16) self.execute( - [self.bat_A.half(), self.bat_B.half(), self.bat_C.float()], + [self.bat_A.half(), self.bat_B.half(), self.bat_C.float(), torch.Tensor()], [2], ) @@ -243,11 +243,11 @@ class TestPpMatmulAccum(op_test.OpTest): "matmulType": MATMUL_ACCUM_ATOMIC, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), ta, tb, torch.float16) self.execute( - [self.bat_A.half(), self.bat_B.half(), self.bat_C.float()], + [self.bat_A.half(), self.bat_B.half(), self.bat_C.float(), torch.Tensor()], [2], ) @@ -264,11 +264,11 @@ class TestPpMatmulAccum(op_test.OpTest): "matmulType": MATMUL_ACCUM_ATOMIC, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), ta, tb, torch.float16) self.execute( - [self.bat_A.half(), self.bat_B.half(), self.bat_C.float()], + [self.bat_A.half(), self.bat_B.half(), self.bat_C.float(), torch.Tensor()], [2], ) diff --git a/tests/apitest/kernelstest/matmul/test_pp_matmul_bf16.py b/tests/apitest/kernelstest/matmul/test_pp_matmul_bf16.py index 2a68fd99..8758fb5a 100644 --- a/tests/apitest/kernelstest/matmul/test_pp_matmul_bf16.py +++ b/tests/apitest/kernelstest/matmul/test_pp_matmul_bf16.py @@ -106,11 +106,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": trans_A, "transposeB": trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), trans_A, trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -123,11 +123,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": trans_A, "transposeB": trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), trans_A, trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -140,11 +140,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": trans_A, "transposeB": trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), trans_A, trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -156,11 +156,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), self.trans_A, self.trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -172,11 +172,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), self.trans_A, self.trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -188,11 +188,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), self.trans_A, self.trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -204,11 +204,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), self.trans_A, self.trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -220,11 +220,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), self.trans_A, self.trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -236,11 +236,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), self.trans_A, self.trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -252,11 +252,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), self.trans_A, self.trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -268,11 +268,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), self.trans_A, self.trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -285,11 +285,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": trans_A, "transposeB": trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), trans_A, trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -302,11 +302,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": trans_A, "transposeB": trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), trans_A, trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -319,11 +319,11 @@ class TestPpMatmulBf16(op_test.OpTest): "MatMulOperation", {"transposeA": trans_A, "transposeB": trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), trans_A, trans_B) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) diff --git a/tests/apitest/kernelstest/matmul/test_pp_matmul_ein_sum_f16.py b/tests/apitest/kernelstest/matmul/test_pp_matmul_ein_sum_f16.py index 46c66a30..d3654f80 100644 --- a/tests/apitest/kernelstest/matmul/test_pp_matmul_ein_sum_f16.py +++ b/tests/apitest/kernelstest/matmul/test_pp_matmul_ein_sum_f16.py @@ -104,11 +104,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "matmulType": 4, }, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) @@ -126,11 +126,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "matmulType": 4, }, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) @@ -148,11 +148,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "matmulType": 4, }, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) @@ -170,11 +170,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "matmulType": 4, }, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) @@ -192,11 +192,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "matmulType": 4, }, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) @@ -214,11 +214,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "matmulType": 4, }, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) @@ -237,11 +237,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "enShuffleK": True, }, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) @@ -260,11 +260,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "enShuffleK": True, }, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) @@ -283,11 +283,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "enShuffleK": True, }, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) @@ -306,11 +306,11 @@ class TestPpMatmulEinSumFp16(op_test.OpTest): "enShuffleK": True, }, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype) self.execute( - [self.bat_A, self.bat_B], + [self.bat_A, self.bat_B, torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).to(dtype)], ) diff --git a/tests/apitest/kernelstest/matmul/test_pp_matmul_f16.py b/tests/apitest/kernelstest/matmul/test_pp_matmul_f16.py index a72ac5f8..500c9d07 100644 --- a/tests/apitest/kernelstest/matmul/test_pp_matmul_f16.py +++ b/tests/apitest/kernelstest/matmul/test_pp_matmul_f16.py @@ -83,11 +83,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -99,11 +99,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -115,11 +115,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -131,11 +131,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @op_test.only_910b @@ -146,11 +146,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -162,11 +162,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -178,11 +178,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -194,11 +194,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -210,11 +210,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -226,11 +226,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -242,11 +242,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -258,11 +258,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -274,11 +274,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -290,11 +290,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -306,11 +306,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -322,11 +322,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -338,11 +338,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -354,11 +354,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -370,11 +370,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nz]) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -386,11 +386,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -402,11 +402,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -418,11 +418,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -434,11 +434,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -450,11 +450,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -466,11 +466,11 @@ class TestPpMatmulF16(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) diff --git a/tests/apitest/kernelstest/matmul/test_pp_matmul_with_bias.py b/tests/apitest/kernelstest/matmul/test_pp_matmul_with_bias.py index cfc36974..d87b4d60 100644 --- a/tests/apitest/kernelstest/matmul/test_pp_matmul_with_bias.py +++ b/tests/apitest/kernelstest/matmul/test_pp_matmul_with_bias.py @@ -87,11 +87,11 @@ class TestPpMatmulF16(op_test.OpTest): "matmulType": MATMUL_WITH_BIAS, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [self.bat_A.half(), self.bat_B.half(), self.bat_bias.float()], + [self.bat_A.half(), self.bat_B.half(), self.bat_bias.float(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -108,11 +108,11 @@ class TestPpMatmulF16(op_test.OpTest): "matmulType": MATMUL_WITH_BIAS, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [self.bat_A.half(), self.bat_B.half(), self.bat_bias.float()], + [self.bat_A.half(), self.bat_B.half(), self.bat_bias.float(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -129,11 +129,11 @@ class TestPpMatmulF16(op_test.OpTest): "matmulType": MATMUL_WITH_BIAS, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [self.bat_A.half(), self.bat_B.half(), self.bat_bias.float()], + [self.bat_A.half(), self.bat_B.half(), self.bat_bias.float(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -150,11 +150,11 @@ class TestPpMatmulF16(op_test.OpTest): "matmulType": MATMUL_WITH_BIAS, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [self.bat_A.half(), self.bat_B.half(), self.bat_bias.float()], + [self.bat_A.half(), self.bat_B.half(), self.bat_bias.float(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) @@ -171,11 +171,11 @@ class TestPpMatmulF16(op_test.OpTest): "matmulType": MATMUL_WITH_BIAS, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype=torch.bfloat16) self.execute( - [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float()], + [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -192,11 +192,11 @@ class TestPpMatmulF16(op_test.OpTest): "matmulType": MATMUL_WITH_BIAS, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype=torch.bfloat16) self.execute( - [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float()], + [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -213,11 +213,11 @@ class TestPpMatmulF16(op_test.OpTest): "matmulType": MATMUL_WITH_BIAS, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype=torch.bfloat16) self.execute( - [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float()], + [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -234,11 +234,11 @@ class TestPpMatmulF16(op_test.OpTest): "matmulType": MATMUL_WITH_BIAS, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype=torch.bfloat16) self.execute( - [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float()], + [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) @@ -255,11 +255,11 @@ class TestPpMatmulF16(op_test.OpTest): "matmulType": MATMUL_WITH_BIAS, }, ) - self.set_input_formats([self.format_nd, self.format_nd, self.format_nd]) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nd]) self.__gen_test_data((bsize, msize, ksize, nsize), dtype=torch.bfloat16) self.execute( - [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float()], + [self.bat_A.bfloat16(), self.bat_B.bfloat16(), self.bat_bias.float(), torch.Tensor()], [torch.zeros(self.bat_C.shape).bfloat16()], ) diff --git a/tests/apitest/kernelstest/matmul/test_ppmatmul_310p.py b/tests/apitest/kernelstest/matmul/test_ppmatmul_310p.py index faf7cbfc..78f0c42e 100644 --- a/tests/apitest/kernelstest/matmul/test_ppmatmul_310p.py +++ b/tests/apitest/kernelstest/matmul/test_ppmatmul_310p.py @@ -76,11 +76,11 @@ class TestPpMatmul310p(op_test.OpTest): "MatMulOperation", {"transposeA": self.trans_A, "transposeB": self.trans_B, "oriShape": [msize, ksize, nsize]}, ) - self.set_input_formats([self.format_nz, self.format_nz]) + self.set_input_formats([self.format_nz, self.format_nz, self.format_nd, self.format_nd]) self.set_output_formats([self.format_nz]) self.__gen_test_data((bsize, msize, ksize, nsize)) self.execute( - [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half()], + [torch.tensor(self.bat_A).half(), torch.tensor(self.bat_B).half(), torch.Tensor(), torch.Tensor()], [torch.zeros(self.bat_C.shape).half()], ) -- Gitee From a72ac4ca645f0bcb3109efa8055ffa9067dd02b1 Mon Sep 17 00:00:00 2001 From: w30042780 Date: Thu, 31 Jul 2025 21:57:48 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E5=90=8C=E6=AD=A5?= =?UTF-8?q?=E5=A4=B1=E8=B4=A5=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/kernels/kernels/matmul/matmul_operation.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/kernels/kernels/matmul/matmul_operation.cpp b/src/kernels/kernels/matmul/matmul_operation.cpp index 5bfc849e..6ad3f272 100644 --- a/src/kernels/kernels/matmul/matmul_operation.cpp +++ b/src/kernels/kernels/matmul/matmul_operation.cpp @@ -108,7 +108,6 @@ public: kernelKey = (kernelKey << INPUT_BIT_COUNT) + (inTensorCount - DIM_2); MKI_LOG(INFO) << "kernelKey: " << kernelKey; MKI_LOG(INFO) << ">>> PpMatmulType:" << static_cast(opParam.matmulType); - // 先判断w8a8compress if (isSparseDequant) { switch (kernelKey) { -- Gitee