diff --git a/src/kernels/kernels/matmul/matmul_operation.cpp b/src/kernels/kernels/matmul/matmul_operation.cpp index 83a28833f522bcbd33242e0f95a0d38b5de4ec8a..9b687840f8c5e3114a2fba2d0003c0309f19bdca 100644 --- a/src/kernels/kernels/matmul/matmul_operation.cpp +++ b/src/kernels/kernels/matmul/matmul_operation.cpp @@ -24,6 +24,8 @@ namespace { constexpr uint32_t DTYPE_BIT_COUNT = 2; constexpr uint32_t FORMAT_BIT_COUNT = 1; constexpr uint32_t INPUT_BIT_COUNT = 2; +constexpr uint32_t ALIGNment_4 = 4; +constexpr uint32_t ALIGNment_16 = 16; constexpr uint32_t PP_MATMUL_I8_BF16_KERNEL_KEY = 0b1'00'00'10'0'0'0'10; constexpr uint32_t PP_MATMUL_I8_BF16_WEIGHT_NZ_KERNEL_KEY = 0b1'00'00'10'0'1'0'10; constexpr uint32_t PP_MATMUL_I8_KERNEL_KEY = 0b1'00'00'01'0'0'0'10; @@ -124,7 +126,14 @@ public: case PP_MATMUL_I8_BF16_KERNEL_KEY: case PP_MATMUL_I8_BF16_WEIGHT_NZ_KERNEL_KEY: return GetKernelByName("PpMatMulI8Bf16Kernel"); - case PP_MATMUL_I8_KERNEL_KEY: return GetKernelByName("PpMatMulI8Kernel"); + case PP_MATMUL_I8_KERNEL_KEY: { + const uint32_t nModResult = opParam.oriShape[DIM_2] % ALIGNment_16; + if (nModResult < ALIGNment_4) { + return GetKernelByName("QuantBatchMatmulI8Kernel"); + } else { + return GetKernelByName("PpMatMulI8Kernel"); + } + } case PP_MATMUL_I8_WEIGHT_NZ_KERNEL_KEY: return GetKernelByName("PpMatMulI8WeightNzKernel"); case PP_MATMUL_F16_KERNEL_KEY: return GetKernelByName("PpMatMulF16Kernel"); case PP_MATMUL_F16_MIX_KERNEL_KEY: return GetKernelByName("PpMatMulF16MixKernel"); diff --git a/src/kernels/kernels/matmul/pp_matmul_i8_kernel/pp_matmul_i8_kernel.cpp b/src/kernels/kernels/matmul/pp_matmul_i8_kernel/pp_matmul_i8_kernel.cpp index eaefb5b7e92b0dff808575ac58911cbe0e8de2cb..8ce1f4b0fe4bee87c53f5a90ed26a8ae0d8e9a58 100644 --- a/src/kernels/kernels/matmul/pp_matmul_i8_kernel/pp_matmul_i8_kernel.cpp +++ b/src/kernels/kernels/matmul/pp_matmul_i8_kernel/pp_matmul_i8_kernel.cpp @@ -64,6 +64,27 @@ public: }; REG_KERNEL_BASE(PpMatMulI8Kernel); +class QuantBatchMatmulI8Kernel : public PpMatMulI8CommKernel { +public: + explicit QuantBatchMatmulI8Kernel(const std::string &kernelName, const BinHandle *handle) noexcept + : PpMatMulI8CommKernel(kernelName, handle) + { + } + + bool CanSupport(const LaunchParam &launchParam) const override + { + MKI_CHECK_NO_LOG(PpMatMulI8CommKernel::CanSupport(launchParam), return false); + const auto &inTensor3 = launchParam.GetInTensor(3); + MKI_CHECK(inTensor3.desc.format == TENSOR_FORMAT_ND, "inTensor3 format invalid", return false); + MKI_CHECK(inTensor3.desc.dtype == TENSOR_DTYPE_INT64 || inTensor3.desc.dtype == TENSOR_DTYPE_UINT64, + "inTensor3 dtype invalid", return false); + return true; + } +}; +REG_KERNEL_BASE(QuantBatchMatmulI8Kernel); + + + class PpMatMulI8NdNzKernel : public PpMatMulI8Kernel { public: explicit PpMatMulI8NdNzKernel(const std::string &kernelName, const BinHandle *handle) noexcept