From 50e382ec4f7b2d1ce0f3e48abef743efe3938481 Mon Sep 17 00:00:00 2001 From: "zhaozhiyong15@hisilicon.com" Date: Fri, 12 Apr 2024 11:36:13 +0800 Subject: [PATCH] bsnd --- third_party/op-plugin | 2 +- torch_npu/meta/meta_registrations.py | 20 ++++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/third_party/op-plugin b/third_party/op-plugin index 7c0d60b6fa..1f8963baa2 160000 --- a/third_party/op-plugin +++ b/third_party/op-plugin @@ -1 +1 @@ -Subproject commit 7c0d60b6fa8d2f43c1c036a48e0bd6bd291a177e +Subproject commit 1f8963baa2eff2e101d472f9f98e34b5ae414e6e diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 5df8ebd955..74579723bb 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -29,12 +29,16 @@ def npu_incre_flash_attention_forward(query, key, value, *, padding_mask=None, a @impl(m, "npu_prompt_flash_attention") def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, pse_shift=None, actual_seq_lengths=None, deq_scale1=None, quant_scale1=None, deq_scale2=None, quant_scale2=None, quant_offset2=None, num_heads=1, scale_value=1.0, pre_tokens=2147473647, next_tokens=0, input_layout="BSH", num_key_value_heads=0, actual_seq_lengths_kv=None, sparse_mode=0): + tmp_out = torch.empty_like(query, dtype=query.dtype, device='meta') + if input_layout == "BNSD_BSND": + tmp_out = torch.empty([query.size(0), query.size(2), query.size(1), query.size(3)], dtype=query.dtype, device='meta') + if quant_scale2 is not None: - return torch.empty_like(query, dtype=torch.int8) + return torch.empty_like(tmp_out, dtype=torch.int8) elif query.dtype == torch.int8: - return torch.empty_like(query, dtype=torch.half) + return torch.empty_like(tmp_out, dtype=torch.half) else: - return torch.empty_like(query, dtype=query.dtype) + return torch.empty_like(tmp_out, dtype=query.dtype) @impl(m, "npu_fused_infer_attention_score") @@ -44,12 +48,16 @@ def npu_fused_infer_attention_score_forward(query, key, value, *, pse_shift = No query_padding_size = None, kv_padding_size = None, num_heads = 1, scale = 1.0, pre_tokens = 2147483647, next_tokens = 2147483647, input_layout = "BSH", num_key_value_heads = 0, sparse_mode = 0, inner_precise = 0, block_size = 0, antiquant_mode = 0, softmax_lse_flag = False): + tmp_out = torch.empty_like(query, dtype=query.dtype, device='meta') + if input_layout == "BNSD_BSND": + tmp_out = torch.empty([query.size(0), query.size(2), query.size(1), query.size(3)], dtype=query.dtype, device='meta') + if quant_scale2 is not None: - return (torch.empty_like(query, dtype=torch.int8), torch.empty_like(query)) + return (torch.empty_like(tmp_out, dtype=torch.int8), torch.empty_like(query)) elif query.dtype == torch.int8: - return (torch.empty_like(query, dtype=torch.half), torch.empty_like(query)) + return (torch.empty_like(tmp_out, dtype=torch.half), torch.empty_like(query)) else: - return (torch.empty_like(query), torch.empty_like(query)) + return (torch.empty_like(tmp_out), torch.empty_like(query)) @impl(m, "npu_fusion_attention") -- Gitee