diff --git a/vllm_mindspore/model_executor/models/attention_mask.py b/vllm_mindspore/model_executor/models/attention_mask.py index 40be1f46cb4de747f1d607241fb264e30acde3fe..42d6e6297d5da3320435e3a40afa96bf17bff389 100644 --- a/vllm_mindspore/model_executor/models/attention_mask.py +++ b/vllm_mindspore/model_executor/models/attention_mask.py @@ -46,7 +46,7 @@ class LowerTriangularMask: self.dtype = dtype self.max_model_len = max_model_len - prefill_mask_coeff = 1.0 if self.dtype is mstype.bfloat16 else -10000.0 + prefill_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 self.prefill_mask = Tensor(np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1) * prefill_mask_coeff, dtype=self.dtype) @@ -78,6 +78,6 @@ class MLALowerTriangularMask(LowerTriangularMask): def __init__(self, dtype, max_model_len): super().__init__(dtype, max_model_len) - decode_mask_coeff = 1.0 if self.dtype is mstype.bfloat16 else -10000.0 + decode_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 self.decode_mask = Tensor(np.triu(np.ones(shape=(self.max_model_len, self.max_model_len), dtype=np.int8), k=1), dtype=self.dtype) * decode_mask_coeff