diff --git a/mindspeed_llm/tasks/posttrain/dpo/dpo_trainer.py b/mindspeed_llm/tasks/posttrain/dpo/dpo_trainer.py index bf4e4e762146d07ffb2e0792b130c4ca50051c97..d5f1384f390693ff6de29151b9d0269e4b407447 100644 --- a/mindspeed_llm/tasks/posttrain/dpo/dpo_trainer.py +++ b/mindspeed_llm/tasks/posttrain/dpo/dpo_trainer.py @@ -55,7 +55,9 @@ class DPOTrainer(BaseTrainer): if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): if args.variable_seq_lengths and args.pipeline_model_parallel_size > 2: tokens, attention_mask = get_finetune_data_on_this_tp_rank(data_iterator) - return tokens, None, attention_mask, None + batch = {'attention_mask': attention_mask} + batch = get_batch_on_this_cp_rank(batch) + return tokens, None, batch['attention_mask'], None else: # Broadcast data. data_b = tensor_parallel.broadcast_data(keys, next(data_iterator), data_type) @@ -134,6 +136,13 @@ class DPOTrainer(BaseTrainer): # Get the batch. self.timers('batch-generator', log_level=2).start() tokens, labels, attention_mask, position_ids = self.get_batch(data_iterator) + + if self.args.stage in ['dpo']: + if attention_mask is not None: + if isinstance(attention_mask, list): + attention_mask = [torch.cat((x, x), dim=0) for x in attention_mask] + else: + attention_mask = torch.cat((attention_mask, attention_mask), dim=0) self.timers('batch-generator').stop() output_tensor = self.hyper_model(tokens, position_ids, attention_mask) diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index 3713c84b1eee9a2816003f181b8e4335f4626e59..9ab4bdaed705b0b0ee97a2f0b1b72821820af80e 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -177,24 +177,21 @@ def print_rank0_by_args(args, message): def get_tune_attention_mask(attention_mask_1d): args = get_args() micro_batch_size, seq_length = attention_mask_1d.size() - if args.reset_attention_mask: - att_mask_batch = micro_batch_size - else: - att_mask_batch = 1 + + if args.stage in ['dpo']: + micro_batch_size = attention_mask_1d.shape[0] // 2 + attention_mask_1d = attention_mask_1d[:micro_batch_size] + + attention_mask = torch.ones((micro_batch_size, seq_length, seq_length), + device=attention_mask_1d.device, + dtype=torch.bool).tril_().view(micro_batch_size, 1, seq_length, seq_length) if args.tokenizer_padding_side == "left": - attention_mask = torch.tril( - torch.ones(seq_length, seq_length, device=attention_mask_1d.device, dtype=torch.bool)).view(1, 1, - seq_length, - seq_length) - attention_mask_tran = attention_mask_1d.view(seq_length, 1, -1) - attention_mask = attention_mask.masked_fill((attention_mask_tran < 0.5).view(-1, 1, 1, seq_length), value=0) - else: - attention_mask = torch.tril(torch.ones( - (att_mask_batch, seq_length, seq_length), device=attention_mask_1d.device, dtype=torch.bool)).view( - att_mask_batch, 1, seq_length, seq_length) - attention_mask = attention_mask.masked_fill((attention_mask_1d < 0.5).view(-1, 1, 1, seq_length), value=0) - attention_mask = ~attention_mask + attention_mask_1d = attention_mask_1d.view(seq_length, 1, -1) + + attention_mask = attention_mask.masked_fill_(attention_mask_1d.bool().bitwise_not_().view(-1, 1, 1, seq_length), value=0) + attention_mask.bitwise_not_() + return attention_mask