From f74d14504a7145d299104f5d321fa9a0cdb32271 Mon Sep 17 00:00:00 2001 From: chenzeng Date: Tue, 9 Sep 2025 14:56:27 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E9=80=82=E9=85=8DDPO=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E4=B8=ADCP=E7=9A=84ring=20attention=E7=AE=97=E6=B3=95=E4=BB=A5?= =?UTF-8?q?=E5=8F=8Aattention=20mask=E5=9C=A8CP=E5=88=87=E5=88=86=E4=B8=8B?= =?UTF-8?q?=E6=98=BE=E5=AD=98=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_llm/tasks/posttrain/dpo/dpo_trainer.py | 11 ++++++++++- mindspeed_llm/training/utils.py | 15 +++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/mindspeed_llm/tasks/posttrain/dpo/dpo_trainer.py b/mindspeed_llm/tasks/posttrain/dpo/dpo_trainer.py index bf4e4e762..d5f1384f3 100644 --- a/mindspeed_llm/tasks/posttrain/dpo/dpo_trainer.py +++ b/mindspeed_llm/tasks/posttrain/dpo/dpo_trainer.py @@ -55,7 +55,9 @@ class DPOTrainer(BaseTrainer): if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): if args.variable_seq_lengths and args.pipeline_model_parallel_size > 2: tokens, attention_mask = get_finetune_data_on_this_tp_rank(data_iterator) - return tokens, None, attention_mask, None + batch = {'attention_mask': attention_mask} + batch = get_batch_on_this_cp_rank(batch) + return tokens, None, batch['attention_mask'], None else: # Broadcast data. data_b = tensor_parallel.broadcast_data(keys, next(data_iterator), data_type) @@ -134,6 +136,13 @@ class DPOTrainer(BaseTrainer): # Get the batch. self.timers('batch-generator', log_level=2).start() tokens, labels, attention_mask, position_ids = self.get_batch(data_iterator) + + if self.args.stage in ['dpo']: + if attention_mask is not None: + if isinstance(attention_mask, list): + attention_mask = [torch.cat((x, x), dim=0) for x in attention_mask] + else: + attention_mask = torch.cat((attention_mask, attention_mask), dim=0) self.timers('batch-generator').stop() output_tensor = self.hyper_model(tokens, position_ids, attention_mask) diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index 3713c84b1..611335fbf 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -176,8 +176,11 @@ def print_rank0_by_args(args, message): def get_tune_attention_mask(attention_mask_1d): args = get_args() + if args.stage in ['dpo']: + bsz_per_model = attention_mask_1d.shape[0] // 2 + attention_mask_1d = attention_mask_1d[0:bsz_per_model] micro_batch_size, seq_length = attention_mask_1d.size() - if args.reset_attention_mask: + if args.reset_attention_mask or args.stage in ['dpo']: att_mask_batch = micro_batch_size else: att_mask_batch = 1 @@ -190,11 +193,11 @@ def get_tune_attention_mask(attention_mask_1d): attention_mask_tran = attention_mask_1d.view(seq_length, 1, -1) attention_mask = attention_mask.masked_fill((attention_mask_tran < 0.5).view(-1, 1, 1, seq_length), value=0) else: - attention_mask = torch.tril(torch.ones( - (att_mask_batch, seq_length, seq_length), device=attention_mask_1d.device, dtype=torch.bool)).view( - att_mask_batch, 1, seq_length, seq_length) - attention_mask = attention_mask.masked_fill((attention_mask_1d < 0.5).view(-1, 1, 1, seq_length), value=0) - attention_mask = ~attention_mask + attention_mask = torch.ones((att_mask_batch, seq_length, seq_length), + device=attention_mask_1d.device, + dtype=torch.bool).tril_().view(att_mask_batch, 1, seq_length, seq_length) + attention_mask = attention_mask.masked_fill_(attention_mask_1d.bool().bitwise_not_().view(-1, 1, 1, seq_length), value=0) + attention_mask.bitwise_not_() return attention_mask -- Gitee From 0fff0f5346f586ae9af50f375f62ad3eebb77b08 Mon Sep 17 00:00:00 2001 From: chenzeng Date: Tue, 9 Sep 2025 17:39:13 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E9=80=82=E9=85=8DDPO=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E4=B8=ADCP=E7=9A=84ring=20attention=E7=AE=97=E6=B3=95=E4=BB=A5?= =?UTF-8?q?=E5=8F=8Aattention=20mask=E5=9C=A8CP=E5=88=87=E5=88=86=E4=B8=8B?= =?UTF-8?q?=E6=98=BE=E5=AD=98=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed_llm/training/utils.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index 611335fbf..9ab4bdaed 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -176,28 +176,22 @@ def print_rank0_by_args(args, message): def get_tune_attention_mask(attention_mask_1d): args = get_args() - if args.stage in ['dpo']: - bsz_per_model = attention_mask_1d.shape[0] // 2 - attention_mask_1d = attention_mask_1d[0:bsz_per_model] micro_batch_size, seq_length = attention_mask_1d.size() - if args.reset_attention_mask or args.stage in ['dpo']: - att_mask_batch = micro_batch_size - else: - att_mask_batch = 1 + + if args.stage in ['dpo']: + micro_batch_size = attention_mask_1d.shape[0] // 2 + attention_mask_1d = attention_mask_1d[:micro_batch_size] + + attention_mask = torch.ones((micro_batch_size, seq_length, seq_length), + device=attention_mask_1d.device, + dtype=torch.bool).tril_().view(micro_batch_size, 1, seq_length, seq_length) if args.tokenizer_padding_side == "left": - attention_mask = torch.tril( - torch.ones(seq_length, seq_length, device=attention_mask_1d.device, dtype=torch.bool)).view(1, 1, - seq_length, - seq_length) - attention_mask_tran = attention_mask_1d.view(seq_length, 1, -1) - attention_mask = attention_mask.masked_fill((attention_mask_tran < 0.5).view(-1, 1, 1, seq_length), value=0) - else: - attention_mask = torch.ones((att_mask_batch, seq_length, seq_length), - device=attention_mask_1d.device, - dtype=torch.bool).tril_().view(att_mask_batch, 1, seq_length, seq_length) + attention_mask_1d = attention_mask_1d.view(seq_length, 1, -1) + attention_mask = attention_mask.masked_fill_(attention_mask_1d.bool().bitwise_not_().view(-1, 1, 1, seq_length), value=0) attention_mask.bitwise_not_() + return attention_mask -- Gitee