diff --git a/tests/apitest/kernelstest/mix/test_flash_attention.py b/tests/apitest/kernelstest/mix/test_flash_attention.py index 2338633c52fb859f46fc783f9a820c95df56b9c8..d27913ae982787d203b85a605e8177bdc0a9cbed 100644 --- a/tests/apitest/kernelstest/mix/test_flash_attention.py +++ b/tests/apitest/kernelstest/mix/test_flash_attention.py @@ -41,6 +41,10 @@ MASK_TYPE_ALIBI_WITH_PREFIX_BATCH = 8 MASK_TYPE_NO_BATCH_WITH_PREFIX = 9 MASK_TYPE_ALIBI_NO_BATCH_WITH_PREFIX = 10 MASK_TYPE_RAZOR_FUSION = 11 +UNPAD_FLASH_ATTENTION_ND = 1 +UNPAD_DYNAMIC_BATCH_FLASH_ATTENTION = 4 +UNPAD_FLASH_ATTENTION_ENCODER_ND = 10 +UNPAD_ALIBI_FLASH_ATTENTION_ND = 11 class TestFlashAttention(op_test.OpTest): def close_pack(self, in_data, seq_len): @@ -113,6 +117,7 @@ class TestFlashAttention(op_test.OpTest): self.is_compress = is_compress self.cache_type = cache_type self.q_seqlens = q_seqlens if q_seqlens is not None else kv_seqLen + self.op_type = op_type if self.embeddimv == 0: self.embeddimv = self.embeddim @@ -128,20 +133,20 @@ class TestFlashAttention(op_test.OpTest): self.layer_id = torch.from_numpy(np.array([0], dtype=np.int32)).to(torch.int32) self.q_max_seq = np.max(self.q_seqlen) self.kv_max_seq = np.max(self.kv_seqlen) - q = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(self.q_ntokens, heads * self.embeddim))) + q = torch.from_numpy(np.random.uniform(-5.0, 5.0, size=(self.q_ntokens, heads * self.embeddim))) self.q = q.to(data_type) if num_blocks is None: - self.k = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(self.layer_id[0] + 1, batch, self.max_seq, kv_head * self.embeddim))).to(data_type) - self.v = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(self.layer_id[0] + 1, batch, self.max_seq, kv_head * self.embeddimv))).to(data_type) + self.k = torch.from_numpy(np.random.uniform(-5.0, 5.0, size=(self.layer_id[0] + 1, batch, self.max_seq, kv_head * self.embeddim))).to(data_type) + self.v = torch.from_numpy(np.random.uniform(-5.0, 5.0, size=(self.layer_id[0] + 1, batch, self.max_seq, kv_head * self.embeddimv))).to(data_type) if is_splitm: maxKvSeqlen = max(self.kv_seqlen) - self.k = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(self.layer_id[0] + 1, batch, maxKvSeqlen, kv_head * self.embeddim))).to(data_type) - self.v = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(self.layer_id[0] + 1, batch, maxKvSeqlen, kv_head * self.embeddimv))).to(data_type) + self.k = torch.from_numpy(np.random.uniform(-5.0, 5.0, size=(self.layer_id[0] + 1, batch, maxKvSeqlen, kv_head * self.embeddim))).to(data_type) + self.v = torch.from_numpy(np.random.uniform(-5.0, 5.0, size=(self.layer_id[0] + 1, batch, maxKvSeqlen, kv_head * self.embeddimv))).to(data_type) else: # kv cache shape: (num_blocks, block_size, num_heads, head_size) - self.k_cache = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(num_blocks, block_size, kv_head, embeddim))).to(data_type) - self.v_cache = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(num_blocks, block_size, kv_head, embeddim))).to(data_type) + self.k_cache = torch.from_numpy(np.random.uniform(-5.0, 5.0, size=(num_blocks, block_size, kv_head, embeddim))).to(data_type) + self.v_cache = torch.from_numpy(np.random.uniform(-5.0, 5.0, size=(num_blocks, block_size, kv_head, embeddim))).to(data_type) batch = len(kv_seqLen) max_context_len = max(kv_seqLen) @@ -974,6 +979,684 @@ class TestFlashAttention(op_test.OpTest): self.k = self.close_pack(self.k.to(torch.float32), kv_seqlen).to(self.data_type) self.v = self.close_pack(self.v.to(torch.float32), kv_seqlen).to(self.data_type) + def qkMM1( + self, + query, + key, + dtype = torch.float32 + ): + result = None + qk_k = key.shape[0] + for qk_k_split in range(0, qk_k, 128): + sub_k = 128 + if qk_k_split == 512: + sub_k = 64 + query_k = query[:, qk_k_split : qk_k_split + sub_k] + key_k = key[qk_k_split : qk_k_split + sub_k, :] + result_split = torch.matmul(query_k.to(dtype), key_k.to(dtype)) + if result is None: + result = result_split + else: + result = result + result_split + return result + + def softmax( + self, + qk_result, + is_first, + gm + ): + sim = qk_result + lm = torch.max(sim, dim=-1, keepdim=True).values + + if is_first: + hm = lm + dm = torch.zeros_like(lm) + else: + hm = torch.maximum(gm, lm) + dm = gm - hm + gm = hm + + sim_sub = sim - hm + sim_sub = torch.exp(sim_sub.to(torch.float32)) + + if qk_result.dtype != torch.float32: + sim_sub = sim_sub.to(self.data_type) + + row_sum = torch.sum(sim_sub, dim=-1, keepdim=True) + + return sim_sub, row_sum, dm, gm + + def online_softmax_quant( + self, + qk_result, + sub_value, + head_idx, + is_tail, + mi, + li, + Oi, + pp_max_num, + online, + data_type + ): + mi_new = torch.max( + torch.column_stack([mi, torch.max(qk_result, dim=1).values[:, None]]), dim=1 + ).values[:, None].to(data_type) + Pij_hat = torch.exp((qk_result - mi_new).to(torch.float32)) + Pij_hat = Pij_hat.to(data_type) + li = torch.exp((mi - mi_new).to(torch.float32)).to(data_type) * li + torch.sum(Pij_hat, dim=1)[:, None] + if self.is_int8_flag: + if online: + x_q, scales, pp_max_num = self.quantize_tensor_symmetric(Pij_hat, pp_max_num) + if pp_max_num == None: + pp_max_num = pp_max_num + pv = x_q.to(torch.int32) @ sub_value.to(torch.int32) + Oi = Oi * torch.exp((mi - mi_new).to(torch.float32)).to(data_type) + self.dequantize_tensor(pv, scales, self.v_scale[head_idx]).to(data_type) + else: + x_q = Pij_hat / self.offline_scale[head_idx] + x_q = torch.round(x_q.to(torch.float32)) + pv = x_q.to(torch.int32) @ sub_value.to(torch.int32) + pv = pv.to(torch.float32) + value = self.v_scale[head_idx] * self.offline_scale[head_idx] + Oi = Oi * torch.exp((mi - mi_new).to(torch.float32)).to(data_type) + (pv * value).to(data_type) + + if is_tail: + mi = mi_new + + return mi, li, Oi, pp_max_num + + def gen_out_tensor_4_stage(self, online=False): + q_offset = 0 + k_offset = 0 + v_offset = 0 + batch = self.batch + dynamic_batch = self.dynamic_batch + batch_state = self.batch_state + heads = self.heads + is_decoder = self.is_decoder + embed = self.embeddim + embedv = self.embeddimv + max_seq = self.max_seq + q_seqlen = self.q_seqlen + kv_seqlen = self.kv_seqLen + kv_head = self.kv_head + mask = self.mask + is_mask = self.is_mask + is_razor_fusion = self.is_razor_fusion + q = self.q + k = self.k + v = self.v + if self.fav3: + q = self.q_int8 + k = self.k_int8 + v = self.v_int8 + q_ntokens = self.q_ntokens + kv_ntokens = self.kv_ntokens + layer_id = self.layer_id[0] + out = None + out_true = None + + self.encoder_logN = torch.tensor([2.0] * self.max_seq).to(torch.float32) + self.encoder_logN.uniform_(1, 2) + self.decoder_logN = torch.tensor([2.0] * batch).to(torch.float32) + self.decoder_logN.uniform_(1, 2) + for idx in range(batch): + if dynamic_batch and batch_state[idx] == 0 and not is_decoder: + continue + if dynamic_batch and batch_state[idx] == 0: + output = torch.zeros([heads, q_s, embedv]) + output = torch.permute(output, (1, 0, 2)) + if out is None: + out = output + if not self.fav3: + out_true = output + else: + out = torch.cat((out, output), 0) + if not self.fav3: + out_true = torch.cat((out_true, output), 0) + q_offset += q_s + k_offset += max_seq + v_offset += max_seq + continue + q_s = q_seqlen[idx] + kv_s = kv_seqlen[idx] + q_slice = q[q_offset:q_offset + q_s][:] + q_slice = q_slice.view(q_s, heads, embed) + q_slice = torch.permute(q_slice, (1, 0, 2)) + k_slice = k[layer_id][idx][:kv_s][:] + k_slice = k_slice.view(kv_s, kv_head, embed) + k_slice_t = torch.permute(k_slice, (1, 2, 0)) # get K^T + v_slice = v[layer_id][idx][:kv_s][:] + v_slice = v_slice.view(kv_s, kv_head, embedv) + v_slice = torch.permute(v_slice, (1, 0, 2)) + context_size = 128 + group_num = self.heads // self.kv_head + if group_num != 1: + k_slice_t = k_slice_t.repeat_interleave(group_num, dim=0) + v_slice = v_slice.repeat_interleave(group_num, dim=0) + out_B = torch.zeros([q_s, heads, embedv], dtype = self.data_type) + out_true_B = torch.zeros([q_s, heads, embedv], dtype = torch.float32) + for head_idx in range(heads): + q_slice_N = q_slice[head_idx, :, :] + k_slice_t_N = k_slice_t[head_idx, :, :] + v_slice_N = v_slice[head_idx, :, :] + gl = None + gl_high = None + go = None + go_high = None + if self.is_int8_flag: + Oi = torch.zeros((q_s, embed)).to(self.data_type) + li = torch.zeros((q_s, 1)).to(self.data_type) + mi = torch.full((q_s, 1), -torch.inf).to(self.data_type) + Oi_high = torch.zeros((q_s, embed)).to(torch.float32) + li_high = torch.zeros((q_s, 1)).to(torch.float32) + mi_high = torch.full((q_s, 1), -torch.inf).to(torch.float32) + pp_max_num = None + pp_max_num_high = None + for kv_start in range(0, v_slice_N.shape[0], context_size): + sub_len = context_size + if kv_start + context_size > v_slice_N.shape[0]: + sub_len = v_slice_N.shape[0] - kv_start + sub_key = k_slice_t_N[:, kv_start : kv_start + sub_len] + sub_value = v_slice_N[kv_start : kv_start + sub_len, :] + + if self.fav3: + qk_result = self.qkMM1(q_slice_N, sub_key, dtype = torch.int32) + qk_result_high = self.qkMM1(q_slice_N.to(torch.float32), sub_key.to(torch.float32), dtype = torch.int32) + else: + qk_result = self.qkMM1(q_slice_N, sub_key) + qk_result_high = self.qkMM1(q_slice_N.to(torch.float32), sub_key.to(torch.float32)) + + if self.fav3: + # score:[heads,m,n] + qk_result = qk_result.to(torch.float32) + qk_result_high = qk_result_high.to(torch.float32) + qk_result = qk_result * self.q_scale[head_idx] + qk_result_high = qk_result_high * self.q_scale[head_idx] + + if self.op_type in {UNPAD_FLASH_ATTENTION_ND, + UNPAD_DYNAMIC_BATCH_FLASH_ATTENTION, + UNPAD_FLASH_ATTENTION_ENCODER_ND, + UNPAD_ALIBI_FLASH_ATTENTION_ND}: + qk_result = qk_result.to(torch.float16) + else: + qk_result = qk_result.to(torch.float32) + + if self.scaleType == ScaleType.SCALE_LOGN_FP32.value: + if is_decoder: + qk_result = qk_result * self.decoder_logN[idx] + qk_result_high = qk_result_high * self.decoder_logN[idx] + else: + qk_result = qk_result * self.encoder_logN[None, :q_s, None] + qk_result_high = qk_result_high * self.encoder_logN[None, :q_s, None] + + qk_result = qk_result.to(self.data_type) * self.tor + qk_result_high = qk_result_high * self.tor + + if self.is_clamp == 1: + qk_result = torch.clamp(qk_result, min=self.clamp_min, max=self.clamp_max) + qk_result_high = torch.clamp(qk_result_high, min=self.clamp_min, max=self.clamp_max) + + #score + mask + temp_mask = self.mask_info[1](self.mask, idx, q_s, kv_s) * self.post_mask_coff + if is_mask or is_razor_fusion: + if isinstance(temp_mask, torch.Tensor): + if self.mask_type in {MASK_TYPE_ALIBI_WITH_BATCH, + MASK_TYPE_ALIBI_NO_BATCH, + MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, + MASK_TYPE_ALIBI_NO_BATCH_WITH_PREFIX}: + qk_result = qk_result + temp_mask[head_idx, :, kv_start:kv_start + sub_len] + qk_result_high = qk_result_high + temp_mask[head_idx, :, kv_start:kv_start + sub_len].to(torch.float32) + else: + if len(temp_mask.shape) == 2: + qk_result = qk_result + temp_mask[:, kv_start:kv_start + sub_len] + qk_result_high = qk_result_high + temp_mask[:, kv_start:kv_start + sub_len].to(torch.float32) + elif len(temp_mask.shape) == 3: + qk_result = qk_result + temp_mask[0, :, kv_start:kv_start + sub_len] + qk_result_high = qk_result_high + temp_mask[0, :, kv_start:kv_start + sub_len].to(torch.float32) + else: + qk_result = qk_result + temp_mask[0, 0, :, kv_start:kv_start + sub_len] + qk_result_high = qk_result_high + temp_mask[0, 0, :, kv_start:kv_start + sub_len].to(torch.float32) + + if self.is_int8_flag: + is_tail = (kv_start + context_size <= v_slice_N.shape[0]) + mi, li, Oi, pp_max_num = self.online_softmax_quant(qk_result, sub_value, head_idx, + is_tail, mi, li, Oi, pp_max_num, online, self.data_type) + mi_high, li_high, Oi_high, pp_max_num_high = self.online_softmax_quant(qk_result_high, sub_value, head_idx, + is_tail, mi_high, li_high, Oi_high, pp_max_num_high, online, torch.float32) + go = Oi + gl = li + go_high = Oi_high + gl_high = li_high + else: + if kv_start == 0: + gm = None + p_result, row_sum, dm, gm = self.softmax(qk_result, kv_start == 0, gm) + if kv_start == 0: + gm_high = None + p_result_high, row_sum_high, dm_high, gm_high = self.softmax(qk_result_high, kv_start == 0, gm_high) + lo = torch.matmul(p_result.to(torch.float32), sub_value.to(torch.float32)) + lo = lo.to(self.data_type) + lo_high = torch.matmul(p_result_high, sub_value.to(torch.float32)) + if kv_start == 0: + gl = row_sum + gl_high = row_sum_high + go = lo + go_high = lo_high + else: + dm = torch.exp(dm) + dm_high = torch.exp(dm_high) + gl = gl * dm + gl = gl + row_sum + + go = go * dm + go = go + lo + + gl_high = gl_high * dm_high + gl_high = gl_high + row_sum_high + + go_high = go_high * dm_high + go_high = go_high + lo_high + go = go / gl + go_high = go_high / gl_high + go_high = go_high.contiguous() + go = go.contiguous() + out_B[:, head_idx, :] = go + out_true_B[:, head_idx, :] = go_high + if idx == 0: + out = out_B + out_true = out_true_B + else: + out = torch.cat((out, out_B), 0) + out_true = torch.cat((out_true, out_true_B), 0) + + q_offset += q_s + k_offset += max_seq + v_offset += max_seq + # golden data + out = out.view(q_ntokens, heads * embedv) + out_true = out_true.view(q_ntokens, heads * embedv) + self.golden_out = out.to(self.data_type) + self.golden_out_true = out_true.to(torch.float32) + + if self.no_cache: + self.k = self.close_pack(self.k.to(torch.float32), kv_seqlen).to(self.data_type) + self.v = self.close_pack(self.v.to(torch.float32), kv_seqlen).to(self.data_type) + if self.fav3: + self.k_int8 = self.close_pack(self.k_int8.to(torch.float32), kv_seqlen).to(torch.int8) + self.v_int8 = self.close_pack(self.v_int8.to(torch.float32), kv_seqlen).to(torch.int8) + if self.long_seq: + self.max_seq = 128 + self.gen_mask(self.batch, self.heads, self.data_type, self.mask_type, 0, False, 0) + + def gen_out_tensor_bnsd_4_stage(self, online=False): + q_offset = 0 + k_offset = 0 + v_offset = 0 + batch = self.batch + dynamic_batch = self.dynamic_batch + batch_state = self.batch_state + heads = self.heads + is_decoder = self.is_decoder + embed = self.embeddim + embedv = self.embeddimv + max_seq = self.max_seq + q_seqlen = self.q_seqlen + kv_seqlen = self.kv_seqLen + kv_head = self.kv_head + mask = self.mask + is_mask = self.is_mask + q = self.q + k = self.k + v = self.v + q_ntokens = self.q_ntokens + kv_ntokens = self.kv_ntokens + layer_id = self.layer_id[0] + s = None + _p = None + out = None + out_true = None + obsnd = torch.zeros(batch, max_seq, heads, embedv) + out_true_bnsd = torch.zeros(batch, max_seq, heads, embedv) + kbsnd=k.view(layer_id+1,batch,max_seq,kv_head,embed) + vbsnd=v.view(layer_id+1,batch,max_seq,kv_head,embedv) + qbsnd = torch.zeros(batch, max_seq, heads, embed) + + self.encoder_logN = torch.tensor([2.0] * self.max_seq).to(torch.float32) + self.encoder_logN.uniform_(1, 2) + self.decoder_logN = torch.tensor([2.0] * batch).to(torch.float32) + self.decoder_logN.uniform_(1, 2) + for idx in range(batch): + if dynamic_batch and batch_state[idx] == 0 and not is_decoder: + continue + if dynamic_batch and batch_state[idx] == 0: + output = torch.zeros([heads, q_s, embedv]) + output = torch.permute(output, (1, 0, 2)) + if out is None: + out = output + else: + out = torch.cat((out, output), 0) + q_offset += q_s + k_offset += max_seq + v_offset += max_seq + continue + q_s = q_seqlen[idx] + kv_s = kv_seqlen[idx] + q_slice = q[q_offset:q_offset + q_s][:] + q_slice = q_slice.view(q_s, heads, embed) + for q_s_idx in range(q_s): + qbsnd[idx][q_s_idx] = q_slice[q_s_idx][:] + q_slice = torch.permute(q_slice, (1, 0, 2)) + k_slice = k[layer_id][idx][:kv_s][:] + k_slice = k_slice.view(kv_s, kv_head, embed) + k_slice_t = torch.permute(k_slice, (1, 2, 0)) + v_slice = v[layer_id][idx][:kv_s][:] + v_slice = v_slice.view(kv_s, kv_head, embedv) + v_slice = torch.permute(v_slice, (1, 0, 2)) + context_size = 128 + group_num = self.heads // self.kv_head + if group_num != 1: + k_slice_t = k_slice_t.repeat_interleave(group_num, dim=0) + v_slice = v_slice.repeat_interleave(group_num, dim=0) + out_B = torch.zeros([q_s, heads, embedv], dtype = self.data_type) + out_true_B = torch.zeros([q_s, heads, embedv], dtype = torch.float32) + for head_idx in range(heads): + q_slice_N = q_slice[head_idx, :, :] + k_slice_t_N = k_slice_t[head_idx, :, :] + v_slice_N = v_slice[head_idx, :, :] + gl = None + gl_high = None + go = None + go_high = None + for kv_start in range(0, v_slice_N.shape[0], context_size): + sub_len = context_size + if kv_start + context_size > v_slice_N.shape[0]: + sub_len = v_slice_N.shape[0] - kv_start + sub_key = k_slice_t_N[:, kv_start : kv_start + sub_len] + sub_value = v_slice_N[kv_start : kv_start + sub_len, :] + + qk_result = self.qkMM1(q_slice_N, sub_key) + qk_result_high = self.qkMM1(q_slice_N.to(torch.float32), sub_key.to(torch.float32)) + + if self.op_type in {UNPAD_FLASH_ATTENTION_ND, + UNPAD_DYNAMIC_BATCH_FLASH_ATTENTION, + UNPAD_FLASH_ATTENTION_ENCODER_ND, + UNPAD_ALIBI_FLASH_ATTENTION_ND}: + qk_result = qk_result.to(torch.float16) + else: + qk_result = qk_result.to(torch.float32) + + if self.scaleType == ScaleType.SCALE_LOGN_FP32.value: + if is_decoder: + qk_result = qk_result * self.decoder_logN[idx] + qk_result_high = qk_result_high * self.decoder_logN[idx] + else: + qk_result = qk_result * self.encoder_logN[None, :q_s, None] + qk_result_high = qk_result_high * self.encoder_logN[None, :q_s, None] + + qk_result = qk_result.to(self.data_type) * self.tor + qk_result_high = qk_result_high * self.tor + + if self.is_clamp == 1: + qk_result = torch.clamp(qk_result, min=self.clamp_min, max=self.clamp_max) + qk_result_high = torch.clamp(qk_result_high, min=self.clamp_min, max=self.clamp_max) + + #score + mask + temp_mask = self.mask_info[1](self.mask, idx, q_s, kv_s) * self.post_mask_coff + if is_mask and isinstance(temp_mask, torch.Tensor): + if self.mask_type in {MASK_TYPE_ALIBI_WITH_BATCH, + MASK_TYPE_ALIBI_NO_BATCH, + MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, + MASK_TYPE_ALIBI_NO_BATCH_WITH_PREFIX}: + qk_result = qk_result + temp_mask[head_idx, :, kv_start:kv_start + sub_len] + qk_result_high = qk_result_high + temp_mask[head_idx, :, kv_start:kv_start + sub_len].to(torch.float32) + else: + if len(temp_mask.shape) == 2: + qk_result = qk_result + temp_mask[:, kv_start:kv_start + sub_len] + qk_result_high = qk_result_high + temp_mask[:, kv_start:kv_start + sub_len].to(torch.float32) + elif len(temp_mask.shape) == 3: + qk_result = qk_result + temp_mask[0, :, kv_start:kv_start + sub_len] + qk_result_high = qk_result_high + temp_mask[0, :, kv_start:kv_start + sub_len].to(torch.float32) + else: + qk_result = qk_result + temp_mask[0, 0, :, kv_start:kv_start + sub_len] + qk_result_high = qk_result_high + temp_mask[0, 0, :, kv_start:kv_start + sub_len].to(torch.float32) + + if kv_start == 0: + gm = None + p_result, row_sum, dm, gm = self.softmax(qk_result, kv_start == 0, gm) + if kv_start == 0: + gm_high = None + p_result_high, row_sum_high, dm_high, gm_high = self.softmax(qk_result_high, kv_start == 0, gm_high) + lo = torch.matmul(p_result.to(torch.float32), sub_value.to(torch.float32)) + lo = lo.to(self.data_type) + lo_high = torch.matmul(p_result_high, sub_value.to(torch.float32)) + + if kv_start == 0: + gl = row_sum + gl_high = row_sum_high + go = lo + go_high = lo_high + else: + dm = torch.exp(dm) + dm_high = torch.exp(dm_high) + gl = gl * dm + gl = gl + row_sum + + go = go * dm + go = go + lo + + gl_high = gl_high * dm_high + gl_high = gl_high + row_sum_high + + go_high = go_high * dm_high + go_high = go_high + lo_high + go = go / gl + go_high = go_high / gl_high + go_high = go_high.contiguous() + go = go.contiguous() + out_B[:, head_idx, :] = go + out_true_B[:, head_idx, :] = go_high + if idx == 0: + out = out_B + out_true = out_true_B + else: + out = torch.cat((out, out_B), 0) + out_true = torch.cat((out_true, out_true_B), 0) + + for i in range(0, q_s): + obsnd[idx][i] = out_B[i] + out_true_bnsd[idx][i] =out_true_B[i] + + q_offset += q_s + k_offset += max_seq + v_offset += max_seq + obnsd = torch.permute(obsnd, (0, 2, 1, 3)) + out_true_bnsd = torch.permute(out_true_bnsd, (0, 2, 1, 3)) + self.qbnsd = torch.permute(qbsnd, (0, 2, 1, 3)).to(self.data_type) + self.kbnsd = torch.permute(kbsnd, (0, 1, 3, 2, 4)).to(self.data_type) + self.vbnsd = torch.permute(vbsnd, (0, 1, 3, 2, 4)).to(self.data_type) + # golden data + out = out.view(q_ntokens, heads * embedv) + out_true = out_true.view(q_ntokens, heads * embedv) + if(self.is_decoder == 1): + self.golden_out = out + self.golden_out_true = out_true.to(torch.float32) + else: + self.golden_out = obnsd.to(self.data_type) + self.golden_out_true = out_true_bnsd.to(torch.float32) + logging.debug(f"golden_out shape: {self.golden_out.shape}") + + if self.no_cache: + self.k = self.close_pack(self.k.to(torch.float32), kv_seqlen).to(self.data_type) + self.v = self.close_pack(self.v.to(torch.float32), kv_seqlen).to(self.data_type) + if self.long_seq: + self.max_seq = 128 + self.gen_mask(self.batch, self.heads, self.data_type, self.mask_type) + + def gen_out_tensor_bnsd_splitm_4_stage(self, online=False): + q_offset = 0 + k_offset = 0 + v_offset = 0 + batch = self.batch + dynamic_batch = self.dynamic_batch + batch_state = self.batch_state + heads = self.heads + is_decoder = self.is_decoder + embed = self.embeddim + embedv = self.embeddimv + max_seq = self.max_seq + q_seqlen = self.q_seqlen + kv_seqlen = self.kv_seqLen + kv_head = self.kv_head + mask = self.mask + is_mask = self.is_mask + q = self.q + k = self.k + v = self.v + q_ntokens = self.q_ntokens + kv_ntokens = self.kv_ntokens + layer_id = self.layer_id[0] + out = None + out_true = None + maxQSeqlen = max(q_seqlen) + obsnd = torch.zeros(batch, maxQSeqlen, heads, embedv) + out_true_bnsd = torch.zeros(batch, maxQSeqlen, heads, embedv) + maxKvSeqlen = max(kv_seqlen) + kbsnd=k.view(layer_id+1,batch,maxKvSeqlen,kv_head,embed) + vbsnd=v.view(layer_id+1,batch,maxKvSeqlen,kv_head,embedv) + qbsnd = torch.zeros(batch, maxQSeqlen, heads, embed) + for idx in range(batch): + if dynamic_batch and batch_state[idx] == 0 and not is_decoder: + continue + if dynamic_batch and batch_state[idx] == 0: + output = torch.zeros([heads, q_s, embedv]) + output = torch.permute(output, (1, 0, 2)) + if out is None: + out = output + else: + out = torch.cat((out, output), 0) + q_offset += q_s + k_offset += max_seq + v_offset += max_seq + continue + q_s = q_seqlen[idx] + kv_s = kv_seqlen[idx] + q_slice = q[q_offset:q_offset + q_s][:] + q_slice = q_slice.view(q_s, heads, embed) + for q_s_idx in range(q_s): + qbsnd[idx][q_s_idx] = q_slice[q_s_idx][:] + q_slice = torch.permute(q_slice, (1, 0, 2)) + k_slice = k[layer_id][idx][:kv_s][:] + k_slice = k_slice.view(kv_s, kv_head, embed) + k_slice_t = torch.permute(k_slice, (1, 2, 0)) + v_slice = v[layer_id][idx][:kv_s][:] + v_slice = v_slice.view(kv_s, kv_head, embedv) + v_slice = torch.permute(v_slice, (1, 0, 2)) + context_size = 128 + group_num = self.heads // self.kv_head + if group_num != 1: + k_slice_t = k_slice_t.repeat_interleave(group_num, dim=0) + v_slice = v_slice.repeat_interleave(group_num, dim=0) + out_B = torch.zeros([q_s, heads, embedv], dtype = self.data_type) + out_true_B = torch.zeros([q_s, heads, embedv], dtype = torch.float32) + for head_idx in range(heads): + q_slice_N = q_slice[head_idx, :, :] + k_slice_t_N = k_slice_t[head_idx, :, :] + v_slice_N = v_slice[head_idx, :, :] + gl = None + gl_high = None + go = None + go_high = None + for kv_start in range(0, v_slice_N.shape[0], context_size): + sub_len = context_size + if kv_start + context_size > v_slice_N.shape[0]: + sub_len = v_slice_N.shape[0] - kv_start + sub_key = k_slice_t_N[:, kv_start : kv_start + sub_len] + sub_value = v_slice_N[kv_start : kv_start + sub_len, :] + + qk_result = self.qkMM1(q_slice_N, sub_key) + qk_result_high = self.qkMM1(q_slice_N.to(torch.float32), sub_key.to(torch.float32)) + + if self.op_type in {UNPAD_FLASH_ATTENTION_ND, + UNPAD_DYNAMIC_BATCH_FLASH_ATTENTION, + UNPAD_FLASH_ATTENTION_ENCODER_ND, + UNPAD_ALIBI_FLASH_ATTENTION_ND}: + qk_result = qk_result.to(torch.float16) + else: + qk_result = qk_result.to(torch.float32) + + qk_result = qk_result.to(self.data_type) * self.tor + qk_result_high = qk_result_high * self.tor + + if kv_start == 0: + gm = None + p_result, row_sum, dm, gm = self.softmax(qk_result, kv_start == 0, gm) + if kv_start == 0: + gm_high = None + p_result_high, row_sum_high, dm_high, gm_high = self.softmax(qk_result_high, kv_start == 0, gm_high) + lo = torch.matmul(p_result.to(torch.float32), sub_value.to(torch.float32)) + lo = lo.to(self.data_type) + lo_high = torch.matmul(p_result_high, sub_value.to(torch.float32)) + + if kv_start == 0: + gl = row_sum + gl_high = row_sum_high + go = lo + go_high = lo_high + else: + dm = torch.exp(dm) + dm_high = torch.exp(dm_high) + gl = gl * dm + gl = gl + row_sum + + go = go * dm + go = go + lo + + gl_high = gl_high * dm_high + gl_high = gl_high + row_sum_high + + go_high = go_high * dm_high + go_high = go_high + lo_high + + go = go / gl + go_high = go_high / gl_high + go_high = go_high.contiguous() + go = go.contiguous() + + out_B[:, head_idx, :] = go + out_true_B[:, head_idx, :] = go_high + if idx == 0: + out = out_B + out_true = out_true_B + else: + out = torch.cat((out, out_B), 0) + out_true = torch.cat((out_true, out_true_B), 0) + + for i in range(0, q_s): + obsnd[idx][i] = out_B[i] + out_true_bnsd[idx][i] =out_true_B[i] + + q_offset += q_s + k_offset += kv_s + v_offset += kv_s + obnsd = torch.permute(obsnd, (0, 2, 1, 3)) + out_true_bnsd = torch.permute(out_true_bnsd, (0, 2, 1, 3)) + self.qbnsd = torch.permute(qbsnd, (0, 2, 1, 3)).to(self.data_type) + self.kbnsd = torch.permute(kbsnd, (0, 1, 3, 2, 4)).to(self.data_type) + self.vbnsd = torch.permute(vbsnd, (0, 1, 3, 2, 4)).to(self.data_type) + + out = out.view(q_ntokens, heads * embedv) + out_true = out_true.view(q_ntokens, heads * embedv) + + self.golden_out = obnsd.to(self.data_type) + self.golden_out_true = out_true_bnsd.to(torch.float32) + logging.debug(f"golden_out shape: {self.golden_out.shape}") + + if self.no_cache: + self.k = self.close_pack(self.k.to(torch.float32), kv_seqlen).to(self.data_type) + self.v = self.close_pack(self.v.to(torch.float32), kv_seqlen).to(self.data_type) + def gen_seq_len(self, batch, seq_len): ntokens = sum(seq_len) return seq_len, ntokens @@ -1081,7 +1764,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False,scaleType = scaleType, is_splitm = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, is_mask = False, tor = tor) - self.gen_out_tensor_bnsd_splitm() + self.gen_out_tensor_bnsd_splitm_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") logging.debug(f"k shape: {self.k.shape}") @@ -1123,7 +1806,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, is_splitm = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, is_mask = False, tor = tor) - self.gen_out_tensor_bnsd_splitm() + self.gen_out_tensor_bnsd_splitm_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") logging.debug(f"k shape: {self.k.shape}") @@ -1168,7 +1851,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, long_seq = True, scaleType = scaleType, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 128, 128)[0, :, :] attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v, self.layer_id, self.mask.to(data_type), torch.tensor([], dtype=torch.float), @@ -1216,7 +1899,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, is_mask =True, tor = tor, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_SWA_DECODER, is_multi_layer = False) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -1266,7 +1949,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, is_mask =True, tor = tor, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_SWA_DECODER, is_multi_layer = False) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -1319,7 +2002,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, window_size = window_size, no_cache=True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_SWA) # self.golden_out = np.zeros_like(self.q) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -1367,7 +2050,7 @@ class TestFlashAttention(op_test.OpTest): embeddim = embeddim, max_seq = max_seq, kv_seqLen = kv_seqLen, tor = tor, data_type = data_type, is_alibi = False, window_size = window_size, no_cache=True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_SWA) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() if self.is_compress: self.mask = self.gen_swa_cmp(max_seq, window_size) @@ -1424,7 +2107,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, window_size = window_size, no_cache=True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_SWA) # self.golden_out = np.zeros_like(self.q) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -1473,7 +2156,7 @@ class TestFlashAttention(op_test.OpTest): embeddim = embeddim, max_seq = max_seq, kv_seqLen = kv_seqLen, tor = tor, data_type = data_type, is_alibi = False, window_size = window_size, no_cache=True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_SWA) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() if self.is_compress: self.mask = self.gen_swa_cmp(max_seq, window_size) @@ -1523,7 +2206,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor_bnsd() + self.gen_out_tensor_bnsd_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -1555,7 +2238,7 @@ class TestFlashAttention(op_test.OpTest): dynamic_batch = False OP_NAME = "UnpadFlashAttentionOperation" OP_PARAM = {"type": 2001, "qSeqLen": kv_seqLen, "kvSeqLen": kv_seqLen, "headSize": heads, "tor": tor, "kvHead":kv_head, - "isClamp" : is_clamp, "clampMin" : clamp_min, "maskType": 0, "clampMax" : clamp_max, "isTriuMask": 1, + "isClamp" : is_clamp, "clampMin" : clamp_min, "maskType": 0, "clampMax" : clamp_max, "isTriuMask": 0, "dataShapeType":1} self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 12) @@ -1570,10 +2253,10 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, is_mask = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, tor = tor) - self.gen_out_tensor_bnsd() + self.gen_out_tensor_bnsd_4_stage() attention_out = np.zeros_like(self.qbnsd.to(torch.float16)) - return self.execute([self.qbnsd, self.kbnsd, self.vbnsd, self.layer_id, self.mask.to(data_type), torch.tensor([], dtype=torch.float), + return self.execute([self.qbnsd, self.kbnsd, self.vbnsd, self.layer_id, torch.tensor([], dtype=data_type), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float)], @@ -1609,7 +2292,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, long_seq = True, scaleType = scaleType, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 128, 128)[0, :, :] attention_out = np.zeros_like(self.q.to(torch.float16)) @@ -1650,7 +2333,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, long_seq = True, scaleType = scaleType, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v, self.layer_id, torch.tensor([], dtype=torch.bfloat16), torch.tensor([], dtype=torch.float), @@ -1690,7 +2373,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, long_seq = True, scaleType = scaleType, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 128, 128)[0, :, :] logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -1733,7 +2416,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, scaleType=scaleType, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -1777,7 +2460,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH ,tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -1822,7 +2505,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor_bnsd() + self.gen_out_tensor_bnsd_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -1866,7 +2549,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor_bnsd() + self.gen_out_tensor_bnsd_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -1911,7 +2594,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = np.reshape(self.k, (batch, max_seq, heads * embeddim)) self.v = np.reshape(self.v, (batch, max_seq, heads * embeddim)) @@ -1961,7 +2644,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = np.reshape(self.k, (batch, max_seq, heads * embeddim)) self.v = np.reshape(self.v, (batch, max_seq, heads * embeddim)) @@ -2009,7 +2692,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (batch, heads, max_seq, max_seq)) @@ -2056,7 +2739,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) @@ -2102,7 +2785,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor_bnsd() + self.gen_out_tensor_bnsd_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -2148,7 +2831,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, is_mask = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD_DECODER, is_multi_layer = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (batch, 1, max_seq)) @@ -2195,7 +2878,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = np.reshape(self.k, (batch, max_seq, kv_head * embeddim)) self.v = np.reshape(self.v, (batch, max_seq, kv_head * embeddim)) @@ -2243,7 +2926,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) @@ -2290,7 +2973,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD_DECODER, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (batch, 1, max_seq)) @@ -2338,7 +3021,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, long_seq = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(128, 128).to(data_type) @@ -2386,7 +3069,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, is_mask = False, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) @@ -2435,7 +3118,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, is_mask = False, tor = tor) - self.gen_out_tensor_bnsd() + self.gen_out_tensor_bnsd_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) @@ -2483,7 +3166,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, long_seq = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) @@ -2524,7 +3207,7 @@ class TestFlashAttention(op_test.OpTest): self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 12) self.set_output_formats([self.format_nd]) - data_type = torch.float32 + data_type = torch.float16 self.set_data_params(dynamic_batch = dynamic_batch, is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, @@ -2532,9 +3215,10 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) + self.mask[self.mask == -10000] = 1 logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -2571,7 +3255,7 @@ class TestFlashAttention(op_test.OpTest): self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 12) self.set_output_formats([self.format_nd]) - data_type = torch.float32 + data_type = torch.float16 self.set_data_params(dynamic_batch = dynamic_batch, is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, @@ -2579,9 +3263,10 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) + self.mask[self.mask == -10000] = 1 logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -2627,7 +3312,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) @@ -2668,7 +3353,7 @@ class TestFlashAttention(op_test.OpTest): self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 12) self.set_output_formats([self.format_nd]) - data_type = torch.float32 + data_type = torch.float16 self.set_data_params(dynamic_batch = dynamic_batch, is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, @@ -2676,9 +3361,10 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD_DECODER, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (batch, 1, max_seq)) + self.mask[self.mask == -10000] = 1 logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -2723,7 +3409,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), torch.tensor([], dtype=torch.half), torch.tensor([], dtype=torch.float), @@ -2763,7 +3449,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, is_triu_mask = True , tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), @@ -2802,7 +3488,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, long_seq= True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 128, 128)[0, :, :] attention_out = np.zeros_like(self.q) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), @@ -2839,7 +3525,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), @@ -2878,7 +3564,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), @@ -2917,7 +3603,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), @@ -2955,7 +3641,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), @@ -2995,7 +3681,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) self.q = self.q.view(sum(kv_seqLen), heads, embeddim) self.k = self.k.view(sum(kv_seqLen), heads, embeddim) @@ -3037,7 +3723,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, long_seq = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 128, 128)[0, :, :] attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), @@ -3075,7 +3761,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_NO_BATCH, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), @@ -3111,7 +3797,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(max_seq, max_seq).to(data_type) @@ -3157,7 +3843,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, long_seq = True, is_multi_layer = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(128, 128).to(data_type) @@ -3202,7 +3888,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD_DECODER, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 1, max_seq).to(data_type) logging.debug("**********input shape***********") @@ -3247,7 +3933,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, is_mask = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD_DECODER, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 1, max_seq).to(data_type) logging.debug("**********input shape***********") @@ -3292,7 +3978,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -3339,7 +4025,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = np.reshape(self.k, (batch, max_seq, heads * embeddim)) self.v = np.reshape(self.v, (batch, max_seq, heads * embeddim)) @@ -3387,7 +4073,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), @@ -3417,7 +4103,7 @@ class TestFlashAttention(op_test.OpTest): self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 12) self.set_output_formats([self.format_nd]) - data_type = torch.float32 + data_type = torch.float16 self.set_data_params(dynamic_batch = dynamic_batch, is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, @@ -3425,7 +4111,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") logging.debug(f"k shape: {self.k.shape}") @@ -3471,7 +4157,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = np.reshape(self.k, (batch, max_seq, heads * embeddim)) self.v = np.reshape(self.v, (batch, max_seq, heads * embeddim)) @@ -3519,7 +4205,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) @@ -3569,7 +4255,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) @@ -3620,7 +4306,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = is_alibi, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) @@ -3667,7 +4353,7 @@ class TestFlashAttention(op_test.OpTest): embeddim = embeddim, max_seq = max_seq, kv_seqLen = kv_seqLen, is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) self.mask = self.mask.view(batch, max_seq, max_seq).to(data_type) @@ -3714,7 +4400,7 @@ class TestFlashAttention(op_test.OpTest): embeddim = embeddim, max_seq = max_seq, kv_seqLen = kv_seqLen, is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, kv_head * embeddim) self.v = self.v.view(batch, max_seq, kv_head * embeddim) self.mask = self.mask.view(batch, max_seq, max_seq).to(data_type) @@ -3762,7 +4448,7 @@ class TestFlashAttention(op_test.OpTest): embeddim = embeddim, max_seq = max_seq, kv_seqLen = kv_seqLen, is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) self.mask = self.mask.view(batch, max_seq, max_seq).to(data_type) @@ -3809,7 +4495,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -3854,7 +4540,7 @@ class TestFlashAttention(op_test.OpTest): embeddim = embeddim, max_seq = max_seq, kv_seqLen = kv_seqLen, is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.k = self.k.view(batch, max_seq, heads * embeddim) self.v = self.v.view(batch, max_seq, heads * embeddim) self.mask = self.mask.view(batch, max_seq, max_seq).to(data_type) @@ -3904,7 +4590,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * 60000 @@ -3950,7 +4636,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, no_cache = True, is_sqrt = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * 60000 @@ -3996,7 +4682,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, no_cache = True, is_sqrt = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask[0, :, :, :128].to(torch.bfloat16) attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask, torch.tensor([], dtype=torch.float32), @@ -4036,7 +4722,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_BATCH, no_cache = True, left_align = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() mask = np.ones((256,256)) * -(float("inf")) mask = np.triu(mask, 1) self.mask = self.bias[0, :256, :256] + mask @@ -4078,7 +4764,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), @@ -4115,7 +4801,7 @@ class TestFlashAttention(op_test.OpTest): # embeddim = embeddim,embeddimv = embeddimv, max_seq = max_seq, kv_seqLen = kv_seqLen, # is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, # op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - # self.gen_out_tensor() + # self.gen_out_tensor_4_stage() # self.k = self.k.view(batch, max_seq, heads * embeddim) # self.v = self.v.view(batch, max_seq, heads * embeddimv) # self.mask = self.mask.view(batch, max_seq, max_seq).to(data_type) @@ -4163,7 +4849,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -4209,7 +4895,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -4256,7 +4942,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, long_seq = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 128, 128)[0, :, :] attention_out = np.zeros_like(self.q.to(torch.float16)) @@ -4296,7 +4982,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, long_seq = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 128, 128)[0, :, :] attention_out = np.zeros_like(self.q.to(torch.float16)) @@ -4336,7 +5022,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, long_seq= True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 128, 128)[0, :, :] attention_out = np.zeros_like(self.golden_out) @@ -4369,7 +5055,7 @@ class TestFlashAttention(op_test.OpTest): self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 12) self.set_output_formats([self.format_nd]) - data_type = torch.float32 + data_type = torch.float16 self.set_data_params(dynamic_batch = dynamic_batch, is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, @@ -4377,7 +5063,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = True, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") logging.debug(f"k shape: {self.k.shape}") @@ -4422,7 +5108,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -4467,7 +5153,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() logging.debug("**********input shape***********") logging.debug(f"q shape: {self.q.shape}") @@ -4515,7 +5201,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, fav3 = fav3, is_mask = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) if not self.fav3: return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), @@ -4564,7 +5250,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, fav3 = fav3, is_mask = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) if not self.fav3: return self.execute([self.q, self.k, self.v, torch.tensor([], dtype=torch.int), self.mask.to(data_type), torch.tensor([], dtype=torch.float), @@ -4614,7 +5300,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, fav3 = fav3, is_mask = True, tor = tor) - self.gen_out_tensor(online) + self.gen_out_tensor_4_stage(online) attention_out = np.zeros_like(self.q.to(torch.float16)) if not self.fav3: @@ -4666,7 +5352,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, no_cache = True, fav3 = fav3, is_mask = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) if not self.fav3: @@ -4713,7 +5399,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, long_seq = True, scaleType = scaleType, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v, self.layer_id, torch.tensor([], dtype=torch.bfloat16), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), @@ -4754,7 +5440,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_NO_BATCH_WITH_PREFIX, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.to(torch.float16) attention_out = np.zeros_like(self.q.to(torch.float16)) logging.info(f"self.mask: {self.mask}") @@ -4797,7 +5483,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * float("inf") mask = np.triu(mask, 1) @@ -4843,7 +5529,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * float("inf") mask = np.triu(mask, 1) @@ -4889,7 +5575,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * float("inf") mask = np.triu(mask, 1) @@ -4933,7 +5619,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, no_cache = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k_cache, self.v_cache, self.block_tables, torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float)], @@ -4973,7 +5659,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_MASK, no_cache = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k_cache, self.v_cache, self.block_tables, torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float)], @@ -5013,7 +5699,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, left_align = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() mask = np.ones((256,256)) * -(float("inf")) mask = np.triu(mask, 1) self.mask = self.bias[0, :256, :256] + mask @@ -5057,7 +5743,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, left_align = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() mask = np.ones((256,256)) * -(float("inf")) mask = np.triu(mask, 1) self.mask = self.bias[0, :256, :256] + mask @@ -5100,7 +5786,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (batch, heads, max_seq, max_seq)) self.mask = self.mask.to(torch.bfloat16) attention_out = np.zeros_like(self.q.to(torch.float16)) @@ -5142,7 +5828,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH_WITH_PREFIX, no_cache = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(128, 128).to(data_type) attention_out = np.zeros_like(self.q.to(torch.float16)) logging.info(f"self.q: {self.q}") @@ -5188,7 +5874,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH_WITH_PREFIX, no_cache = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(128, 128).to(data_type) attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k_cache, self.v_cache, self.block_tables, @@ -5226,7 +5912,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH_WITH_PREFIX, no_cache = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -5271,7 +5957,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH_WITH_PREFIX, no_cache = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -5320,7 +6006,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * float("inf") mask = np.triu(mask, 1) @@ -5366,7 +6052,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = False, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * float("inf") mask = np.triu(mask, 1) @@ -5412,7 +6098,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = False, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * float("inf") mask = np.triu(mask, 1) @@ -5458,7 +6144,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = False, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * float("inf") mask = np.triu(mask, 1) @@ -5504,7 +6190,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = False, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 mask = np.ones((256,256)) * float("inf") mask = np.triu(mask, 1) @@ -5550,7 +6236,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = False, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 self.mask = self.mask[0, :, :, :128].to(torch.bfloat16) attention_out = np.zeros_like(self.q.to(torch.float16)) @@ -5593,7 +6279,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = False, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 self.mask = self.mask[0, :, :, :128].to(torch.bfloat16) attention_out = np.zeros_like(self.q.to(torch.float16)) @@ -5636,7 +6322,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = False, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 self.mask = self.mask[0, :, :, :128].to(torch.bfloat16) attention_out = np.zeros_like(self.q.to(torch.float16)) @@ -5679,7 +6365,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_ALIBI_WITH_PREFIX_BATCH, no_cache = True, is_sqrt = False, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.alibi_slopes *= -1 self.mask = self.mask[0, :, :, :128].to(torch.bfloat16) attention_out = np.zeros_like(self.q.to(torch.float16)) @@ -5734,7 +6420,7 @@ class TestFlashAttention(op_test.OpTest): data_type = data_type, is_alibi = True, is_mask = False, q_seqlens=q_seqLen, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_RAZOR_FUSION, no_cache = True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k, self.v], @@ -5768,7 +6454,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_HEAD, long_seq= True, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = self.mask.view(batch, 128, 128)[0, :, :] self.q = self.q.reshape(3072, 128, 192) logging.debug("**********input shape***********") @@ -5820,7 +6506,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH_WITH_PREFIX, no_cache = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size, is_triu_mask = True, is_mask=True) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) @@ -5863,7 +6549,7 @@ class TestFlashAttention(op_test.OpTest): op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH_WITH_PREFIX, no_cache = True, tor = tor, q_seqlens = q_seqlens, num_blocks = num_blocks, block_size = block_size, is_triu_mask = True, is_mask=True) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() attention_out = np.zeros_like(self.q.to(torch.float16)) return self.execute([self.q, self.k_cache, self.v_cache, self.block_tables, torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float)], @@ -5900,7 +6586,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, window_size = window_size, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -5948,7 +6634,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, window_size = window_size, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() self.mask = np.reshape(self.mask, (max_seq, max_seq)) logging.debug("**********input shape***********") @@ -5996,7 +6682,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, window_size = window_size, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() if self.is_compress: self.mask = self.gen_swa_cmp(max_seq, window_size) self.mask = np.reshape(self.mask, (512, 512)) @@ -6046,7 +6732,7 @@ class TestFlashAttention(op_test.OpTest): is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, data_type = data_type, is_alibi = False, window_size = window_size, op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor) - self.gen_out_tensor() + self.gen_out_tensor_4_stage() if self.is_compress: self.mask = self.gen_swa_cmp(max_seq, window_size) self.mask = np.reshape(self.mask, (512, 512))