diff --git a/examples/cpp/ms/initialize.h b/examples/cpp/ms/initialize.h index 32a4d5ae9c8f3253554e6fe34e74359297e7a097..6a6221d5bb45f1c1572c99f157e28ed705ec0662 100644 --- a/examples/cpp/ms/initialize.h +++ b/examples/cpp/ms/initialize.h @@ -401,14 +401,22 @@ void InitializeEncoder(opt_arg* opt_a, false, // free buffer after fwd true, // is_qk_buf_float_ false); // sparse + bool compress=true; desc.input_tensors.push_back(Tensor{ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); desc.input_tensors.push_back(Tensor{ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); + if(compress) + desc.input_tensors.push_back(Tensor{ + MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); desc.input_python_tensors.push_back(Tensor{ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); desc.input_python_tensors.push_back(Tensor{ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->seq_len}, 0}); + if(compress) + desc.input_python_tensors.push_back(Tensor{ + MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); + desc.output_tensors.push_back(Tensor{ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); @@ -473,6 +481,11 @@ void InitializeEncoderT5(opt_arg* opt_a, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, 0}); + desc.input_tensors.push_back( + Tensor{MEMORY_GPU, + getTensorType(), + std::vector{opt_a->batch_size, opt_a->seq_len}, + 0}); desc.input_python_tensors.push_back(Tensor{ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); desc.input_python_tensors.push_back(Tensor{ @@ -482,6 +495,11 @@ void InitializeEncoderT5(opt_arg* opt_a, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, 0}); + desc.input_python_tensors.push_back( + Tensor{MEMORY_CPU, + getTensorType(), + std::vector{opt_a->batch_size, opt_a->seq_len}, + 0}); desc.output_tensors.push_back(Tensor{ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, opt_a->hidden_size}, 0}); diff --git a/examples/cpp/ms/ms.cc b/examples/cpp/ms/ms.cc index 8df6c5fe4074d62d2ee3f2ef01de83b87a034398..d5267f31e11316c07157b984a2b39682a5be0986 100644 --- a/examples/cpp/ms/ms.cc +++ b/examples/cpp/ms/ms.cc @@ -263,12 +263,12 @@ static float CompareData(const T* refOutput, int size, const T* msTensorData) static int x = 0; int s = std::min(10, size); if (x == 0) { - for (int j = 0; j < s; j++) { // std::min(50, size) + for (int j = 0; j < std::min(50, size); j++) { std::cout << static_cast(msTensorData[j]) << " "; } std::cout << std::endl; std::cout << "Data of Ref output : "; - for (int j = 0; j < s; j++) { // std::min(50, size) + for (int j = 0; j < std::min(50, size); j++) { std::cout << static_cast(refOutput[j]) << " "; } std::cout << std::endl; diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.cu b/src/fastertransformer/kernels/bert_preprocess_kernels.cu index 9976d501930807fa24137578654d4535ffdb6771..a9b11a51e0a5b447e681bfd6db0da818fc135449 100644 --- a/src/fastertransformer/kernels/bert_preprocess_kernels.cu +++ b/src/fastertransformer/kernels/bert_preprocess_kernels.cu @@ -58,7 +58,7 @@ void invokeGetPaddingOffset(size_t* h_token_num, } template -__global__ void buildSequnceLength(const T * input, int *sequnce_length, const int max_seq_length, const int hidden_size) { +__global__ void buildSequnceLength(const T * input, int *sequence_length, const int max_seq_length, const int hidden_size) { __shared__ int s_max_val; int bid = blockIdx.x; const T * seq_base = input + bid* max_seq_length * hidden_size; @@ -75,52 +75,73 @@ __global__ void buildSequnceLength(const T * input, int *sequnce_length, const i s_max_val = max_val; } __syncthreads(); - sequnce_length[bid] = -s_max_val; + sequence_length[bid] = -s_max_val; } +__global__ void buildSequnceLength(const int *input, int *sequence_length, const int max_seq_length) { + __shared__ int s_max_val; + int bid = blockIdx.x; + int last = 0; + const int *base = input + bid * max_seq_length; + for (int i=threadIdx.x ; i < max_seq_length; i += blockDim.x) { + const int *ptr = base + i; + if (*ptr != 0){ + last = i; + } + } + int max_val = blockReduceMax(last); + if (threadIdx.x == 0) { + s_max_val = max_val + 1; + } + __syncthreads(); + sequence_length[bid] = s_max_val; +} template -__global__ void buildEncoderAttentionMaskKernel(T* attention_mask, const int* sequence_lengths, const int max_seq_len) +__global__ void buildEncoderAttentionMaskKernel(T* attention_mask, const int* q_sequence_lengths, const int* kv_sequence_lengths, const int src_seq_len, const int tgt_seq_len) { // sequence_lengths: [batch_size] // attention_mask: [batch_size, 1, max_seq_len, max_seq_len] - attention_mask += blockIdx.x * max_seq_len * max_seq_len; - const int length = sequence_lengths[blockIdx.x]; - for (int i = threadIdx.x; i < max_seq_len * max_seq_len; i += blockDim.x) { - // int row_id = i / max_seq_len; - int col_id = i % max_seq_len; + attention_mask += blockIdx.x * src_seq_len * tgt_seq_len; + const int q_length = q_sequence_lengths[blockIdx.x]; + const int kv_length = kv_sequence_lengths[blockIdx.x]; + for (int i = threadIdx.x; i < src_seq_len * tgt_seq_len; i += blockDim.x) { + int row_id = i / tgt_seq_len; + int col_id = i % tgt_seq_len; // if (row_id < length && col_id < length) { // TODO (bhsueh) check this modification is ok or not on other rmodel - if (col_id >= length) { + if (col_id >= q_length || row_id >= kv_length) { attention_mask[i] = (T)(0.0f); } } } - - template void invokeBuildEncoderAttentionMask( - T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream) + T* attention_mask, const int* q_sequence_lengths, const int* kv_sequence_lengths, const int batch_size, const int src_seq_len, const int tgt_seq_len, cudaStream_t stream) { - buildEncoderAttentionMaskKernel<<>>(attention_mask, sequence_lengths, max_seq_len); + buildEncoderAttentionMaskKernel<<>>(attention_mask, q_sequence_lengths, kv_sequence_lengths, src_seq_len, tgt_seq_len); } + template void invokeBuildEncoderAttentionMask(float* attention_mask, - const int* sequence_lengths, + const int* q_sequence_lengths, + const int* kv_sequence_lengths, const int batch_size, - const int max_seq_len, + const int src_seq_len, + const int tgt_seq_len, cudaStream_t stream); template void invokeBuildEncoderAttentionMask(half* attention_mask, - const int* sequence_lengths, + const int* q_sequence_lengths, + const int* kv_sequence_lengths, const int batch_size, - const int max_seq_len, + const int src_seq_len, + const int tgt_seq_len, cudaStream_t stream); - __global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int* sequence_length, const int batch_size) { // use for get tensorrt fused mha padding offset @@ -143,10 +164,14 @@ __global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int template -void invokeBuildSequnceLength(const T * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream) { +void invokeBuildSequenceLength(const T * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream) { buildSequnceLength<<>>(input,sequnce_length, max_seq_length,hidden_size); } +void invokeBuildSequenceLength(const int * input, int batch_size, int *sequnce_length, int max_seq_length,cudaStream_t stream) { + buildSequnceLength<<>>(input,sequnce_length, max_seq_length); +} + @@ -339,9 +364,8 @@ void invokeBuildRelativeAttentionBias(T* relative_attention_bias, is_bidirectional, max_distance); } -template void invokeBuildSequnceLength(const float * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); -template void invokeBuildSequnceLength(const half * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); - +template void invokeBuildSequenceLength(const float * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); +template void invokeBuildSequenceLength(const half * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); template void invokeBuildRelativeAttentionBias(float* relative_attention_bias, const float* relative_attention_bias_table, diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.h b/src/fastertransformer/kernels/bert_preprocess_kernels.h index dca4ef5c2d3b89cfcc7ab2e7fae4a38a41ba0545..19c1b288391b21f0f389442ebd49508ba3ef7cf9 100644 --- a/src/fastertransformer/kernels/bert_preprocess_kernels.h +++ b/src/fastertransformer/kernels/bert_preprocess_kernels.h @@ -32,10 +32,13 @@ void invokeGetPaddingOffset(size_t* h_token_num, template void invokeBuildEncoderAttentionMask( - T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream); + T* attention_mask, const int* q_sequence_lengths, const int* kv_sequence_lengths, const int batch_size, const int src_seq_len, const int tgt_seq_len, cudaStream_t stream); template -void invokeBuildSequnceLength(const T * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); +void invokeBuildSequenceLength(const T * input, int batch_size, int *sequnce_length, int max_seq_length, int hidden_size,cudaStream_t stream); + +void invokeBuildSequenceLength(const int* input, int batch_size, int *sequnce_length, int max_seq_length,cudaStream_t stream); + void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, const int* sequence_length, diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index 119b2b765c574dfe83f87543d898710921ce37cd..dd1062c3203682cd1b7c40031aa3498f494c853f 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1692,7 +1692,7 @@ void invokeAddFusedZP_QKVBiasTranspose(T* q_buf, dim3 block(384); dim3 grid((int)(ceil(1.0 * m * n / 384))); add_fusedQKV_ZP_bias_transpose_kernel<<>>( - q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head,h_token, padding_mask); + q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, h_token, padding_mask); } @@ -1723,6 +1723,37 @@ template void invokeAddFusedZP_QKVBiasTranspose(half* q_buf, int *padding_mask, cudaStream_t stream); +template void invokeCrossAddFusedZP_QKVBiasTranspose(float* q_buf, + float* k_buf, + float* v_buf, + float* QKV, + const float* qkv_bias, + const int batch_size, + const int seq_len, + const int tgt_seq_len, + const int head_num, + const int size_per_head, + const int h_token, + const int h_token2, + int *padding_mask, + int *padding_mask2, + cudaStream_t stream); + +template void invokeCrossAddFusedZP_QKVBiasTranspose(half* q_buf, + half* k_buf, + half* v_buf, + half* QKV, + const half* qkv_bias, + const int batch_size, + const int seq_len, + const int tgt_seq_len, + const int head_num, + const int size_per_head, + const int h_token, + const int h_token2, + int *padding_mask, + int *padding_mask2, + cudaStream_t stream); @@ -1762,6 +1793,44 @@ __global__ void invokeCrossAddFusedQKVBiasTransposeQ(T* q_buf, + seq_id * size_per_head + size_id] = val; } } +template +__global__ void add_fusedQKV_ZP_bias_transpose_kernel_q(T* q_buf, + const T* __restrict QKV, + const U* __restrict qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int token_num, + int *mask_offset) +{ + // QKV: [m, 1, n] + // qkv_bias: [1, n] + // q_buf: [batch, head_num, seq_len, size_per_head] + + T* qkv_ptr[1] = {q_buf}; + const int n = head_num * size_per_head; + for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 1 * n; + index += gridDim.x * blockDim.x) { + int bias_id = index % (1 * n); + T val = ldg(&QKV[index]); + if (qkv_bias != nullptr) + val += (T)ldg(&qkv_bias[bias_id]); + int tmp_index = index; + // const int target_batch_id = tmp_index / (seq_len * 1 * n); + int token_id = tmp_index / (1 *n); + int batch_id = (token_id + ldg(&mask_offset[token_id])) / seq_len; + int seq_id = (token_id + ldg(&mask_offset[token_id])) % seq_len; + tmp_index -= token_id * 1 * n; + int qkv_id = tmp_index / n; + tmp_index -= qkv_id * n; + int head_id = tmp_index / size_per_head; + int size_id = tmp_index - head_id * size_per_head; + int dst_id = batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head + + seq_id * size_per_head + size_id; + qkv_ptr[qkv_id][dst_id] = val; + } +} template __global__ void invokeCrossTransposeQ(T* q_buf, @@ -1831,6 +1900,46 @@ __global__ void invokeCrossAddFusedQKVBiasTransposeKV(T* k_buf, T* v_buf, } +template +__global__ void add_fusedQKV_ZP_bias_transpose_kernel_kv(T* k_buf, T* v_buf, + const T* __restrict QKV, + const U* __restrict qkv_bias, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int token_num, + int *mask_offset) +{ + // QKV: [m, 2, n] + // qkv_bias: [2, n] + // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] + + T* qkv_ptr[2] = {k_buf, v_buf}; + const int n = head_num * size_per_head; + for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 2 * n; + index += gridDim.x * blockDim.x) { + int bias_id = index % (2 * n); + T val = ldg(&QKV[index]); + if (qkv_bias != nullptr) + val += (T)ldg(&qkv_bias[bias_id]); + int tmp_index = index; + int token_id = tmp_index / (2 *n); + int batch_id = (token_id + ldg(&mask_offset[token_id])) / seq_len; + int seq_id = (token_id + ldg(&mask_offset[token_id])) % seq_len; + tmp_index -= token_id * 2 * n; + int qkv_id = tmp_index / n; + tmp_index -= qkv_id * n; + int head_id = tmp_index / size_per_head; + int size_id = tmp_index - head_id * size_per_head; + int dst_id = batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head + + seq_id * size_per_head + size_id; + //printf("%d %d\n", head_id, size_id); + qkv_ptr[qkv_id][dst_id] = val; + } +} + + template __global__ void invokeCrossTransposeKV(T* k_buf, T* v_buf, const T* __restrict QKV, @@ -1907,7 +2016,42 @@ void invokeCrossAddFusedQKVBiasTranspose(T* q_buf, } } +template +void invokeCrossAddFusedZP_QKVBiasTranspose(T* q_buf, + T* k_buf, + T* v_buf, + T* QKV, + const U* qkv_bias, + const int batch_size, + const int seq_len, + const int tgt_seq_len, + const int head_num, + const int size_per_head, + const int h_token, + const int h_token2, + int *padding_mask, + int *padding_mask2, + cudaStream_t stream) +{ + const int m = h_token; + const int n = head_num * size_per_head; + cudaMemsetAsync(q_buf, 0, batch_size * seq_len * n * sizeof(T), stream); + dim3 block(384); + dim3 grid((int)(ceil(1.0 * m * n / 384))); + add_fusedQKV_ZP_bias_transpose_kernel_q<<>>( + q_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, h_token, padding_mask); + + const int m2 = h_token2; + const int n2 = head_num * size_per_head; + cudaMemsetAsync(k_buf, 0, batch_size * tgt_seq_len * n2 * sizeof(T), stream); + cudaMemsetAsync(v_buf, 0, batch_size * tgt_seq_len * n2 * sizeof(T), stream); + dim3 block2(384); + dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); + qkv_bias = (qkv_bias == nullptr) ? nullptr : qkv_bias + n2; + add_fusedQKV_ZP_bias_transpose_kernel_kv<<>>( + k_buf, v_buf, QKV + m * n, qkv_bias, batch_size, tgt_seq_len, head_num, size_per_head, h_token2, padding_mask2); +} template void invokeCrossAddFusedQKVBiasTranspose(float* q_buf, float* k_buf, float* v_buf, diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.h b/src/fastertransformer/kernels/unfused_attention_kernels.h index 2ccdbaa78c37736bcf3d3e483ed256ae125d2c70..92aa995e45f53515a8f553098d414c122f91e050 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.h +++ b/src/fastertransformer/kernels/unfused_attention_kernels.h @@ -122,8 +122,22 @@ void invokeAddFusedZP_QKVBiasTranspose(T* q_buf, const int h_token, int *padding_mask, cudaStream_t stream); - - +template +void invokeCrossAddFusedZP_QKVBiasTranspose(T* q_buf, + T* k_buf, + T* v_buf, + T* QKV, + const U* qkv_bias, + const int batch_size, + const int seq_len, + const int tgt_seq_len, + const int head_num, + const int size_per_head, + const int h_token, + const int h_token2, + int *padding_mask, + int *padding_mask2, + cudaStream_t stream); template void invokeCrossAddFusedQKVBiasTranspose(T* q_buf, diff --git a/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc index bc5934692d3c0a0b23f567d09fd545c1f561f896..fc5aa29ac57c385419063927ff87becf1f49af77 100644 --- a/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc +++ b/src/fastertransformer/layers/ms_layers/MSEncoderLayer.cc @@ -122,10 +122,11 @@ int MSELayer::forward(std::vector* output_tensors, (void*)encoder_weights->attention.attention_output_weight.kernel, (void*)encoder_weights->layernorm2.gamma, (void*)encoder_weights->encoder_output_mapping.kernel, - (void*)encoder_weights->encoder_output_projection.kernel + (void*)encoder_weights->encoder_output_projection.kernel, + (void*)input_tensors->at(3).data }; - forwardEncoder(inputs, 9, outputs, 1, &encoder_param_, buf_); + forwardEncoder(inputs, 10, outputs, 1, &encoder_param_, buf_); } else { void* inputs[] = {(void*)input_tensors->at(0).data, @@ -141,8 +142,9 @@ int MSELayer::forward(std::vector* output_tensors, (void*)encoder_weights->encoder_output_mapping.kernel, (void*)encoder_weights->encoder_output_mapping.bias, (void*)encoder_weights->encoder_output_projection.kernel, - (void*)encoder_weights->encoder_output_projection.bias}; - forwardEncoder(inputs, 14, outputs, 1, &encoder_param_, buf_); + (void*)encoder_weights->encoder_output_projection.bias, + (void*)input_tensors->at(2).data}; + forwardEncoder(inputs, 15, outputs, 1, &encoder_param_, buf_); } } else { @@ -155,10 +157,11 @@ int MSELayer::forward(std::vector* output_tensors, (void*)encoder_weights->attention.attention_output_weight.kernel, (void*)encoder_weights->layernorm2.gamma, (void*)encoder_weights->encoder_output_mapping.kernel, - (void*)encoder_weights->encoder_output_projection.kernel + (void*)encoder_weights->encoder_output_projection.kernel, + (void*)input_tensors->at(3).data }; - forwardEncoder(inputs, 9, outputs, 1, &encoder_param_, buf_); + forwardEncoder(inputs, 10, outputs, 1, &encoder_param_, buf_); } else { void* inputs[] = {(void*)input_tensors->at(0).data, @@ -174,8 +177,9 @@ int MSELayer::forward(std::vector* output_tensors, (void*)encoder_weights->encoder_output_projection.kernel, (void*)encoder_weights->encoder_output_projection.bias, (void*)encoder_weights->layernorm2.gamma, - (void*)encoder_weights->layernorm2.beta}; - forwardEncoder(inputs, 3, outputs, 1, &encoder_param_, buf_); + (void*)encoder_weights->layernorm2.beta, + (void*)input_tensors->at(2).data}; + forwardEncoder(inputs, 15, outputs, 1, &encoder_param_, buf_); } } diff --git a/src/fastertransformer/layers/ms_layers/attention.cc b/src/fastertransformer/layers/ms_layers/attention.cc index b129388e009b6ca5d2c833e29a917c9c620fbdb2..b3d20ce9993eada02c690ef6f1fc034a6f895083 100644 --- a/src/fastertransformer/layers/ms_layers/attention.cc +++ b/src/fastertransformer/layers/ms_layers/attention.cc @@ -5,7 +5,7 @@ #include "src/fastertransformer/kernels/unfused_attention_kernels.h" #include "src/fastertransformer/layers/ms_layers/debug_utils.h" #include "src/fastertransformer/layers/ms_layers/opt_allocator.h" -// #include "src/fastertransformer/kernels/bert_preprocess_kernels.h" +#include "src/fastertransformer/kernels/bert_preprocess_kernels.h" #include "fmha_cutlass.h" #include @@ -21,46 +21,38 @@ public: virtual size_t getWorkspaceSize(attentionParamRun* param) = 0; virtual void runMha(T* inputs[], int in_len, T* output[], int out_len, attentionParamRun* param, void* ws) = 0; }; - template class unfusedMhaDispatch: public mhaDispatch { public: size_t getWorkspaceSize(attentionParamRun* param) override { - size_t attn_out_size = param->common_param->batch_size * param->common_param->head_num - * param->common_param->head_size * param->common_param->tgt_seq_len; - size_t size_q = - param->common_param->batch_size * param->common_param->src_seq_len * param->common_param->hidden_size; - size_t size_k = - param->common_param->batch_size * param->common_param->tgt_seq_len * param->common_param->hidden_size; - size_t size_v = size_k; - size_t qkv_len = size_q + size_k + size_v; - size_t q_buf_2_len = size_q; - size_t qk_buf_len = param->common_param->batch_size * param->common_param->head_num - * param->common_param->src_seq_len * param->common_param->tgt_seq_len; - size_t qkv_buf_2_len = - param->common_param->batch_size * param->common_param->src_seq_len * param->common_param->hidden_size; - size_t qkv_buf_3_len = qkv_buf_2_len; - - OptAllocator allocator(ALIGN_SIZE); - param->attn.qkv_buf = allocator.Malloc(qkv_len * sizeof(T)); - param->attn.q_buf_2 = allocator.Malloc(q_buf_2_len * sizeof(T)); - param->attn.output1 = allocator.Malloc(attn_out_size * sizeof(T)); - param->attn.output2 = allocator.Malloc(attn_out_size * sizeof(T)); - allocator.Free(param->attn.qkv_buf); - param->attn.qk_buf = allocator.Malloc(qk_buf_len * sizeof(T)); - allocator.Free(param->attn.q_buf_2); - allocator.Free(param->attn.output1); - param->attn.qkv_buf_2 = allocator.Malloc(qkv_buf_2_len * sizeof(T)); - allocator.Free(param->attn.output2); - allocator.Free(param->attn.qk_buf); - param->attn.qkv_buf_3 = allocator.Malloc(qkv_buf_3_len * sizeof(T)); - allocator.Free(param->attn.qkv_buf_2); - allocator.Free(param->attn.qkv_buf_3); - return allocator.total_size(); - } - - void runMha(T* inputs[], int in_len, T* output[], int out_len, attentionParamRun* param, void* ws) override + size_t attn_out_size = param->common_param->batch_size * param->common_param->head_num * param->common_param->head_size * param->common_param->tgt_seq_len; + size_t size_q = param->common_param->batch_size * param->common_param->src_seq_len * param->common_param->hidden_size; + size_t size_k = param->common_param->batch_size * param->common_param->tgt_seq_len * param->common_param->hidden_size; + size_t size_v = size_k; + size_t qkv_len = size_q + size_k + size_v; + size_t q_buf_2_len = size_q; + size_t qk_buf_len = param->common_param->batch_size * param->common_param->head_num * param->common_param->src_seq_len * param->common_param->tgt_seq_len; + size_t qkv_buf_2_len = param->common_param->batch_size * param->common_param->src_seq_len * param->common_param->hidden_size; + size_t qkv_buf_3_len = qkv_buf_2_len; + OptAllocator allocator(ALIGN_SIZE); + param->attn.qkv_buf = allocator.Malloc(qkv_len * sizeof(T)); + param->attn.q_buf_2 = allocator.Malloc(q_buf_2_len * sizeof(T)); + param->attn.output1 = allocator.Malloc(attn_out_size * sizeof(T)); + param->attn.output2 = allocator.Malloc(attn_out_size * sizeof(T)); + allocator.Free(param->attn.qkv_buf); + param->attn.qk_buf = allocator.Malloc(qk_buf_len * sizeof(T)); + allocator.Free(param->attn.q_buf_2); + allocator.Free(param->attn.output1); + param->attn.qkv_buf_2 = allocator.Malloc(qkv_buf_2_len * sizeof(T)); + allocator.Free(param->attn.output2); + allocator.Free(param->attn.qk_buf); + param->attn.qkv_buf_3 = allocator.Malloc(qkv_buf_3_len * sizeof(T)); + allocator.Free(param->attn.qkv_buf_2); + allocator.Free(param->attn.qkv_buf_3); + return allocator.total_size(); +} + void runMha(T* inputs[], int in_len, T* output[], int out_len, attentionParamRun* param, void* ws) override { // setup inputs T* q_buf_2 = inputs[0]; @@ -143,17 +135,29 @@ public: param->common_param->batch_size * param->common_param->head_num, param->common_param->cublas_handle, param->common_param->algo); - invokeTransposeQKV(static_cast(output[0]), + if (param->attn.padding_offset == nullptr) { + invokeTransposeQKV(static_cast(output[0]), static_cast(qkv_buf_2), param->common_param->batch_size, param->common_param->src_seq_len, param->common_param->head_num, param->common_param->head_size, param->common_param->stream); + } + else { + invokeTransposeAttentionOutRemovePadding(qkv_buf_2, + output[0], + param->common_param->h_token_num, + param->common_param->batch_size, + param->common_param->src_seq_len, + param->common_param->head_num, + param->common_param->head_size, + param->attn.padding_offset, + param->common_param->stream); + } return; } }; - template class fusedCutlassMhaDispatch: public mhaDispatch { public: @@ -186,7 +190,7 @@ public: param->attn.output2 = allocator.Malloc(attn_out_size * sizeof(T)); allocator.Free(param->attn.qkv_buf); param->attn.qk_buf = 0; // not in use - param->attn.qkv_buf_2 = 0; // not in use + param->attn.qkv_buf_2 = allocator.Malloc(qkv_buf_2_len * sizeof(T)); param->attn.qkv_buf_3 = allocator.Malloc(qkv_buf_3_len * sizeof(T)); size_t size = 0; typedef typename std::conditional::value, cutlass::half_t, float>::type Type; @@ -203,7 +207,19 @@ public: void runMha(T* inputs[], int in_len, T* output[], int out_len, attentionParamRun* param, void* ws) { typedef typename std::conditional::value, cutlass::half_t, float>::type Type; - forward_fmha(reinterpret_cast(inputs), out_len, reinterpret_cast(output), 1, param, ws); + T* qkv_buf_2 = reinterpret_cast(static_cast(ws) + param->attn.qkv_buf_2); + forward_fmha(reinterpret_cast(inputs), out_len, reinterpret_cast(&qkv_buf_2), 1, param, ws); + if(param->attn.padding_offset != nullptr) + { + invokeRemovePadding((float*)(*output), + (const float*)qkv_buf_2, + param->attn.padding_offset, + param->common_param->h_token_num, + param->common_param->head_num * param->common_param->head_size, + param->common_param->stream); + } else { + *output = qkv_buf_2; + } } }; @@ -239,12 +255,9 @@ void forward_attn(T* inputs[], int in_len, T* output[], int out_len, attentionPa T* output2 = reinterpret_cast(static_cast(ws) + param->attn.output2); T* mha = reinterpret_cast(static_cast(ws) + param->attn.mha); - int gemm_dims[] = {3 * (int)param->common_param->hidden_size, - (int)param->common_param->batch_size * (int)param->common_param->src_seq_len, - (int)param->common_param->hidden_size}; - int gemm_lds[] = {3 * (int)param->common_param->hidden_size, - (int)param->common_param->hidden_size, - 3 * (int)param->common_param->hidden_size}; + int gemm_dims[] = { + 3 * (int)param->common_param->hidden_size, (int)param->common_param->h_token_num, (int)param->common_param->hidden_size}; + int gemm_lds[] = {3 * (int)param->common_param->hidden_size, (int)param->common_param->hidden_size, 3 * (int)param->common_param->hidden_size}; T* from_tensor = reinterpret_cast(inputs[param->common_param->in_idx++]); cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; @@ -257,7 +270,7 @@ void forward_attn(T* inputs[], int in_len, T* output[], int out_len, attentionPa T beta = 0.0f; if (param->attn.is_cross) { gemm_dims[0] = param->common_param->hidden_size; - gemm_dims[1] = param->common_param->batch_size * param->common_param->src_seq_len; + gemm_dims[1] = param->common_param->h_token_num; gemm_dims[2] = param->common_param->hidden_size; gemm_lds[0] = param->common_param->hidden_size; gemm_lds[1] = param->common_param->hidden_size; @@ -275,17 +288,19 @@ void forward_attn(T* inputs[], int in_len, T* output[], int out_len, attentionPa &beta, param->common_param->cublas_handle, param->common_param->algo); - + gemm_dims[0] = 2 * param->common_param->hidden_size; - gemm_dims[1] = param->common_param->batch_size * param->common_param->tgt_seq_len; + gemm_dims[1] = param->common_param->h_token_num2; + gemm_dims[2] = param->common_param->hidden_size; + gemm_lds[0] = 2 * param->common_param->hidden_size; + gemm_lds[1] = param->common_param->hidden_size; gemm_lds[2] = 2 * param->common_param->hidden_size; + T* weight_kv = reinterpret_cast(inputs[param->common_param->in_idx++]); CublasGemmWrapper(weight_kv, encoder_output, - qkv_buf - + (param->common_param->batch_size * param->common_param->src_seq_len) - * param->common_param->hidden_size, + qkv_buf + param->common_param->h_token_num * param->common_param->hidden_size, gemm_dims, gemm_lds, gemm_ops, @@ -295,7 +310,8 @@ void forward_attn(T* inputs[], int in_len, T* output[], int out_len, attentionPa param->common_param->cublas_handle, param->common_param->algo); T* bias_qkv = (param->attn.qkv_bias) ? reinterpret_cast(inputs[param->common_param->in_idx++]) : nullptr; - invokeCrossAddFusedQKVBiasTranspose(q_buf_2, + if (param->attn.padding_offset == nullptr) { + invokeCrossAddFusedQKVBiasTranspose(q_buf_2, output1, output2, qkv_buf, @@ -306,6 +322,24 @@ void forward_attn(T* inputs[], int in_len, T* output[], int out_len, attentionPa param->common_param->head_num, param->common_param->head_size, param->common_param->stream); + } + else{ + invokeCrossAddFusedZP_QKVBiasTranspose(q_buf_2, + output1, + output2, + qkv_buf, + bias_qkv, + param->common_param->batch_size, + param->common_param->src_seq_len, + param->common_param->tgt_seq_len, + param->common_param->head_num, + param->common_param->head_size, + param->common_param->h_token_num, + param->common_param->h_token_num2, + param->attn.padding_offset, + param->attn.padding_offset2, + param->common_param->stream); + } } else { T* weight_qkv = reinterpret_cast(inputs[param->common_param->in_idx++]); @@ -321,23 +355,50 @@ void forward_attn(T* inputs[], int in_len, T* output[], int out_len, attentionPa param->common_param->cublas_handle, param->common_param->algo); T* bias_qkv = (param->attn.qkv_bias) ? reinterpret_cast(inputs[param->common_param->in_idx++]) : nullptr; - invokeAddFusedQKVBiasTranspose(static_cast(q_buf_2), - static_cast(output1), - static_cast(output2), - static_cast(qkv_buf), - bias_qkv, - param->common_param->batch_size, - param->common_param->src_seq_len, - param->common_param->head_num, - param->common_param->head_size, - 0, - param->common_param->stream); + if (param->attn.padding_offset == nullptr) { + invokeAddFusedQKVBiasTranspose(static_cast(q_buf_2), + static_cast(output1), + static_cast(output2), + static_cast(qkv_buf), + bias_qkv, + param->common_param->batch_size, + param->common_param->src_seq_len, + param->common_param->head_num, + param->common_param->head_size, + 0, + param->common_param->stream); + } else { + invokeAddFusedZP_QKVBiasTranspose(static_cast(q_buf_2), + static_cast(output1), + static_cast(output2), + static_cast(qkv_buf), + bias_qkv, + param->common_param->batch_size, + param->common_param->src_seq_len, + param->common_param->head_num, + param->common_param->head_size, + param->common_param->h_token_num, + param->attn.padding_offset, + param->common_param->stream); + } } + gemm_ops[0] = CUBLAS_OP_T; + gemm_ops[1] = CUBLAS_OP_N; + + gemm_lds[0] = param->common_param->head_size; + gemm_lds[1] = param->common_param->head_size; + gemm_lds[2] = param->common_param->tgt_seq_len; + int gemm_strides[] = {(int)(param->common_param->tgt_seq_len * param->common_param->head_size), + (int)(param->common_param->src_seq_len * param->common_param->head_size), + (int)(param->common_param->src_seq_len * param->common_param->tgt_seq_len)}; // Do Softmax(Q*Kt + Bias + Mask) T* attention_mask = (param->attn.mask) ? reinterpret_cast(inputs[param->common_param->in_idx++]) : nullptr; - T* position_bias = - (param->attn.position_bias) ? reinterpret_cast(inputs[param->common_param->in_idx++]) : nullptr; + T* position_bias = (param->attn.position_bias) ? reinterpret_cast(inputs[param->common_param->in_idx++]) : nullptr; + if (param->attn.padding_offset != nullptr){ + invokeBuildEncoderAttentionMask( + attention_mask, param->attn.d_sequence_length2, param->attn.d_sequence_length, param->common_param->batch_size, param->common_param->src_seq_len, param->common_param->tgt_seq_len, param->common_param->stream); + } T* in[] = {q_buf_2, output1, output2, attention_mask, position_bias}; if (param->attn.fmha_type == FmhaType_CutlassFix) { fusedCutlassMhaDispatch dispatch; @@ -347,10 +408,11 @@ void forward_attn(T* inputs[], int in_len, T* output[], int out_len, attentionPa dispatch.runMha(in, sizeof(in) / sizeof(in[0]), &qkv_buf_3, 1, param, ws); } + gemm_ops[0] = CUBLAS_OP_N; gemm_ops[1] = CUBLAS_OP_N; gemm_dims[0] = param->common_param->hidden_size; - gemm_dims[1] = param->common_param->batch_size * param->common_param->src_seq_len; + gemm_dims[1] = param->common_param->h_token_num; gemm_dims[2] = param->common_param->hidden_size; gemm_lds[0] = param->common_param->hidden_size; @@ -369,13 +431,11 @@ void forward_attn(T* inputs[], int in_len, T* output[], int out_len, attentionPa param->common_param->algo); if (param->attn.projection_bias) { - int len = param->common_param->batch_size * param->common_param->src_seq_len; - invokeAddBias(static_cast(output[0]), - (const T*)(inputs[param->common_param->in_idx++]), - len, - param->common_param->hidden_size, - param->common_param->stream); + int len = param->common_param->h_token_num; + invokeAddBias( + static_cast(output[0]), (const T*)(inputs[param->common_param->in_idx++]), len, param->common_param->hidden_size, param->common_param->stream); } + return; } diff --git a/src/fastertransformer/layers/ms_layers/debug_utils.cc b/src/fastertransformer/layers/ms_layers/debug_utils.cc index d78de583f50fd48631cd2f1b7e6295d0647933be..4f9e5bdb43fc9cbcbeb8fe8f494110d51e049260 100644 --- a/src/fastertransformer/layers/ms_layers/debug_utils.cc +++ b/src/fastertransformer/layers/ms_layers/debug_utils.cc @@ -2,7 +2,7 @@ #include #include #include "src/fastertransformer/utils/memory_utils.h" - +#include "src/fastertransformer/layers/ms_layers/debug_utils.h" namespace fastertransformer { @@ -29,6 +29,83 @@ void printTensor(char* str, T* input, int size) free(input_host); } +void printCommonParam(CommonParam param) +{ + std::cout<<"print common Param\n"; + std::cout<<"batch_size = "<(0.f); + } + bool compressed = param->common_param.eft && zp; + if (!compressed) { + if (cross) { + std::cout << "cross sum:" << std::endl; + for (int i = 0; i < param->common_param.batch_size; i++) { + for (int j = 0; j < param->common_param.head_num; j++) { + for (int k = 0; k < param->common_param.src_seq_len / 2; k++) { + for (int l = 0; l < head_size; l++) { + sum += ptr[(((i * param->common_param.head_num) + j) * param->common_param.src_seq_len + k) * head_size + l]; + } + } + } + } + } + else { + std::cout << "grid sum:" << std::endl; + for (int i = 0; i < param->common_param.batch_size; i++) { + for (int j = 0; j < param->common_param.src_seq_len / 2; j++) { + for (int k = 0; k < hidden_size; k++) { + sum += ptr[((i * param->common_param.src_seq_len) + j) * hidden_size + k]; + } + } + } + } + } + else { + std::cout << "compress sum:" << std::endl; + for (int i = 0; i < param->common_param.h_token_num * hidden_size; i++) { + sum += ptr[i]; + } + } + return static_cast(sum); + } + else { + return static_cast(0.f); + } +} + template void printTensor(char* str, float* input, int size); template void isNan(char* str, float* input, int size); template float checksum(const float* tensor, int size); template void saveTensor(const std::string& name, float* tensor, int size); +template float checksumGrid(const float* tensor, const encoderParamRun* param, bool zp, bool cross, bool ffn); template void printTensor(char* str, half* input, int size); template void isNan(char* str, half* input, int size); template half checksum(const half* tensor, int size); template void saveTensor(const std::string& name, half* tensor, int size); +template half checksumGrid(const half* tensor, const encoderParamRun* param, bool zp, bool cross, bool ffn); + +template void printTensor(char* str, int* input, int size); + } \ No newline at end of file diff --git a/src/fastertransformer/layers/ms_layers/debug_utils.h b/src/fastertransformer/layers/ms_layers/debug_utils.h index d4c493539e416d1462426e5bb76bd5a3edfebaa8..d461cedb71449becd04e27cccd9c3bdc17fecfaf 100644 --- a/src/fastertransformer/layers/ms_layers/debug_utils.h +++ b/src/fastertransformer/layers/ms_layers/debug_utils.h @@ -1,4 +1,6 @@ #pragma once +#include "src/fastertransformer/layers/ms_layers/param.h" + namespace fastertransformer { #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) @@ -13,5 +15,9 @@ template T checksum(const T* tensor, int size); template void saveTensor(const std::string& name, T* tensor, int size); - +template +T checksumGrid(const T* tensor, const encoderParamRun* param, bool zp = false, bool cross = false, bool ffn = false); +void printEncoderParamRunParam(encoderParamRun param); +void printDecoderParamRunParam(decoderParamRun param); +void printAttnParamRunParam(attentionParamRun param); } \ No newline at end of file diff --git a/src/fastertransformer/layers/ms_layers/decoder.cc b/src/fastertransformer/layers/ms_layers/decoder.cc index 22f89c457d16fcea999902f3a527b8eebcb86d4b..80f3f76f33d46085d98b21aa5d86b3ac13b47aae 100644 --- a/src/fastertransformer/layers/ms_layers/decoder.cc +++ b/src/fastertransformer/layers/ms_layers/decoder.cc @@ -21,9 +21,23 @@ size_t GetDecoderLayerWorkspaceSize(decoderParamRun* param) param->common_param.batch_size * param->common_param.src_seq_len * param->ffn_param.ffn_param.ffn_hidden_size; size_t ffn_size = attn_out_size + ffn_ws_size; size_t tmp_out_size = attn_out_size; + size_t max_hidden = std::max(param->common_param.hidden_size, param->ffn_param.ffn_param.ffn_hidden_size); + size_t compress_buffer_len = param->common_param.batch_size * param->common_param.src_seq_len * max_hidden; + size_t padding_len = param->common_param.batch_size * param->common_param.src_seq_len; OptAllocator allocator(ALIGN_SIZE); + param->decoder.d_sequence_lengths_offset_buf = allocator.Malloc(param->common_param.batch_size * sizeof(int)); + param->decoder.padding_offset_buf = allocator.Malloc(padding_len * sizeof(int)); + param->decoder.d_token_num_buf = allocator.Malloc(1 * sizeof(size_t)); + param->decoder.d_sequence_lengths_offset_buf2 = allocator.Malloc(param->common_param.batch_size * sizeof(int)); + param->decoder.padding_offset_buf2 = allocator.Malloc(padding_len * sizeof(int)); + param->decoder.d_token_num_buf2 = allocator.Malloc(1 * sizeof(size_t)); + param->decoder.compress_buf = allocator.Malloc(compress_buffer_len * sizeof(T)); + param->decoder.compress_buf2 = allocator.Malloc(compress_buffer_len * sizeof(T)); param->decoder.normed_from_tensor_buf = allocator.Malloc(attn_out_size * sizeof(T)); + int tgt_seq_len =param->common_param.tgt_seq_len; + param->common_param.tgt_seq_len = param->common_param.src_seq_len; param->decoder.attn_ws_buf = allocator.Malloc(GetAttnWorkspaceSize(&(param->attn1))); + param->common_param.tgt_seq_len = tgt_seq_len; param->decoder.attn_out_buf = allocator.Malloc(attn_out_size * sizeof(T)); allocator.Free(param->decoder.attn_ws_buf); if (!param->decoder.layernorm_post) @@ -39,7 +53,7 @@ size_t GetDecoderLayerWorkspaceSize(decoderParamRun* param) allocator.Malloc(attn_out_size * sizeof(T)); allocator.Free(param->decoder.attn_out_buf); param->decoder.tmp_out_buf = - param->ffn_param.ffn_param.ffn_fp16 ? allocator.Malloc(tmp_out_size * sizeof(half)) : -1; + param->ffn_param.ffn_param.ffn_fp16 ? allocator.Malloc(tmp_out_size * sizeof(half)) : allocator.Malloc(tmp_out_size * sizeof(T)); param->decoder.ffn_ws_buf = param->ffn_param.ffn_param.ffn_fp16 ? allocator.Malloc(ffn_ws_size * sizeof(half)) : allocator.Malloc(ffn_ws_size * sizeof(T)); allocator.Free(param->decoder.ffn_ws_buf); @@ -50,14 +64,73 @@ size_t GetDecoderLayerWorkspaceSize(decoderParamRun* param) template size_t GetDecoderLayerWorkspaceSize(decoderParamRun* param); template size_t GetDecoderLayerWorkspaceSize(decoderParamRun* param); +template +void GetCompressBuffer(T* compress_buffer, T* from_tensor, int *input_ids, int* padding_offset, int* d_sequence_lengths, size_t &h_token_num, size_t* d_token_num, size_t seq_len, decoderParamRun* param) +{ + // if(seq_len == param->common_param.src_seq_len) + // invokeBuildSequenceLength( + // input_ids, param->common_param.batch_size, d_sequence_lengths, seq_len, param->common_param.hidden_size, param->common_param.stream); + // else + invokeBuildSequenceLength( + input_ids, param->common_param.batch_size, d_sequence_lengths, seq_len, param->common_param.stream); + invokeGetPaddingOffset(&h_token_num, + d_token_num, + padding_offset, + d_sequence_lengths, + param->common_param.batch_size, + seq_len, + param->common_param.stream); + if (h_token_num * 2 <= param->common_param.batch_size * seq_len) { + param->common_param.eft = true; + invokeRemovePadding(compress_buffer, + (const T*)from_tensor, + padding_offset, + h_token_num, + param->common_param.head_num * param->common_param.head_size, + param->common_param.stream); + } +} template void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, decoderParamRun* param, void* ws) { param->common_param.in_idx = 0; - size_t h_token_num = param->common_param.batch_size * param->common_param.src_seq_len; - param->common_param.h_token_num = h_token_num; - T* from_tensor = reinterpret_cast(inputs[param->common_param.in_idx++]); + size_t h_token_num = param->common_param.h_token_num = param->common_param.batch_size * param->common_param.src_seq_len; + size_t h_token_num2 = param->common_param.h_token_num2 = param->common_param.batch_size * param->common_param.tgt_seq_len; + int* d_sequence_lengths = nullptr; + int* d_sequence_lengths2 = nullptr; + int* padding_offset = nullptr; + int* padding_offset2= nullptr; + T* input_tensor = reinterpret_cast(inputs[param->common_param.in_idx++]); + T* from_tensor = input_tensor; + int idx_encoder_out = param->attn1.attn.position_bias ? 7 : 10; + T* encoder_output = reinterpret_cast(inputs[idx_encoder_out]); + int *input_ids = reinterpret_cast(inputs[in_len-1]); + int *input_ids2 = reinterpret_cast(inputs[in_len-2]); + T* compress_buffer = reinterpret_cast(static_cast(ws) + param->decoder.compress_buf); + T* compress_buffer2 = reinterpret_cast(static_cast(ws) + param->decoder.compress_buf2); + padding_offset = reinterpret_cast(static_cast(ws) + param->decoder.padding_offset_buf); + padding_offset2 = reinterpret_cast(static_cast(ws) + param->decoder.padding_offset_buf2); + d_sequence_lengths = reinterpret_cast(static_cast(ws) + param->decoder.d_sequence_lengths_offset_buf); + d_sequence_lengths2 = reinterpret_cast(static_cast(ws) + param->decoder.d_sequence_lengths_offset_buf2); + size_t* d_token_num = reinterpret_cast(static_cast(ws) + param->decoder.d_token_num_buf); + size_t* d_token_num2 = reinterpret_cast(static_cast(ws) + param->decoder.d_token_num_buf2); + GetCompressBuffer(compress_buffer, from_tensor, input_ids, padding_offset, d_sequence_lengths, h_token_num, d_token_num, param->common_param.src_seq_len, param); + if (h_token_num * 2 <= param->common_param.batch_size * param->common_param.src_seq_len) { + param->common_param.h_token_num = h_token_num; + from_tensor = compress_buffer; + } else { + padding_offset = nullptr; + } + GetCompressBuffer(compress_buffer2, encoder_output, input_ids2, padding_offset2, d_sequence_lengths2, h_token_num2, d_token_num2, param->common_param.tgt_seq_len, param); + if (h_token_num2 * 2 <= param->common_param.batch_size * param->common_param.tgt_seq_len) { + param->common_param.h_token_num2 = h_token_num2; + inputs[idx_encoder_out] = compress_buffer2; + } else { + padding_offset2 = nullptr; + } + h_token_num = param->common_param.h_token_num; + h_token_num2 = param->common_param.h_token_num2; T* attn_out = reinterpret_cast(static_cast(ws) + param->decoder.attn_out_buf); T* normed_from_tensor = reinterpret_cast(static_cast(ws) + param->decoder.normed_from_tensor_buf); T* attn_ws = reinterpret_cast(static_cast(ws) + param->decoder.attn_ws_buf); @@ -67,9 +140,16 @@ void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, dec T* normed_attn2_out = reinterpret_cast(static_cast(ws) + param->decoder.normed_attn2_out_buf); T* ffn_ws = reinterpret_cast(static_cast(ws) + param->decoder.ffn_ws_buf); T* tmp_out = reinterpret_cast(output[0]); - if ((std::is_same::value && param->ffn_param.ffn_param.ffn_fp16 == true)) { + if ((padding_offset != nullptr || std::is_same::value && param->ffn_param.ffn_param.ffn_fp16 == true)) { tmp_out = reinterpret_cast(static_cast(ws) + param->decoder.tmp_out_buf); } + T* tmp_out1 = reinterpret_cast(output[0]); + T* tmp_out2; + T* out_buf = tmp_out; + if (padding_offset != nullptr) { + tmp_out1 = compress_buffer; + tmp_out2 = compress_buffer2; + } T* gamma1 = reinterpret_cast(inputs[param->common_param.in_idx++]); T* beta1 = (param->decoder.has_beta) ? reinterpret_cast(inputs[param->common_param.in_idx++]) : nullptr; invokeGeneralT5LayerNorm(normed_from_tensor, @@ -85,8 +165,15 @@ void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, dec // if attention is embedded inside an decoder - fuse the bias to next layer normalization bool is_projection_bias = param->attn1.attn.projection_bias; param->attn1.attn.projection_bias = false; + param->attn1.attn.d_sequence_length = d_sequence_lengths; + param->attn1.attn.padding_offset = padding_offset; + param->attn1.attn.d_sequence_length2 = d_sequence_lengths; + param->attn1.attn.padding_offset2 = padding_offset; + int tgt_seq_len = param->common_param.tgt_seq_len; + param->common_param.tgt_seq_len = param->common_param.src_seq_len; forward_attn( reinterpret_cast(&inputs[param->common_param.in_idx]), in_len, &attn_out, 1, &(param->attn1), attn_ws); + param->common_param.tgt_seq_len = tgt_seq_len; param->attn1.attn.projection_bias = is_projection_bias; param->common_param.in_idx = param->common_param.in_idx + in_idx; T* projection_bias = @@ -104,11 +191,14 @@ void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, dec param->common_param.hidden_size, param->common_param.stream, param->decoder.eps2); - inputs[--param->common_param.in_idx] = normed_attn_out; in_idx = param->common_param.in_idx; is_projection_bias = param->attn2.attn.projection_bias; param->attn2.attn.projection_bias = false; + param->attn2.attn.d_sequence_length = d_sequence_lengths; + param->attn2.attn.padding_offset = padding_offset; + param->attn2.attn.d_sequence_length2 = d_sequence_lengths2; + param->attn2.attn.padding_offset2 = padding_offset2; forward_attn( reinterpret_cast(&inputs[param->common_param.in_idx]), in_len, &attn2_out, 1, &(param->attn2), attn2_ws); param->attn2.attn.projection_bias = is_projection_bias; @@ -131,7 +221,6 @@ void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, dec param->decoder.eps3); } else { - invokeGeneralAddBiasResidualT5PreLayerNormCast(attn2_out, reinterpret_cast(normed_attn2_out), attn_out, @@ -150,6 +239,7 @@ void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, dec else { forward_ffn(reinterpret_cast(inputs), in_len, &tmp_out, 1, &(param->ffn_param), ffn_ws); } + attn2_out = (param->decoder.layernorm_post) ? normed_attn2_out : attn2_out; T* ffn_bias = (param->ffn_param.ffn_param.ffn_bias) ? reinterpret_cast(inputs[param->common_param.in_idx++]) : nullptr; @@ -165,7 +255,7 @@ void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, dec if (param->decoder.layernorm_post) { invokeAddBiasResidualSameTypeCast(reinterpret_cast(tmp_out), reinterpret_cast(attn2_out), - reinterpret_cast(output[0]), + reinterpret_cast(tmp_out2), ffn_bias, h_token_num, param->common_param.hidden_size, @@ -174,12 +264,37 @@ void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, dec else { invokeAddBiasResidualCast(reinterpret_cast(tmp_out), reinterpret_cast(attn2_out), - reinterpret_cast(output[0]), + reinterpret_cast(tmp_out2), ffn_bias, h_token_num, param->common_param.hidden_size, param->common_param.stream); } + tmp_out = tmp_out2; + } + if(param->decoder.is_layernorm){ + std::cout<<"decoder is_layernorm\n"; + T* gamma4 = reinterpret_cast(inputs[param->common_param.in_idx++]); + T* beta4 = (param->decoder.has_beta) ? reinterpret_cast(inputs[param->common_param.in_idx++]) : nullptr; + invokeGeneralT5LayerNorm(tmp_out1, + tmp_out, + gamma4, + beta4, + h_token_num, + param->common_param.hidden_size, + param->common_param.stream, + param->decoder.eps4); + out_buf = tmp_out1; + } else { + out_buf = tmp_out; + } + if (padding_offset != nullptr) { + cudaMemsetAsync(output[0], + 0, + param->common_param.batch_size * param->common_param.src_seq_len * param->common_param.head_size * param->common_param.head_num * sizeof(T), + param->common_param.stream); + invokeRebuildPadding( + (T*)output[0], out_buf, padding_offset, h_token_num, param->common_param.hidden_size, param->common_param.stream); } return; } @@ -188,4 +303,6 @@ template void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, decoderParamRun* param, void* ws); template void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, decoderParamRun* param, void* ws); +template void GetCompressBuffer(float* compress_buffer, float* from_tensor, int *input_ids, int* padding_offset, int* d_sequence_lengths, size_t &h_token_num, size_t* d_token_num, size_t seq_len, decoderParamRun* param); +template void GetCompressBuffer(half* compress_buffer, half* from_tensor, int *input_ids, int* padding_offset, int* d_sequence_lengths, size_t &h_token_num, size_t* d_token_num, size_t seq_len, decoderParamRun* param); } // namespace fastertransformer diff --git a/src/fastertransformer/layers/ms_layers/decoder.h b/src/fastertransformer/layers/ms_layers/decoder.h index e1e47f3e86980f8b2cd6882d3e384cb3385bf705..61af13a9f1ba4596636b59b93d7d6ae630e7fd25 100644 --- a/src/fastertransformer/layers/ms_layers/decoder.h +++ b/src/fastertransformer/layers/ms_layers/decoder.h @@ -12,4 +12,6 @@ template size_t GetDecoderLayerWorkspaceSize(decoderParamRun* param); template void forwardDecoder(void* inputs[], int in_len, void* output[], int out_len, decoderParamRun* param, void* ws); +template +void GetCompressBuffer(T* compress_buffer2, T* from_tensor, int *input_ids, int* padding_offset, int* d_sequence_lengths, size_t &h_token_num, size_t* d_token_num, size_t seq_len, decoderParamRun* param); } // namespace fastertransformer diff --git a/src/fastertransformer/layers/ms_layers/encoder.cc b/src/fastertransformer/layers/ms_layers/encoder.cc index 179d70509c1fd3a0651cd6453240c1e2e59dd992..bab1589772c9ee74c0e8cb30f884aaa62614e4d7 100644 --- a/src/fastertransformer/layers/ms_layers/encoder.cc +++ b/src/fastertransformer/layers/ms_layers/encoder.cc @@ -25,12 +25,21 @@ size_t GetEncoderLayerWorkspaceSize(encoderParamRun* param) size_t normed_attn_out_len = param->common_param.batch_size * param->common_param.src_seq_len * param->common_param.hidden_size; size_t tmp_out_size = attn_out_len; + size_t max_hidden = std::max(param->common_param.hidden_size, param->ffn_param.ffn_param.ffn_hidden_size); + size_t compress_buffer_len = param->common_param.batch_size * param->common_param.src_seq_len * max_hidden; + size_t padding_len = param->common_param.batch_size * param->common_param.src_seq_len; OptAllocator allocator(ALIGN_SIZE); + param->encoder.d_sequence_lengths_offset_buf = allocator.Malloc(param->common_param.batch_size * sizeof(int)); + param->encoder.padding_offset_buf = allocator.Malloc(padding_len * sizeof(int)); + param->encoder.d_token_num_buf = allocator.Malloc(1 * sizeof(size_t)); + param->encoder.compress_buf = allocator.Malloc(compress_buffer_len * sizeof(T)); param->encoder.normed_from_tensor_buf = (!param->encoder.layernorm_post || param->attn.attn.position_bias) ? allocator.Malloc(normed_from_tensor_len * sizeof(T)) : - -1; + 0; param->encoder.attn_ws_buf = allocator.Malloc(GetAttnWorkspaceSize(&(param->attn))); param->encoder.attn_out_buf = allocator.Malloc(attn_out_len * sizeof(T)); + allocator.Free(param->encoder.d_token_num_buf); + allocator.Free(param->encoder.d_sequence_lengths_offset_buf); allocator.Free(param->encoder.attn_ws_buf); if (!param->encoder.layernorm_post || param->attn.attn.position_bias) allocator.Free(param->encoder.normed_from_tensor_buf); @@ -38,14 +47,36 @@ size_t GetEncoderLayerWorkspaceSize(encoderParamRun* param) ((!param->encoder.layernorm_post || param->attn.attn.position_bias) || param->ffn_param.ffn_param.ffn_fp16) ? param->ffn_param.ffn_param.ffn_fp16 ? allocator.Malloc(normed_attn_out_len * sizeof(half)) : allocator.Malloc(normed_attn_out_len * sizeof(T)) : - -1; + 0; param->encoder.ffn_ws_buf = param->ffn_param.ffn_param.ffn_fp16 ? allocator.Malloc(ffn_len * sizeof(half)) : allocator.Malloc(ffn_len * sizeof(T)); param->encoder.tmp_out_buf = param->ffn_param.ffn_param.ffn_fp16 ? allocator.Malloc(tmp_out_size * sizeof(half)) : allocator.Malloc(tmp_out_size * sizeof(T)); + param->encoder.norm_out_buf = allocator.Malloc(tmp_out_size * sizeof(T)); return allocator.total_size(); } - +template +void GetCompressBuffer(T* compress_buffer, T* from_tensor, int *input_ids, int* padding_offset, int* d_sequence_lengths, size_t &h_token_num, size_t* d_token_num, size_t seq_len, encoderParamRun* param) +{ + invokeBuildSequenceLength( + input_ids, param->common_param.batch_size, d_sequence_lengths, seq_len, param->common_param.stream); + invokeGetPaddingOffset(&h_token_num, + d_token_num, + padding_offset, + d_sequence_lengths, + param->common_param.batch_size, + seq_len, + param->common_param.stream); + if (h_token_num * 2 <= param->common_param.batch_size * seq_len) { + param->common_param.eft = true; + invokeRemovePadding(compress_buffer, + (const T*)from_tensor, + padding_offset, + h_token_num, + param->common_param.head_num * param->common_param.head_size, + param->common_param.stream); + } +} template size_t GetEncoderLayerWorkspaceSize(encoderParamRun* param); template size_t GetEncoderLayerWorkspaceSize(encoderParamRun* param); template @@ -53,17 +84,38 @@ void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, enc { param->common_param.in_idx = 0; size_t h_token_num = param->common_param.batch_size * param->common_param.src_seq_len; - param->common_param.h_token_num = h_token_num; - T* from_tensor = reinterpret_cast(inputs[param->common_param.in_idx++]); + param->common_param.h_token_num = param->common_param.h_token_num2 = h_token_num; + T* input_tensor = reinterpret_cast(inputs[param->common_param.in_idx++]); + T* from_tensor = input_tensor; + int *input_ids = reinterpret_cast(inputs[in_len-1]); + T* compress_buffer = reinterpret_cast(static_cast(ws) + param->encoder.compress_buf); + int* padding_offset = reinterpret_cast(static_cast(ws) + param->encoder.padding_offset_buf); + int* d_sequence_lengths = reinterpret_cast(static_cast(ws) + param->encoder.d_sequence_lengths_offset_buf); + size_t* d_token_num = reinterpret_cast(static_cast(ws) + param->encoder.d_token_num_buf); + param->common_param.eft = false; + GetCompressBuffer(compress_buffer, from_tensor, input_ids, padding_offset, d_sequence_lengths, h_token_num, d_token_num, param->common_param.src_seq_len, param); + if (h_token_num * 2 <= param->common_param.batch_size * param->common_param.src_seq_len) { + param->common_param.h_token_num = h_token_num; + from_tensor = compress_buffer; + } else { + padding_offset = nullptr; + } + h_token_num = param->common_param.h_token_num; T* attn_out = reinterpret_cast(static_cast(ws) + param->encoder.attn_out_buf); T* normed_from_tensor = reinterpret_cast(static_cast(ws) + param->encoder.normed_from_tensor_buf); T* attn_ws = reinterpret_cast(static_cast(ws) + param->encoder.attn_ws_buf); T* normed_attn_out = reinterpret_cast(static_cast(ws) + param->encoder.normed_attn_out_buf); T* ffn_ws = reinterpret_cast(static_cast(ws) + param->encoder.ffn_ws_buf); T* tmp_out = reinterpret_cast(output[0]); - if ((std::is_same::value && param->ffn_param.ffn_param.ffn_fp16 == true)) { + if (padding_offset != nullptr || (std::is_same::value && param->ffn_param.ffn_param.ffn_fp16 == true)) { tmp_out = reinterpret_cast(static_cast(ws) + param->encoder.tmp_out_buf); } + T* tmp_out1 = reinterpret_cast(output[0]); + T* out_buf = tmp_out; + if (padding_offset != nullptr) { + tmp_out1 = compress_buffer; + } + T* norm_out = reinterpret_cast(static_cast(ws) + param->encoder.norm_out_buf); if (param->encoder.layernorm_post == false || param->attn.attn.position_bias) { T* gamma1 = reinterpret_cast(inputs[param->common_param.in_idx++]); T* beta1 = (param->encoder.has_beta) ? reinterpret_cast(inputs[param->common_param.in_idx++]) : nullptr; @@ -79,10 +131,13 @@ void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, enc normed_from_tensor = from_tensor; } inputs[--param->common_param.in_idx] = normed_from_tensor; - // if attention is embedded inside an encoder - fuse the bias to next layer normalization bool is_projection_bias = param->attn.attn.projection_bias; param->attn.attn.projection_bias = false; int in_idx = param->common_param.in_idx; + param->attn.attn.d_sequence_length = d_sequence_lengths; + param->attn.attn.padding_offset = padding_offset; + param->attn.attn.d_sequence_length2 = d_sequence_lengths; + param->attn.attn.padding_offset2 = padding_offset; forward_attn( reinterpret_cast(&inputs[param->common_param.in_idx]), in_len, &attn_out, 1, &(param->attn), attn_ws); param->common_param.in_idx = param->attn.common_param->in_idx + in_idx; @@ -144,8 +199,6 @@ void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, enc param->encoder.eps1); } } - // forward ffn - // simulate attention inputs inputs[--param->common_param.in_idx] = normed_attn_out; if (param->ffn_param.ffn_param.ffn_fp16 == false) { forward_ffn(reinterpret_cast(inputs), in_len, &tmp_out, 1, ¶m->ffn_param, ffn_ws); @@ -171,7 +224,7 @@ void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, enc } else { invokeAddBiasResidualLayerNormCast(reinterpret_cast(tmp_out), - reinterpret_cast(output[0]), + reinterpret_cast(tmp_out1), reinterpret_cast(normed_attn_out), ffn_bias, gamma3, @@ -180,6 +233,7 @@ void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, enc param->common_param.hidden_size, param->common_param.stream, param->encoder.eps2); + out_buf = tmp_out1; } } else { @@ -196,7 +250,7 @@ void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, enc if (param->encoder.layernorm_post) { invokeAddBiasResidualSameTypeCast(reinterpret_cast(tmp_out), reinterpret_cast(attn_out), - reinterpret_cast(output[0]), + reinterpret_cast(tmp_out1), ffn_bias, h_token_num, param->common_param.hidden_size, @@ -205,14 +259,37 @@ void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, enc else { invokeAddBiasResidualCast(reinterpret_cast(tmp_out), reinterpret_cast(attn_out), - reinterpret_cast(output[0]), + reinterpret_cast(tmp_out1), ffn_bias, h_token_num, param->common_param.hidden_size, param->common_param.stream); } + out_buf = tmp_out1; } } + if(param->encoder.is_layernorm){ + std::cout<<"encoder is_layernorm\n"; + T* gamma4 = reinterpret_cast(inputs[param->common_param.in_idx++]); + T* beta4 = (param->encoder.has_beta) ? reinterpret_cast(inputs[param->common_param.in_idx++]) : nullptr; + invokeGeneralT5LayerNorm(tmp_out1, + tmp_out, + gamma4, + beta4, + h_token_num, + param->common_param.hidden_size, + param->common_param.stream, + param->encoder.eps3); + out_buf = tmp_out1; + } + if (padding_offset != nullptr) { + cudaMemsetAsync(output[0], + 0, + param->common_param.batch_size * param->common_param.src_seq_len * param->common_param.head_size * param->common_param.head_num * sizeof(T), + param->common_param.stream); + invokeRebuildPadding( + (T*)output[0], out_buf, padding_offset, h_token_num, param->common_param.hidden_size, param->common_param.stream); + } return; } @@ -220,4 +297,7 @@ template void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, encoderParamRun* param, void* ws); template void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, encoderParamRun* param, void* ws); +template void GetCompressBuffer(float* compress_buffer, float* from_tensor, int *input_ids, int* padding_offset, int* d_sequence_lengths, size_t &h_token_num, size_t* d_token_num, size_t seq_len, encoderParamRun* param); +template void GetCompressBuffer(half* compress_buffer, half* from_tensor, int *input_ids, int* padding_offset, int* d_sequence_lengths, size_t &h_token_num, size_t* d_token_num, size_t seq_len, encoderParamRun* param); + } // namespace fastertransformer diff --git a/src/fastertransformer/layers/ms_layers/encoder.h b/src/fastertransformer/layers/ms_layers/encoder.h index 37b2b3c09a598ddb0e275ae88dc22bf0b5de200b..d4f4b03203d1ed589dfae236018eda1d55f0d742 100644 --- a/src/fastertransformer/layers/ms_layers/encoder.h +++ b/src/fastertransformer/layers/ms_layers/encoder.h @@ -12,5 +12,7 @@ template size_t GetEncoderLayerWorkspaceSize(encoderParamRun* param); template void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, encoderParamRun* param, void* ws); +template +void GetCompressBuffer(T* compress_buffer2, T* from_tensor, int *input_ids, int* padding_offset, int* d_sequence_lengths, size_t &h_token_num, size_t* d_token_num, size_t seq_len, encoderParamRun* param); } // namespace fastertransformer diff --git a/src/fastertransformer/layers/ms_layers/ffn.cc b/src/fastertransformer/layers/ms_layers/ffn.cc index 1e39ecb7ac73f103c9ad9b1be00c3e154e962c63..ae2500bf55ca6638370d467a51abc14b1602ca2f 100755 --- a/src/fastertransformer/layers/ms_layers/ffn.cc +++ b/src/fastertransformer/layers/ms_layers/ffn.cc @@ -13,7 +13,7 @@ template void forward_ffn(T* inputs[], int in_len, T* output[], int out_len, ffnParamRun* param, void* ws) { size_t inter_size = param->ffn_param.ffn_hidden_size; - size_t h_token_num = param->common_param->batch_size * param->common_param->src_seq_len; + size_t h_token_num = param->common_param->h_token_num; cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; if ((std::is_same::value) || (std::is_same::value)) { diff --git a/src/fastertransformer/layers/ms_layers/param.h b/src/fastertransformer/layers/ms_layers/param.h index 0ef122ae29591c06c1cd26de72d8454f3040c035..78e5e1df6b1f94cefa1e266c7d8dae1571825d9c 100644 --- a/src/fastertransformer/layers/ms_layers/param.h +++ b/src/fastertransformer/layers/ms_layers/param.h @@ -27,10 +27,12 @@ typedef struct { size_t head_size; size_t hidden_size; size_t h_token_num; + size_t h_token_num2; cublasGemmAlgo_t algo; cublasHandle_t cublas_handle; cudaStream_t stream; int in_idx; + bool eft; } CommonParam; typedef struct { @@ -62,6 +64,10 @@ typedef struct { size_t mha; bool mask; FmhaType fmha_type; + int* padding_offset; + int* d_sequence_length; + int* padding_offset2; + int* d_sequence_length2; } attentionParam; typedef struct { @@ -73,6 +79,7 @@ typedef struct { float eps1; float eps2; float eps3; + float eps4; bool layernorm_post; bool has_beta; size_t normed_from_tensor_buf; @@ -84,6 +91,15 @@ typedef struct { size_t ffn_ws_buf; size_t normed_attn_out_buf; size_t normed_attn2_out_buf; + size_t compress_buf; + size_t d_token_num_buf; + size_t padding_offset_buf; + size_t d_sequence_lengths_offset_buf; + size_t compress_buf2; + size_t d_token_num_buf2; + size_t padding_offset_buf2; + size_t d_sequence_lengths_offset_buf2; + bool is_layernorm; } decoderParam; typedef struct { @@ -97,6 +113,7 @@ typedef struct { typedef struct { float eps1; float eps2; + float eps3; bool layernorm_post; bool has_beta; size_t normed_from_tensor_buf; @@ -105,6 +122,12 @@ typedef struct { size_t normed_attn_out_buf; size_t ffn_ws_buf; size_t tmp_out_buf; + size_t compress_buf; + size_t d_token_num_buf; + size_t padding_offset_buf; + size_t d_sequence_lengths_offset_buf; + size_t norm_out_buf; + bool is_layernorm; } encoderParam; typedef struct { @@ -112,12 +135,5 @@ typedef struct { ffnParamRun ffn_param; attentionParamRun attn; encoderParam encoder; - bool eft; - size_t d_sequence_lengths_offset_buf; - size_t padding_offset_buf; - size_t d_token_num_buf ; - size_t compress_buf; - int *padding_offset; - int* d_sequence_length; } encoderParamRun; } // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/models/bert/Bert.cc b/src/fastertransformer/models/bert/Bert.cc index ac727df8f7db4587a95811b35a74909450b2b39d..31ef8f0abacda45ff13368057f96800012ec4654 100644 --- a/src/fastertransformer/models/bert/Bert.cc +++ b/src/fastertransformer/models/bert/Bert.cc @@ -255,7 +255,7 @@ void Bert::forward(std::vector* output_tensors, switch (attention_type_) { case AttentionType::UNFUSED_MHA: { invokeBuildEncoderAttentionMask( - attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); + attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, stream_); sync_check_cuda_error(); invokeGetPaddingOffset(&h_token_num, token_num_, @@ -281,7 +281,7 @@ void Bert::forward(std::vector* output_tensors, } case AttentionType::UNFUSED_PADDED_MHA: { invokeBuildEncoderAttentionMask( - attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); + attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, stream_); sync_check_cuda_error(); h_token_num = request_batch_size * request_seq_len; bert_input_ptr = (T*)input_tensors->at(0).data; diff --git a/src/fastertransformer/models/bert_int8/BertINT8.cc b/src/fastertransformer/models/bert_int8/BertINT8.cc index 7c6347bcfe3a03d2adc8e5494c74e25ed22f85a2..872b3839a123a497eca6c8dedd6540a1e1c7eec4 100644 --- a/src/fastertransformer/models/bert_int8/BertINT8.cc +++ b/src/fastertransformer/models/bert_int8/BertINT8.cc @@ -180,7 +180,7 @@ void BertINT8::forward(std::vector* output_tensors, switch (attention_type_) { case AttentionType::UNFUSED_MHA: { invokeBuildEncoderAttentionMask( - attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); + attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, stream_); sync_check_cuda_error(); invokeGetPaddingOffset(&h_token_num, token_num_, @@ -206,7 +206,7 @@ void BertINT8::forward(std::vector* output_tensors, } case AttentionType::UNFUSED_PADDED_MHA: { invokeBuildEncoderAttentionMask( - attention_mask_, sequence_lengths, request_batch_size, request_seq_len, stream_); + attention_mask_, sequence_lengths, sequence_lengths, request_batch_size, request_seq_len, request_seq_len, stream_); sync_check_cuda_error(); h_token_num = request_batch_size * request_seq_len; bert_input_ptr = (T*)input_tensors->at(0).data; diff --git a/src/fastertransformer/models/vit/ViT.cc b/src/fastertransformer/models/vit/ViT.cc index e785f2bc6ba355f30a1e237629353483aca31afb..1595d335745458d0d14076a9e0c514e36613b47e 100644 --- a/src/fastertransformer/models/vit/ViT.cc +++ b/src/fastertransformer/models/vit/ViT.cc @@ -415,7 +415,7 @@ bool ViTTransformer::setSeqLenVec(size_t batch_size) template void ViTTransformer::setDefaultMask(size_t batch_size) { - invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, batch_size, max_seq_len_, stream_); + invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, seq_len_vec_, batch_size, max_seq_len_, max_seq_len_, stream_); } template diff --git a/src/fastertransformer/models/vit_int8/ViTINT8.cc b/src/fastertransformer/models/vit_int8/ViTINT8.cc index f61078511e9f21e3902802f2aa3b2f17491de992..ed35a4e0ec52f81ef17b6e40388581195c75f294 100644 --- a/src/fastertransformer/models/vit_int8/ViTINT8.cc +++ b/src/fastertransformer/models/vit_int8/ViTINT8.cc @@ -462,7 +462,7 @@ bool ViTTransformerINT8::setSeqLenVec(size_t batch_size) template void ViTTransformerINT8::setDefaultMask(size_t batch_size) { - invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, batch_size, max_seq_len_, stream_); + invokeBuildEncoderAttentionMask(mask_buf_, seq_len_vec_, seq_len_vec_, batch_size, max_seq_len_, max_seq_len_, stream_); } template