diff --git a/tests/apitest/opstest/csv/self_attention.csv b/tests/apitest/opstest/csv/self_attention.csv index acb8c7782c56788bc7f7cafcdfddce1a3b190c7d..d69fbc4aa2235fddce64fd389f5a71b82050bbf9 100644 --- a/tests/apitest/opstest/csv/self_attention.csv +++ b/tests/apitest/opstest/csv/self_attention.csv @@ -1,6 +1,6 @@ CaseNum|CaseName |OpName |OpParam |InNum|InDType |InFormat |InShape |OutNum|OutDType|OutFormat|OutShape |DataGenType |DataGenRange |InTensorFile|OutTensorFile|TestType|TestLevel|FromModel|SocVersion|ExpectedError -1 |SelfAttentionBase |SelfAttentionOperation|{"headNum":4,"qScale":0.2,"qkScale":1,"calcType":0, "clampType": 0, "maskType":1, "kernelType":0} |9 |float16;float16;float16;float16;float16;float16;int32;int32;int32|nd;nd;nd;nd;nd;nd;nd;nd;nd |4,3,32,128;4,3,32,128;4,3,32,128;28,3,2048,4096;28,3,2048,4096;2048,2048;3;3;1 |1 |float16 |nd |3,32 |customize;customize;customize;customize;customize;customize;customize;customize;customize|-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100| | | | | |Ascend910B|NO_ERROR -2 |SelfAttentionEncoder |SelfAttentionOperation|{"headNum":32,"qkScale":0.08838834764831843,"kvHeadNum":32,"calcType":3, "clampType": 0, "maskType":0, "kernelType":0} |4 |float16;float16;float16;int32 |nd;nd;nd;nd |32,128;32,128;32,128;128 |1 |float16 |nd |2048,32,128|customize;customize;customize;customize |-100,100;-100,100;-100,100;-100,100;-100,100 | | | | | |Ascend910B|NO_ERROR +1 |SelfAttentionBase |SelfAttentionOperation|{"headNum":4,"qScale":0.2,"qkScale":1,"calcType":0, "clampType": 0, "maskType":1, "kernelType":0} |9 |float16;float16;float16;float16;float16;float16;int32;int32;int32|nd;nd;nd;nd;nd;nd;nd;nd;nd |4,3,32,128;4,3,32,128;4,3,32,128;28,3,2048,4096;28,3,2048,4096;2048,2048;3;3;1 |1 |float16 |nd |3,32 |customize;customize;customize;customize;customize;customize;customize;customize;customize|-1,1;-1,1;-1,1;-1,1;-1,1;-1,1;-1,1;-1,1;-1,1| | | | | |Ascend910B|NO_ERROR +2 |SelfAttentionEncoder |SelfAttentionOperation|{"headNum":32,"qkScale":0.08838834764831843,"kvHeadNum":32,"calcType":3, "clampType": 0, "maskType":0, "kernelType":0} |4 |float16;float16;float16;int32 |nd;nd;nd;nd |2048,32,128;2048,32,128;2048,32,128;16 |1 |float16 |nd |2048,32,128|customize;customize;customize;customize |-1,1;-1,1;-1,1;-1,1;-1,1 | | | | | |Ascend910B|NO_ERROR 3 |SelfAttentionClamp |SelfAttentionOperation|{"headNum":4,"qScale":0.2,"headDim":8,"qkScale":1,"calcType":0,"clampMin":0.3333,"clampMax":0.5555,"clampType":1, "maskType":1, "kernelType":0} |9 |float16;float16;float16;float16;float16;float16;int32;int32;int32|nd;nd;nd;nd;nd;nd;nd;nd;nd |4,3,32,128;4,3,32,128;4,3,32,128;28,3,2048,4096;28,3,2048,4096;2048,2048;3;3;1 |1 |float16 |nd |3,32 |customize;customize;customize;customize;customize;customize;customize;customize;customize|-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100| | | | | |Ascend910B|NO_ERROR 4 |SelfAttentionBaseWrongDtype0 |SelfAttentionOperation|{"headNum":4,"qScale":0.2,"headDim":8,"qkScale":1,"calcType":0, "clampType": 0, "maskType":1} |9 |int32;float16;float16;float16;float16;float16;int32;int32;int32 |nd;nd;nd;nd;nd;nd;nd;nd;nd |4,3,32,128;4,3,32,128;4,3,32,128;28,3,2048,4096;28,3,2048,4096;2048,2048;3;3;1 |1 |float16 |nd |3,32 |one;one;one;one;one;one;one;one;one |-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100| | | | | |Ascend910B|I:ERROR_INVALID_TENSOR_INI_MATCH 5 |SelfAttentionBaseWrongDtype1 |SelfAttentionOperation|{"headNum":4,"qScale":0.2,"headDim":8,"qkScale":1,"calcType":0, "clampType": 0, "maskType":1} |9 |float16;bool;float16;float16;float16;float16;int32;int32;int32 |nd;nd;nd;nd;nd;nd;nd;nd;nd |4,3,32,128;4,3,32,128;4,3,32,128;28,3,2048,4096;28,3,2048,4096;2048,2048;3;3;1 |1 |float16 |nd |3,32 |one;one;one;one;one;one;one;one;one |-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100;-100,100| | | | | |Ascend910B|I:ERROR_INVALID_TENSOR_INI_MATCH @@ -91,6 +91,7 @@ CaseNum|CaseName |OpName |OpParam 90 |SelfAttentionEncoderWrongDim |SelfAttentionOperation|{"headNum":1,"qkScale":0.08838834764831843,"kvHeadNum":1,"calcType":3, "clampType": 0, "maskType":0, "kernelType":0} |4 |float16;float16;float16;int32 |nd;nd;nd;nd |1,32,1,128;1,32,1,128;1,32,1,128;32 |1 |float16 |nd |2048,32,128|one;one;one;one |-100,100;-100,100;-100,100;-100,100;-100,100 | | | | | | |I:ERROR_INVALID_TENSOR_SIZE 91 |SelfAttentionEncoderWrongDim |SelfAttentionOperation|{"headNum":1,"qkScale":0.08838834764831843,"kvHeadNum":1,"calcType":3, "clampType": 0, "maskType":0, "kernelType":0} |4 |float16;float16;float16;int32 |nd;nd;nd;nd |32,128;1,32,128;1,32,128;32 |1 |float16 |nd |2048,32,128|one;one;one;one |-100,100;-100,100;-100,100;-100,100;-100,100 | | | | | | |I:ERROR_INVALID_TENSOR_SIZE 92 |SelfAttentionEncoderWrongDim |SelfAttentionOperation|{"headNum":1,"qkScale":0.08838834764831843,"kvHeadNum":1,"calcType":3, "clampType": 0, "maskType":0, "kernelType":0} |4 |float16;float16;float16;int32 |nd;nd;nd;nd |32,128;32,128;1,32,128;32 |1 |float16 |nd |2048,32,128|one;one;one;one |-100,100;-100,100;-100,100;-100,100;-100,100 | | | | | |Ascend910B|I:ERROR_INVALID_TENSOR_SIZE +93 |SelfAttentionOperationDumpTensorCase |SelfAttentionOperation|{"batchRunStatusEnable":false,"cacheType":0,"calcType":3,"clampMax":0.0,"clampMin":0.0,"clampType":0,"headNum":16,"inputLayout":0,"isTriuMask":1,"kernelType":0,"kvHeadNum":16,"kvcacheCfg":0,"maskType":1,"mlaVHeadSize":0,"outDataType":-1,"qScale":1.0,"qkScale":0.1147213876247406,"quantType":0,"scaleType":0,"windowSize":0}|5|bf16;bf16;bf16;bf16;int32|nd;nd;nd;nd;nd|1024,16,192;1024,16,192;1024,16,128;128,128;1|1|bf16|nd|1024,16,128|customize;customize;customize;customize;customize| ||| | | | |NO_ERROR 94 |SelfAttentionEncoderWrongDim |SelfAttentionOperation|{"headNum":1,"qkScale":0.08838834764831843,"kvHeadNum":1,"calcType":3, "clampType": 0, "maskType":0, "mlaVHeadSize":64} |3 |float16;float16;int32 |nd;nd;nd;nd |32,128;1,1,1,32,128;32 |1 |float16 |nd |2048,32,128|one;one;one;one |-100,100;-100,100;-100,100;-100,100;-100,100 | | | | | |Ascend910B|I:ERROR_INVALID_TENSOR_SIZE 95 |SelfAttentionEncoderWrongDim |SelfAttentionOperation|{"headNum":1,"qkScale":0.08838834764831843,"kvHeadNum":1,"calcType":3, "clampType": 0, "maskType":0, "mlaVHeadSize":64} |3 |float16;float16;int32 |nd;nd;nd;nd |1,1,1,32,128;1,32,128;32 |1 |float16 |nd |2048,32,128|one;one;one;one |-100,100;-100,100;-100,100;-100,100;-100,100 | | | | | |Ascend910B|I:ERROR_INVALID_TENSOR_SIZE 97 |SelfAttentionEncoderWrongDim |SelfAttentionOperation|{"headNum":1,"qkScale":0.08838834764831843,"kvHeadNum":1,"calcType":3, "clampType": 0, "maskType":0, "kernelType":0} |4 |float16;float16;float16;int32 |nd;nd;nd;nd |32,128;32,128;64,128;1 |1 |float16 |nd |2048,32,128|one;one;one;one |-100,100;-100,100;-100,100;-100,100;-100,100 | | | | | | |I:ERROR_INVALID_TENSOR_DIM diff --git a/tests/framework/python/CsvOpsTestTool/data_generation.py b/tests/framework/python/CsvOpsTestTool/data_generation.py index 0f00930b8f929550fb09f192baa3fb337e50fb70..6fdb312d71d5cf804a10ac3479a8ad6626336e26 100755 --- a/tests/framework/python/CsvOpsTestTool/data_generation.py +++ b/tests/framework/python/CsvOpsTestTool/data_generation.py @@ -5065,6 +5065,8 @@ class PagedAttentionOperation(DataGen): return OpTypes.CV_FUSION class SelfAttentionOperation(DataGen): + MASK_TYPE_UNDEFINED = 0 + def gen_mask(batch, heads, max_seq,data_type, mask_type,is_decoder=False,is_triu_mask=False,is_alibi=False,dynamic_batch=False,long_seq=False): import random q_max_seq = max_seq @@ -5276,6 +5278,115 @@ class SelfAttentionOperation(DataGen): ret_data = q, k, v, q_len, out return ret_data + def calc_expect_func_encoder(batch, seqlen, heads, embed, embed_v, group_num=32, is_mask=False): + logging.info(f"Encoder param: batch {batch}, seqlen {seqlen}, heads {heads}, embed {embed}, embed_v {embed_v}, group_num {group_num}") + is_mask = is_mask + variate_seq = False + is_decoder = False + max_seq = 2048 + src_type = 'float16' + fp32 = True + logging.debug(f"group_num: {group_num}") + logging.debug("q_seq is:") + if is_decoder: + q_seqlen, q_seqlen_aligned, q_ntokens = SelfAttentionOperation.gen_seq_len(batch, 1, variate_seq) + kv_seqlen, kv_seqlen_aligned, kv_ntokens = SelfAttentionOperation.gen_seq_len(batch, seqlen, variate_seq) + else: + q_seqlen, q_seqlen_aligned, q_ntokens = SelfAttentionOperation.gen_seq_len(batch, seqlen, variate_seq) + kv_seqlen, kv_seqlen_aligned, kv_ntokens = q_seqlen, q_seqlen_aligned, q_ntokens # crossattention时,q_seqlen != k_seqlen + + logging.debug("q_seqlen", q_seqlen) + + max_s = np.max(q_seqlen) + ntokens2 = (q_seqlen * kv_seqlen).sum() + + q = np.random.uniform(-1.0, 1.0, size=(q_ntokens, heads * embed)).astype(np.float16) + k = np.random.uniform(-1.0, 1.0, size=(kv_ntokens, group_num * embed)).astype(np.float16) + v = np.random.uniform(-1.0, 1.0, size=(kv_ntokens, group_num * embed_v)).astype(np.float16) + + # TODO:增加compress mask based on isTriuMask = 1 + mask = np.ones(shape=(1, max_s, max_s)).astype(np.float16) # 使用当前最大seqlen生成mask + mask = np.triu(mask, 1) + mask *= -10000.0 + logging.debug(mask) + + q_offset = 0 + k_offset = 0 + v_offset = 0 + + s = None + _p = None + out = None + + for idx in range(batch): + q_s = q_seqlen[idx] + kv_s = kv_seqlen[idx] + q_slice = q[q_offset:q_offset + q_s][:] + q_slice = q_slice.reshape(q_s, heads, embed) + q_slice = np.transpose(q_slice, (1, 0, 2)) # (heads, q_seq, embed) + k_slice = k[k_offset:k_offset + kv_s][:] + k_slice = k_slice.reshape(kv_s, group_num, embed) + k_slice = np.transpose(k_slice, (1, 0, 2)) + k_slice_t = np.transpose(k_slice, (0, 2, 1)) # get K^T (kv_heads, embed, k_seq) + v_slice = v[v_offset:v_offset + kv_s][:] + v_slice = v_slice.reshape(kv_s, group_num, embed_v) + v_slice = np.transpose(v_slice, (1, 0, 2)) + score = SelfAttentionOperation.group_matmul(heads, group_num, q_slice, k_slice_t) + if s is None: + s = score.reshape([-1, ]) + else: + s = np.concatenate((s, score.reshape([-1, ])), 0) + + tor = np.float16(1.0 / math.sqrt(1.0 * embed)) + score = score * tor + if is_mask: + score = score + mask[:, :q_s, :kv_s] + score_max = np.max(score, axis=-1) + score = score - score_max.reshape((heads, q_s, 1)) + score_exp = np.exp(score.astype(np.float32)) + if not fp32: + score_sum = np.sum(score_exp.astype(np.float16), axis=-1) + if _p is None: + _p = score_exp.astype(np.float16).reshape([-1, ]) + else: + _p = np.concatenate((_p, score_exp.astype(np.float16).reshape([-1, ])), 0) + p = score_exp.astype(np.float16) / score_sum.reshape((heads, q_s, 1)).astype(np.float16) + out_sub = SelfAttentionOperation.group_matmul(heads, group_num, p, v_slice) + else: + score_sum = np.sum(score_exp, axis=-1) + if _p is None: + _p = score_exp.astype(np.float16).reshape([-1, ]) + else: + _p = np.concatenate((_p, score_exp.astype(np.float16).reshape([-1, ])), 0) + p = score_exp.astype(np.float16) + out_sub = SelfAttentionOperation.group_matmul(heads, group_num, p, v_slice) + out_sub = out_sub / score_sum.reshape((heads, q_s, 1)).astype(np.float16) + + out_sub = out_sub.reshape(heads, q_s, embed_v) + out_sub = np.transpose(out_sub, (1, 0, 2)) + out_sub = np.ascontiguousarray(out_sub) + if out is None: + out = out_sub + else: + out = np.concatenate((out, out_sub), 0) + + q_offset += q_s + k_offset += kv_s + v_offset += kv_s + + logging.info("==> data generate finished!") + + q = q.astype(src_type).reshape(-1, heads, embed) + k = k.astype(src_type).reshape(-1, group_num, embed) + v = v.astype(src_type).reshape(-1, group_num, embed_v) + mask = mask.astype(src_type).reshape(max_s, max_s) + q_len = q_seqlen.astype(np.int32) + out = out.astype(src_type).reshape(-1, heads, embed_v) + logging.info("calc_expect_func_encoder out shape", out.shape) + if is_mask: + return q, k, v, mask, q_len, out + return q, k, v, q_len, out + @staticmethod def case_preprocess(op_params, operation, input_tensor_list): json_data = json.loads(op_params) @@ -5344,18 +5455,59 @@ class SelfAttentionOperation(DataGen): SelfAttentionOperation.clamp_max = clamp_max SelfAttentionOperation.in_tensors = [q,k,v,attention_mask.to(data_type).npu(),torch.tensor(kv_seqLen).to(torch.int32).npu(),torch.tensor(q_seqlen).to(torch.int32).npu(),layer_id] return SelfAttentionOperation.in_tensors[i] + if json_data["calcType"] == 3: if i == 0: - kv_head = 32 - data = SelfAttentionOperation.calc_expect_func(16, 128, 32, 128, group_num=kv_head) - param_seqlen = data[4].tolist() - in_tensors = [torch.from_numpy(tensor) for tensor in data] - SelfAttentionOperation.in_tensors_encoder = [tensor.npu() for tensor in in_tensors] - for tensor in in_tensors: + q_idx = 0 + v_idx = 2 + mask_idx = -1 + seqlen_idx = 3 + is_mask = False + if json_data["maskType"] != SelfAttentionOperation.MASK_TYPE_UNDEFINED: + is_mask = True + mask_idx = 3 + seqlen_idx += 1 + batch = shapes[seqlen_idx][0] + if len(shapes[seqlen_idx]) == 2: + batch = shapes[seqlen_idx][1] + q_ntokens = shapes[q_idx][0] + seqlen = q_ntokens // batch + if q_ntokens % batch != 0: + seqlen += 1 + heads = json_data["headNum"] + kv_head = heads + if "kvHeadNum" in json_data: + kv_head = json_data["kvHeadNum"] + embed = shapes[q_idx][-1] + if len(shapes[q_idx]) == 2: + embed //= heads + embed_v = shapes[v_idx][-1] + if len(shapes[v_idx]) == 2: + embed_v //= kv_head + + # calc_expect_func_encoder(batch, seqlen, heads, embed, embed_v, group_num=32, is_mask=False) + tensor_list = SelfAttentionOperation.calc_expect_func_encoder(batch, seqlen, heads, embed, embed_v, group_num=kv_head, is_mask=is_mask) + SelfAttentionOperation.param_seqlen = tensor_list[seqlen_idx].tolist() + SelfAttentionOperation.in_tensors_encoder = [torch.from_numpy(tensor) for tensor in tensor_list] + # isTriuMask: comprss mask 128 * 128,替换原有mask + if "isTriuMask" in json_data and json_data["isTriuMask"] == 1: + mask_type = 1 + if "maskType" in json_data: + mask_type = json_data["maskType"] + SelfAttentionOperation.is_triu_mask = 1 + max_seq = 128 + attention_mask, _ = SelfAttentionOperation.gen_mask(batch, heads, max_seq, dtype_dict[datatype], mask_type, is_decoder=False, is_triu_mask=True) + if attention_mask.shape[0] == 1: + attention_mask = attention_mask.squeeze(0) + + SelfAttentionOperation.in_tensors_encoder[mask_idx] = attention_mask + + for tensor in tensor_list: logging.debug(tensor.dtype, tensor.shape) - return SelfAttentionOperation.in_tensors_encoder[0] + return SelfAttentionOperation.in_tensors_encoder[0].to(dtype_dict[datatype]).npu() else: - return SelfAttentionOperation.in_tensors_encoder[i] + return SelfAttentionOperation.in_tensors_encoder[i].to(dtype_dict[datatype]).npu() + elif json_data["clampType"] == 1: if i != 0: return SelfAttentionOperation.in_tensors_clamp[i] @@ -5432,7 +5584,10 @@ class SelfAttentionOperation(DataGen): try: json_data = json.loads(op_params) if json_data["calcType"] == 3: - shape = SelfAttentionOperation.in_tensors_encoder[4].shape + output_idx = 4 + if json_data["maskType"] != SelfAttentionOperation.MASK_TYPE_UNDEFINED: + output_idx += 1 + shape = SelfAttentionOperation.in_tensors_encoder[output_idx].shape data = torch.zeros(shape, dtype=dtype_dict[datatype]).npu() return torch_npu.npu_format_cast(data, format_dict[format]) elif json_data["clampType"] == 1: @@ -5455,6 +5610,12 @@ class SelfAttentionOperation(DataGen): asdops_param["head_num"] = json_data["headNum"] asdops_param["is_decoder"] = False asdops_param["embeddim"] = int(q.shape[1] / json_data["headNum"]) + embeddimV = 0 # v headSize + if len(v.shape) == 2: + embeddimV = int(v.shape[1] / json_data["kvHeadNum"]) + elif len(v.shape) == 3: + embeddimV = int(v.shape[2]) + asdops_param["embeddimV"] = embeddimV asdops_param["kv_head"] = json_data["kvHeadNum"] asdops_param["is_mask"] = (json_data["maskType"] != 0) asdops_param["qk_scale"] = json_data["qkScale"] @@ -5466,6 +5627,7 @@ class SelfAttentionOperation(DataGen): asdops_param["data_type"] = q.dtype asdops_param["q_ntokens"] = q.shape[0] asdops_param["kv_ntokens"] = k.shape[0] + # seqlen calculated asdops_param["q_seqlen"] = seq_len.tolist() asdops_param["maskType"] = json_data["maskType"] @@ -5552,6 +5714,7 @@ class SelfAttentionOperation(DataGen): k_slice = torch.permute(k_slice, (1, 0, 2)) k_slice_t =torch.permute(k_slice, (0, 2, 1)) # get K^T (kv_heads, embed, k_seq) v_slice = v[v_offset:v_offset + kv_s][:] + # TODO(ivan): slicing error: [1, 16, 192], invaid for input of size 2048 v_slice = v_slice.view(kv_s, kv_head, embed) v_slice = torch.permute(v_slice, (1, 0, 2)) score = SelfAttentionOperation.group_mm_torch_encoder(head_num, kv_head, q_slice, k_slice_t)