diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src/qkv_to_context/qkvToContextInt8Plugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src/qkv_to_context/qkvToContextInt8Plugin.cu index 5c4d5c5331077a490b1f6feb32d30596500fa23c..2330debf3e1bee647c70336b35729699b90ad06e 100644 --- a/models/nlp/language_model/bert_base_squad/ixrt/src/qkv_to_context/qkvToContextInt8Plugin.cu +++ b/models/nlp/language_model/bert_base_squad/ixrt/src/qkv_to_context/qkvToContextInt8Plugin.cu @@ -284,7 +284,7 @@ cudaError_t fused_multihead_attetion_int8(int8_t* qkv_buffer, int8_t* mask, int8 case 64: case 128: case 192: - case 256: + case 256: { cuinferFlashAttnConfigInfo flashAttnInfo; flashAttnInfo.scaling = sqrt(1.f / (head_dim * 1.0)); flashAttnInfo.quantParam.q_amax = arrange_qkv_amax; @@ -318,7 +318,8 @@ cudaError_t fused_multihead_attetion_int8(int8_t* qkv_buffer, int8_t* mask, int8 CUINFER_CHECK(cuinferFMHAForwardEx(cuinfer_handle, flashAttnInfo, qDesc, q_buffer, kDesc, k_buffer, vDesc, v_buffer, maskDesc, mask, oDesc, qk_buffer)); break; - default: + } + default: { cuinfer_i8_gemm(k_buffer, q_buffer, nullptr, qkv_buffer, batch_size * head_num, batch_seq_len, batch_seq_len, head_dim, batch_seq_len * head_dim, batch_seq_len * head_dim, batch_seq_len * batch_seq_len, scaleBmm1, 0.0, 0, cuinfer_handle, stream); @@ -330,6 +331,7 @@ cudaError_t fused_multihead_attetion_int8(int8_t* qkv_buffer, int8_t* mask, int8 batch_seq_len, batch_seq_len * head_dim, batch_seq_len * batch_seq_len, batch_seq_len * head_dim, scaleBmm2, cuinfer_handle, stream); break; + } } IxinferArrangeAttenOutputI8II8O(batch_token_num, hidden_size, stream, qk_buffer, qkv_out, batch_seq_len, head_dim,