diff --git a/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py b/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py index 0a85b1caf2b7432209bcffdefcf45abed98947ae..4d4fb5c0f9782e296da5553f5bc3037ee67ed3dc 100644 --- a/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py +++ b/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# isort:skip_file # encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # @@ -53,8 +54,12 @@ def test_deepseek_r1_bf16(): sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) # Create an LLM. - llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-bf16", - trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=2, max_model_len=4096) + llm = LLM( + model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-bf16", + trust_remote_code=True, + gpu_memory_utilization=0.9, + tensor_parallel_size=2, + max_model_len=33 * 1024) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm_mindspore/model_executor/models/attention_mask.py b/vllm_mindspore/model_executor/models/attention_mask.py index 0df3c30ab3caa290aae2b93501c4d8415d84819a..6b2b992ce8a3ed8b232b603a4be115e4d2bfd88f 100644 --- a/vllm_mindspore/model_executor/models/attention_mask.py +++ b/vllm_mindspore/model_executor/models/attention_mask.py @@ -33,6 +33,8 @@ FA:ASD-V2.1.5 2.normal: mask BF16(0/1), FP16 mask(0/-10000); """ +MAX_MODEL_LEN_32K = 32 * 1024 + class LowerTriangularMask: r""" @@ -53,13 +55,32 @@ class LowerTriangularMask: prefill_mask_coeff, dtype=self.dtype) - self.decode_mask = Tensor(np.triu(np.ones( - shape=(self.max_model_len, self.max_model_len), dtype=np.int8), - k=1), - dtype=self.dtype) * -10000 + if self.max_model_len > MAX_MODEL_LEN_32K: + self.decode_mask = np.triu(np.ones( + shape=(self.max_model_len, self.max_model_len), + dtype=np.float16), + k=1) * -10000 + else: + self.decode_mask = Tensor(np.triu(np.ones( + shape=(self.max_model_len, self.max_model_len), dtype=np.int8), + k=1), + dtype=self.dtype) * -10000 self.hard_mask = mint.zeros((1, 1), dtype=dtype) + def gen_attention_decode_mask(self, position_ids): + if isinstance(self.decode_mask, ms.Tensor): + attention_mask = mint.index_select(self.decode_mask, 0, + position_ids) + elif isinstance(self.decode_mask, np.ndarray): + attention_mask = self.decode_mask[position_ids.asnumpy()] + attention_mask = ms.Tensor(attention_mask, dtype=self.dtype) + else: + raise ValueError( + f"Decode mask type:{type(self.decode_mask)} is not supported.") + + return attention_mask + def gen_attention_mask(self, is_prefill, position_ids, @@ -69,8 +90,7 @@ class LowerTriangularMask: attention_mask = self.prefill_mask else: if max(query_lens) > 1: - attention_mask = mint.index_select(self.decode_mask, 0, - position_ids) + attention_mask = self.gen_attention_decode_mask(position_ids) else: attention_mask = self.hard_mask return attention_mask @@ -88,10 +108,16 @@ class MLALowerTriangularMask(LowerTriangularMask): super().__init__(dtype, max_model_len) decode_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 - self.decode_mask = Tensor(np.triu(np.ones( - shape=(self.max_model_len, self.max_model_len), dtype=np.int8), - k=1), - dtype=self.dtype) * decode_mask_coeff + if self.max_model_len > MAX_MODEL_LEN_32K: + self.decode_mask = np.triu(np.ones( + shape=(self.max_model_len, self.max_model_len), + dtype=np.float16), + k=1) * decode_mask_coeff + else: + self.decode_mask = Tensor(np.triu(np.ones( + shape=(self.max_model_len, self.max_model_len), dtype=np.int8), + k=1), + dtype=self.dtype) * decode_mask_coeff class MultiModalLowerTriangularMask(LowerTriangularMask):