From ed0a58076723caf5d419c92140b5df4bfabbc54e Mon Sep 17 00:00:00 2001 From: lingchongfeng Date: Wed, 4 Jun 2025 20:20:21 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E7=94=A8canndev=E4=BB=93=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kernels/matmul/matmul_operation.cpp | 11 +++++++++- .../pp_matmul_i8_kernel.cpp | 21 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/kernels/kernels/matmul/matmul_operation.cpp b/src/kernels/kernels/matmul/matmul_operation.cpp index 83a28833..9b687840 100644 --- a/src/kernels/kernels/matmul/matmul_operation.cpp +++ b/src/kernels/kernels/matmul/matmul_operation.cpp @@ -24,6 +24,8 @@ namespace { constexpr uint32_t DTYPE_BIT_COUNT = 2; constexpr uint32_t FORMAT_BIT_COUNT = 1; constexpr uint32_t INPUT_BIT_COUNT = 2; +constexpr uint32_t ALIGNment_4 = 4; +constexpr uint32_t ALIGNment_16 = 16; constexpr uint32_t PP_MATMUL_I8_BF16_KERNEL_KEY = 0b1'00'00'10'0'0'0'10; constexpr uint32_t PP_MATMUL_I8_BF16_WEIGHT_NZ_KERNEL_KEY = 0b1'00'00'10'0'1'0'10; constexpr uint32_t PP_MATMUL_I8_KERNEL_KEY = 0b1'00'00'01'0'0'0'10; @@ -124,7 +126,14 @@ public: case PP_MATMUL_I8_BF16_KERNEL_KEY: case PP_MATMUL_I8_BF16_WEIGHT_NZ_KERNEL_KEY: return GetKernelByName("PpMatMulI8Bf16Kernel"); - case PP_MATMUL_I8_KERNEL_KEY: return GetKernelByName("PpMatMulI8Kernel"); + case PP_MATMUL_I8_KERNEL_KEY: { + const uint32_t nModResult = opParam.oriShape[DIM_2] % ALIGNment_16; + if (nModResult < ALIGNment_4) { + return GetKernelByName("QuantBatchMatmulI8Kernel"); + } else { + return GetKernelByName("PpMatMulI8Kernel"); + } + } case PP_MATMUL_I8_WEIGHT_NZ_KERNEL_KEY: return GetKernelByName("PpMatMulI8WeightNzKernel"); case PP_MATMUL_F16_KERNEL_KEY: return GetKernelByName("PpMatMulF16Kernel"); case PP_MATMUL_F16_MIX_KERNEL_KEY: return GetKernelByName("PpMatMulF16MixKernel"); diff --git a/src/kernels/kernels/matmul/pp_matmul_i8_kernel/pp_matmul_i8_kernel.cpp b/src/kernels/kernels/matmul/pp_matmul_i8_kernel/pp_matmul_i8_kernel.cpp index eaefb5b7..8ce1f4b0 100644 --- a/src/kernels/kernels/matmul/pp_matmul_i8_kernel/pp_matmul_i8_kernel.cpp +++ b/src/kernels/kernels/matmul/pp_matmul_i8_kernel/pp_matmul_i8_kernel.cpp @@ -64,6 +64,27 @@ public: }; REG_KERNEL_BASE(PpMatMulI8Kernel); +class QuantBatchMatmulI8Kernel : public PpMatMulI8CommKernel { +public: + explicit QuantBatchMatmulI8Kernel(const std::string &kernelName, const BinHandle *handle) noexcept + : PpMatMulI8CommKernel(kernelName, handle) + { + } + + bool CanSupport(const LaunchParam &launchParam) const override + { + MKI_CHECK_NO_LOG(PpMatMulI8CommKernel::CanSupport(launchParam), return false); + const auto &inTensor3 = launchParam.GetInTensor(3); + MKI_CHECK(inTensor3.desc.format == TENSOR_FORMAT_ND, "inTensor3 format invalid", return false); + MKI_CHECK(inTensor3.desc.dtype == TENSOR_DTYPE_INT64 || inTensor3.desc.dtype == TENSOR_DTYPE_UINT64, + "inTensor3 dtype invalid", return false); + return true; + } +}; +REG_KERNEL_BASE(QuantBatchMatmulI8Kernel); + + + class PpMatMulI8NdNzKernel : public PpMatMulI8Kernel { public: explicit PpMatMulI8NdNzKernel(const std::string &kernelName, const BinHandle *handle) noexcept -- Gitee