diff --git a/mindformers/pynative/base_models/gpt/gpt_model.py b/mindformers/pynative/base_models/gpt/gpt_model.py index bcc978c5d2e7c0de38f2ba49f760dbc604782f5f..5645724d83b1f4320a16bbf67c5966a0b8695dc2 100644 --- a/mindformers/pynative/base_models/gpt/gpt_model.py +++ b/mindformers/pynative/base_models/gpt/gpt_model.py @@ -282,6 +282,9 @@ class GPTModel(nn.Cell): logits = self.reshape(logits, (-1, logits.shape[-1])) logits = self.cast(logits, dtype.float32) + if labels is None: + return logits + if not self.training: return logits.contiguous() @@ -395,8 +398,9 @@ class GPTModel(nn.Cell): """ if loss_mask is None: loss_mask = self.cast(self.not_equal(input_ids, self.pad_token_id), dtype.float32) - label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) - loss_mask = self.mul(loss_mask, label_mask) + if labels is not None: + label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), dtype.float32) + loss_mask = self.mul(loss_mask, label_mask) if self.use_attn_mask_compression: attention_mask = self.casual_mask() elif attention_mask is None: