From cb651965f699397c3f16b98cc6f3ddd2f6bb99b0 Mon Sep 17 00:00:00 2001 From: "xinchi.tian" Date: Thu, 16 Jan 2025 19:02:35 +0800 Subject: [PATCH] Fix compile error --- .../ixrt/src/qkv_to_context/qkvToContextInt8Plugin.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src/qkv_to_context/qkvToContextInt8Plugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src/qkv_to_context/qkvToContextInt8Plugin.cu index 5c4d5c53..2330debf 100644 --- a/models/nlp/language_model/bert_base_squad/ixrt/src/qkv_to_context/qkvToContextInt8Plugin.cu +++ b/models/nlp/language_model/bert_base_squad/ixrt/src/qkv_to_context/qkvToContextInt8Plugin.cu @@ -284,7 +284,7 @@ cudaError_t fused_multihead_attetion_int8(int8_t* qkv_buffer, int8_t* mask, int8 case 64: case 128: case 192: - case 256: + case 256: { cuinferFlashAttnConfigInfo flashAttnInfo; flashAttnInfo.scaling = sqrt(1.f / (head_dim * 1.0)); flashAttnInfo.quantParam.q_amax = arrange_qkv_amax; @@ -318,7 +318,8 @@ cudaError_t fused_multihead_attetion_int8(int8_t* qkv_buffer, int8_t* mask, int8 CUINFER_CHECK(cuinferFMHAForwardEx(cuinfer_handle, flashAttnInfo, qDesc, q_buffer, kDesc, k_buffer, vDesc, v_buffer, maskDesc, mask, oDesc, qk_buffer)); break; - default: + } + default: { cuinfer_i8_gemm(k_buffer, q_buffer, nullptr, qkv_buffer, batch_size * head_num, batch_seq_len, batch_seq_len, head_dim, batch_seq_len * head_dim, batch_seq_len * head_dim, batch_seq_len * batch_seq_len, scaleBmm1, 0.0, 0, cuinfer_handle, stream); @@ -330,6 +331,7 @@ cudaError_t fused_multihead_attetion_int8(int8_t* qkv_buffer, int8_t* mask, int8 batch_seq_len, batch_seq_len * head_dim, batch_seq_len * batch_seq_len, batch_seq_len * head_dim, scaleBmm2, cuinfer_handle, stream); break; + } } IxinferArrangeAttenOutputI8II8O(batch_token_num, hidden_size, stream, qk_buffer, qkv_out, batch_seq_len, head_dim, -- Gitee