From 4935017be080235abc0f1d8e3443e123c0404fe3 Mon Sep 17 00:00:00 2001 From: donghan Date: Thu, 27 Mar 2025 11:22:09 +0800 Subject: [PATCH] ds fused_infer_attention_score interface add param : key_rope_antiquant_scale --- test/torch_npu_schema.json | 2 +- third_party/op-plugin | 2 +- third_party/torchair/torchair | 2 +- torch_npu/onnx/wrapper_onnx_ops.py | 7 ++++--- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/test/torch_npu_schema.json b/test/torch_npu_schema.json index e715642f2..35f6c7475 100644 --- a/test/torch_npu_schema.json +++ b/test/torch_npu_schema.json @@ -2586,7 +2586,7 @@ "signature": "(query_layer, key_layer, value_layer, attention_mask, scale, keep_prob, query_transpose=False, key_transpose=False, bmm_score_transpose_a=False, bmm_score_transpose_b=False, value_transpose=False, dx_transpose=False)" }, "torch_npu.npu_fused_infer_attention_score": { - "signature": "(self, query, key, value, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, num_heads, scale, pre_tokens, next_tokens, input_layout, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode)" + "signature": "(self, query, key, value, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, num_heads, scale, pre_tokens, next_tokens, input_layout, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode)" }, "torch_npu.npu_fusion_attention": { "signature": "(*args, **kwargs)" diff --git a/third_party/op-plugin b/third_party/op-plugin index e17bd2f79..91e5f9a62 160000 --- a/third_party/op-plugin +++ b/third_party/op-plugin @@ -1 +1 @@ -Subproject commit e17bd2f792880b53d638f0049bcdcf75b299f02c +Subproject commit 91e5f9a6244a3e8faa3a0b1fb4a40fb8a97c204d diff --git a/third_party/torchair/torchair b/third_party/torchair/torchair index e54f651b5..018466a50 160000 --- a/third_party/torchair/torchair +++ b/third_party/torchair/torchair @@ -1 +1 @@ -Subproject commit e54f651b5dde9494012d83d4e18f8a336b1120f8 +Subproject commit 018466a50b704ffe0df5a38dd77e2d50a2d27afe diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 4b2348d87..c8ac02825 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -713,6 +713,7 @@ class _NPUFusedInferAttentionScoreOP(torch.autograd.Function): actual_shared_prefix_len: Optional[Tensor], query_rope: Optional[Tensor], key_rope: Optional[Tensor], + key_rope_antiquant_scale: Optional[Tensor], num_heads: int = 1, scale: float = 1.0, pre_tokens: int = 2147483647, next_tokens: int = 2147483647, input_layout: str = "BSH", num_key_value_heads: int = 0, @@ -724,7 +725,7 @@ class _NPUFusedInferAttentionScoreOP(torch.autograd.Function): dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, - key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, + key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale, num_heads, scale, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode) @@ -1297,7 +1298,7 @@ def _wrapper_npu_fused_infer_attention_score(self, query, key, value, pse_shift, antiquant_offset, block_table, query_padding_size, kv_padding_size, num_heads, scale, pre_tokens, next_tokens, input_layout, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, - key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, + key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode): @@ -1307,7 +1308,7 @@ def _wrapper_npu_fused_infer_attention_score(self, query, key, value, pse_shift, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, - key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, + key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale, num_heads, scale, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, -- Gitee