From 971ea6d0010d7adc6a3169da582fa4e6bf892465 Mon Sep 17 00:00:00 2001 From: uh Date: Fri, 8 Aug 2025 23:32:34 +0800 Subject: [PATCH] fea: long squence mask create optimize --- .../model_executor/models/attention_mask.py | 97 +++++++++++++++---- .../model_executor/models/mf_models/qwen3.py | 2 +- .../models/mindone_models/qwen2.py | 2 +- .../models/mindone_models/qwen2_5_vl.py | 2 +- .../models/mindone_models/qwen3.py | 2 +- .../model_executor/models/model_base.py | 9 +- 6 files changed, 87 insertions(+), 27 deletions(-) diff --git a/vllm_mindspore/model_executor/models/attention_mask.py b/vllm_mindspore/model_executor/models/attention_mask.py index 9ec0c57d..63162299 100644 --- a/vllm_mindspore/model_executor/models/attention_mask.py +++ b/vllm_mindspore/model_executor/models/attention_mask.py @@ -13,14 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ infer attention mask. """ import numpy as np - -from mindspore import Tensor, mint +from mindspore import Tensor from mindspore import dtype as mstype +from mindspore import mint + +# yapf conflicts with isort +# yapf: disable r""" PA:ASD-V2.1.5 @@ -34,6 +36,8 @@ FA:ASD-V2.1.5 2.normal: mask BF16(0/1), FP16 mask(0/-10000); """ +# yapf: enable + class LowerTriangularMask: r""" @@ -43,28 +47,85 @@ class LowerTriangularMask: max_model_len (int): The max model length of Infer model. """ - def __init__(self, dtype, max_model_len): + def __init__(self, dtype, max_model_len, decode_mask_coeff=-10000.0): self.dtype = dtype self.max_model_len = max_model_len + self.cached_mask_len = 8 * 1024 + self.decode_mask_coeff = decode_mask_coeff prefill_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 - - self.prefill_mask = Tensor(np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1) * prefill_mask_coeff, - dtype=self.dtype) - - self.decode_mask = Tensor(np.triu(np.ones(shape=(self.max_model_len, self.max_model_len), dtype=np.int8), k=1), - dtype=self.dtype) * -10000 + self.prefill_mask = Tensor( + np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1) * + prefill_mask_coeff, + dtype=self.dtype) self.hard_mask = mint.zeros((1, 1), dtype=dtype) + self.decode_mask = Tensor(np.triu(np.ones( + shape=(self.cached_mask_len, self.cached_mask_len), dtype=np.int8), + k=1), + dtype=self.dtype) * self.decode_mask_coeff + + def create_mask(self, query_lens_np, seq_lens_np): + ''' + when query_lens_np = [3], seq_lens_np = [6], decode_mask_coeff = 1 + init attention mask + 0 0 0 0 0 0 + 0 0 0 0 0 0 + 0 0 0 0 0 0 + ''' + max_seq_len = seq_lens_np.max().item() + total_q_len = query_lens_np.sum().item() + attention_mask = mint.zeros((total_q_len, max_seq_len), + dtype=self.dtype) + + req_num = query_lens_np.shape[0] + # skip row when q_len = 0, to decrease execute time + current_row = np.argmax(query_lens_np != 0).item() + for i in range(current_row, req_num): + q_len = query_lens_np[i].item() + seq_len = seq_lens_np[i].item() + context_len = seq_len - q_len + ''' + set the right half to 1 + 0 0 0 1 1 1 + 0 0 0 1 1 1 + 0 0 0 1 1 1 + ''' + attention_mask[current_row:current_row + q_len, + context_len:] = self.decode_mask_coeff + ''' + set the lower triangle of the right half to 0 + 0 0 0 0 1 1 + 0 0 0 0 0 1 + 0 0 0 0 0 0 + ''' + right_tensor = attention_mask[current_row:current_row + q_len, + context_len:seq_len] + # use masked_fill_ to inplace modify attention_mask + right_tensor.masked_fill_( + right_tensor.tril() == self.decode_mask_coeff, 0) + current_row += q_len + + return attention_mask - def gen_attention_mask(self, is_prefill, position_ids, query_lens): + def gen_attention_mask(self, + is_prefill: bool, + position_ids: Tensor, + query_lens_np: np.ndarray, + seq_lens_np: np.ndarray, + attn_metadata=None): + max_query_len = query_lens_np.max() + max_seq_len = seq_lens_np.max() if is_prefill: attention_mask = self.prefill_mask - else: - if max(query_lens) > 1: - attention_mask = mint.index_select(self.decode_mask, 0, position_ids) + elif max_query_len > 1: + if max_seq_len <= self.cached_mask_len: + attention_mask = mint.index_select(self.decode_mask, 0, + position_ids) else: - attention_mask = self.hard_mask + attention_mask = self.create_mask(query_lens_np, seq_lens_np) + else: + attention_mask = self.hard_mask return attention_mask @@ -77,8 +138,6 @@ class MLALowerTriangularMask(LowerTriangularMask): """ def __init__(self, dtype, max_model_len): + decode_mask_coeff = 1.0 if dtype == mstype.bfloat16 else -10000.0 + super().__init__(dtype, max_model_len, decode_mask_coeff) - super().__init__(dtype, max_model_len) - decode_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 - self.decode_mask = Tensor(np.triu(np.ones(shape=(self.max_model_len, self.max_model_len), dtype=np.int8), k=1), - dtype=self.dtype) * decode_mask_coeff diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3.py b/vllm_mindspore/model_executor/models/mf_models/qwen3.py index 5ee8ce48..9b6c1549 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen3.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen3.py @@ -131,7 +131,7 @@ class Qwen3ForCausalLM(MsModelBase): q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) position_ids = ms.Tensor(positions, dtype=ms.int32) attention_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens_np) + is_prefill, position_ids, query_lens_np, seq_lens_np) model_inputs = {} model_inputs["input_ids"] = input_ids.astype(ms.int32) * 1 diff --git a/vllm_mindspore/model_executor/models/mindone_models/qwen2.py b/vllm_mindspore/model_executor/models/mindone_models/qwen2.py index b14e5df6..edddf9f6 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mindone_models/qwen2.py @@ -292,7 +292,7 @@ class Qwen2ForCausalLM(MindONEModelBase): slot_mapping = attn_metadata.slot_mapping attn_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens) + is_prefill, positions, query_lens_np, seq_lens_np) seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) batch_valid_length = Tensor.from_numpy(seq_lens_np) q_seq_lens = Tensor.from_numpy( diff --git a/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py index 79f6b20c..973ac3cf 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py +++ b/vllm_mindspore/model_executor/models/mindone_models/qwen2_5_vl.py @@ -396,7 +396,7 @@ class Qwen2_5_VLForConditionalGeneration(MindONEModelBase, SupportsMultiModal): slot_mapping = attn_metadata.slot_mapping attn_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens) + is_prefill, positions, query_lens_np, seq_lens_np) seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) batch_valid_length = Tensor.from_numpy(seq_lens_np) q_seq_lens = Tensor.from_numpy( diff --git a/vllm_mindspore/model_executor/models/mindone_models/qwen3.py b/vllm_mindspore/model_executor/models/mindone_models/qwen3.py index bbc1cc97..de0c3f7d 100644 --- a/vllm_mindspore/model_executor/models/mindone_models/qwen3.py +++ b/vllm_mindspore/model_executor/models/mindone_models/qwen3.py @@ -295,7 +295,7 @@ class Qwen3ForCausalLM(MindONEModelBase): slot_mapping = attn_metadata.slot_mapping attn_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens) + is_prefill, positions, query_lens_np, seq_lens_np) seq_lens_np = np.array(attn_metadata.seq_lens, dtype=np.int32) batch_valid_length = Tensor.from_numpy(seq_lens_np) q_seq_lens = Tensor.from_numpy( diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index bb2213cc..0f8da220 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -292,10 +292,11 @@ class MsModelBase: is_prefill = attn_metadata.max_context_lens == 0 query_lens_np = attn_metadata.q_seq_lens_np seq_lens_np = attn_metadata.seq_lens_np - + q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) position_ids = ms.Tensor(positions, dtype=ms.int32) - attention_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np) + attention_mask = self.casual_mask.gen_attention_mask( + is_prefill, position_ids, query_lens_np, seq_lens_np) model_inputs = {} # Convert input_ids and block_tables into contiguous tensors. @@ -330,7 +331,7 @@ class NativeModel(MsModelBase): def common_preprocess(self, vllm_config, prefix = ""): self.set_modules({"model": self.model, "lm_head": self.lm_head}) - self.casual_mask = LowerTriangularMask(dtype=self.model_config.dtype, + self.casual_mask = LowerTriangularMask(dtype=self.model_config.dtype, max_model_len=self.model_config.max_model_len) self.kv_caches = [AttentionWrapper() for i in range(self.config.num_hidden_layers)] @@ -407,7 +408,7 @@ class NativeModel(MsModelBase): self.set_model_inputs(is_prefill) self.prev_prefill = is_prefill - # for dummy_attention_metadata + # for dummy_attention_metadata if is_prefill and not self.set_flags: self.set_flags = True -- Gitee