diff --git a/tests/apitest/kernelstest/ring_mla/ring_mla.py b/tests/apitest/kernelstest/ring_mla/ring_mla.py index 25626823ea266321ca7ef9789b2efed4d83511aa..200d0a28b69228fec0cc4e60e868ce0cdd5b8367 100644 --- a/tests/apitest/kernelstest/ring_mla/ring_mla.py +++ b/tests/apitest/kernelstest/ring_mla/ring_mla.py @@ -48,22 +48,22 @@ def softmax1( is_first, gm ): - sim = qk_result.numpy() - lm = np.max(sim, axis=-1, keepdims=True) + sim = qk_result + lm = torch.max(sim, dim=-1, keepdim=True)[0] + if is_first: hm = lm - dm = 0 + dm = torch.zeros_like(hm) else: - hm = np.maximum(gm, lm) + hm = torch.max(gm, lm) dm = gm - hm + gm = hm sim_sub = sim - hm - sim_sub = np.exp(sim_sub.astype(np.float32)) - if qk_result.dtype == torch.float16: - sim_sub = sim_sub.astype(np.float16) + sim_sub = torch.exp(sim_sub.to(torch.float32)) - row_sum = np.sum(sim_sub, axis=-1, keepdims=True) - return torch.from_numpy(sim_sub), row_sum, dm, gm + row_sum = torch.sum(sim_sub, dim=-1, keepdim=True) + return sim_sub, row_sum, dm, gm def qkMM1( @@ -143,18 +143,18 @@ def ref_flash_attention( gm_high = None p_result_high, row_sum_high, dm_high, gm_high = softmax1(qk_result_high, kv_start == 0, gm_high) lo = torch.matmul(p_result.to(torch.float32), sub_value.to(torch.float32)) - lo = lo.to(data_type) + # lo = lo.to(data_type) lo_high = torch.matmul(p_result_high, sub_value.to(torch.float32)) - lo = lo.numpy() - lo_high = lo_high.numpy() + # lo = lo.numpy() + # lo_high = lo_high.numpy() if kv_start == 0: gl = row_sum gl_high = row_sum_high go = lo go_high = lo_high else: - dm = np.exp(dm) - dm_high = np.exp(dm_high) + dm = torch.exp(dm) + dm_high = torch.exp(dm_high) gl = gl * dm gl = gl + row_sum @@ -168,9 +168,9 @@ def ref_flash_attention( go_high = go_high + lo_high go = go / gl go_high = go_high / gl_high - go = np.transpose(go, (1, 0, 2)) - go_high = np.transpose(go_high, (1, 0, 2)) - return torch.from_numpy(go), torch.from_numpy(go_high), gl, gm + go = torch.permute(go, (1, 0, 2)) + go_high = torch.permute(go_high, (1, 0, 2)) + return go, go_high, gl, gm @@ -778,7 +778,7 @@ class TestMLAPrefill(op_test.OpTest): self.encoder_logN.uniform_(1, 2) self.decoder_logN = torch.tensor([2.0] * batch).to(torch.float32) self.decoder_logN.uniform_(1, 2) - self.new_les = None + self.new_lse = None self.new_lse_height = None out_height = None @@ -804,20 +804,20 @@ class TestMLAPrefill(op_test.OpTest): kv_s = kv_seqlen[idx] if kv_s == 0: o = torch.zeros(size=(q_s, heads, embedv), dtype=self.data_type) - if out is None: + if out == None: out = o - if not self.fav3: - out_true = o_true else: out = torch.cat((out, o), 0) - if not self.fav3: - out_true = torch.cat((out_true, o_true), 0) + if out_height == None: + out_height = o + else: + out_height = torch.cat((out_height, o), 0) q_offset += q_s - lse = torch.zeros(size=(self.heads, q_s), dtype = torch.float32) - if self.new_les is None: - self.new_les = lse + lse = torch.zeros(size=(self.heads, q_s, 1), dtype = torch.float32) + if self.new_lse is None: + self.new_lse = lse else: - self.new_les = np.concatenate((self.new_les, lse), axis=-1) # shape is (heads, bs) + self.new_lse = np.concatenate((self.new_lse, lse), axis=1) # shape is (heads, bs) continue q_slice = q[q_offset:q_offset + q_s][:] q_slice_ori = q_slice.view(q_s, heads, embed) @@ -836,7 +836,7 @@ class TestMLAPrefill(op_test.OpTest): temp_mask = self.mask.repeat(self.heads, 1, 1)[:, :q_s, :] context_len = k_slice_ori.shape[0] - new_out, new_out_height, gl, gm = ref_flash_attention(q_slice_ori, k_slice_ori, v_slice_ori, self.tor, temp_mask, context_len=context_len, mask_type=self.mask_type) + new_out, new_out_height, gl, gm = ref_flash_attention(q_slice_ori, k_slice_ori, v_slice_ori, self.tor, temp_mask, context_len=context_len, mask_type=self.mask_type, data_type=self.data_type) if out == None: out = new_out @@ -848,22 +848,21 @@ class TestMLAPrefill(op_test.OpTest): else: out_height = torch.cat((out_height, new_out_height), 0) - lse = np.log(gl) + gm + lse = torch.log(gl) + gm - if self.new_les is None: - self.new_les = lse + if self.new_lse is None: + self.new_lse = lse else: - self.new_les = np.concatenate((self.new_les, lse), axis=1) # shape is (heads, bs) + self.new_lse = torch.cat((self.new_lse, lse), dim=1) # shape is (heads, bs) q_offset += q_s k_offset += max_seq v_offset += max_seq # golden data - self.new_les = torch.from_numpy(np.array(self.new_les)) # (bs, head, 1) - self.new_les = torch.squeeze(self.new_les, dim=-1) - self.new_les = self.new_les.permute(1, 0).contiguous() - self.out_lse = self.new_les.reshape(self.q_ntokens * self.heads, 1) # (bs * head, 1) + self.new_lse = torch.squeeze(self.new_lse, dim=-1) + self.new_lse = self.new_lse.permute(1, 0).contiguous() + self.out_lse = self.new_lse.reshape(self.q_ntokens * self.heads, 1) # (bs * head, 1) if self.is_int8_flag: ans_concat = ans_concat.view(q_ntokens, heads * embedv) @@ -877,15 +876,15 @@ class TestMLAPrefill(op_test.OpTest): self.golden_out = out.to(self.data_type) out_true = out_height.view(q_ntokens, heads * embedv) self.golden_out_true = out_true.to(torch.float32) - self.new_les = self.new_les.transpose(1, 0) + self.new_lse = self.new_lse.transpose(1, 0) else: out = out.view(q_ntokens, heads * embedv) - new_out, new_lse = self.update_out(out, self.new_les) + new_out, new_lse = self.update_out(out, self.new_lse) self.golden_out = new_out.to(self.data_type) - self.new_les_low = new_lse + self.new_lse_low = new_lse out_true = out_height.view(q_ntokens, heads * embedv) - new_out, new_lse_height = self.update_out(out_true, self.new_les) + new_out, new_lse_height = self.update_out(out_true, self.new_lse) self.golden_out_true = new_out.to(torch.float32) self.new_lse_height = new_lse_height @@ -980,7 +979,7 @@ class TestMLAPrefill(op_test.OpTest): def golden_calc(self, in_tensors): golden_out = torch.tensor(self.golden_out) if self.isring == 0: - lse_result = self.new_les + lse_result = self.new_lse else: lse_result = self.new_lse_height return [golden_out, lse_result] @@ -1403,7 +1402,5 @@ class TestMLAPrefill(op_test.OpTest): [torch.tensor(attention_out, dtype=data_type), output_lse]) - - if __name__ == '__main__': unittest.main()