From 45a4983946a6ebb4b21450770d0d4a9b65431368 Mon Sep 17 00:00:00 2001 From: yyyyrf Date: Fri, 23 Jan 2026 16:09:48 +0800 Subject: [PATCH] return logtis when label is None --- mindformers/pynative/base_models/gpt/gpt_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mindformers/pynative/base_models/gpt/gpt_model.py b/mindformers/pynative/base_models/gpt/gpt_model.py index bcc978c5d..5645724d8 100644 --- a/mindformers/pynative/base_models/gpt/gpt_model.py +++ b/mindformers/pynative/base_models/gpt/gpt_model.py @@ -282,6 +282,9 @@ class GPTModel(nn.Cell): logits = self.reshape(logits, (-1, logits.shape[-1])) logits = self.cast(logits, dtype.float32) + if labels is None: + return logits + if not self.training: return logits.contiguous() @@ -395,8 +398,9 @@ class GPTModel(nn.Cell): """ if loss_mask is None: loss_mask = self.cast(self.not_equal(input_ids, self.pad_token_id), dtype.float32) - label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) - loss_mask = self.mul(loss_mask, label_mask) + if labels is not None: + label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) + loss_mask = self.mul(loss_mask, label_mask) if self.use_attn_mask_compression: attention_mask = self.casual_mask() elif attention_mask is None: -- Gitee