From 474e6ba6e0358b7dd4922ca7de937740915f43dd Mon Sep 17 00:00:00 2001 From: wusimin Date: Thu, 22 May 2025 10:02:31 +0800 Subject: [PATCH] =?UTF-8?q?[0.8.3=20v1]=E9=80=82=E9=85=8D=E5=8E=9F?= =?UTF-8?q?=E7=94=9FQwen,=E4=BF=AE=E5=A4=8D=E7=B2=BE=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm_mindspore/model_executor/models/attention_mask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_mindspore/model_executor/models/attention_mask.py b/vllm_mindspore/model_executor/models/attention_mask.py index 40be1f4..42d6e62 100644 --- a/vllm_mindspore/model_executor/models/attention_mask.py +++ b/vllm_mindspore/model_executor/models/attention_mask.py @@ -46,7 +46,7 @@ class LowerTriangularMask: self.dtype = dtype self.max_model_len = max_model_len - prefill_mask_coeff = 1.0 if self.dtype is mstype.bfloat16 else -10000.0 + prefill_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 self.prefill_mask = Tensor(np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1) * prefill_mask_coeff, dtype=self.dtype) @@ -78,6 +78,6 @@ class MLALowerTriangularMask(LowerTriangularMask): def __init__(self, dtype, max_model_len): super().__init__(dtype, max_model_len) - decode_mask_coeff = 1.0 if self.dtype is mstype.bfloat16 else -10000.0 + decode_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 self.decode_mask = Tensor(np.triu(np.ones(shape=(self.max_model_len, self.max_model_len), dtype=np.int8), k=1), dtype=self.dtype) * decode_mask_coeff -- Gitee