diff --git a/inference_sora.py b/inference_sora.py index 81dd2bbb684ec94eec0836b6a11dbafd00f88d5a..55742a47580b2fe83143db12f50f72990423d961 100644 --- a/inference_sora.py +++ b/inference_sora.py @@ -62,6 +62,15 @@ def main(): device = get_device(args.device) prompts = load_prompts(args.prompt) + + # Generate args.num_inference_videos_per_sample inference videos for the same prompt. + if hasattr(args, "num_inference_videos_per_sample") and args.num_inference_videos_per_sample > 1: + prompts = [ + item + for item in prompts + for _ in range(args.num_inference_videos_per_sample) + ] + images = load_images(args.image) if hasattr(args, "image") else None conditional_pixel_values_path = load_conditional_pixel_values(args.conditional_pixel_values_path) if hasattr(args, "conditional_pixel_values_path") else None mask_type = args.mask_type if hasattr(args, "mask_type") else None diff --git a/mindspeed_mm/data/data_utils/constants.py b/mindspeed_mm/data/data_utils/constants.py index 06f76e8acac73c9b9ec0a188926ffc14b72c9fde..0ffad32c7044c9adb3b7542bfeada4a60a9f67a9 100644 --- a/mindspeed_mm/data/data_utils/constants.py +++ b/mindspeed_mm/data/data_utils/constants.py @@ -5,13 +5,17 @@ PROMPT_MASK_2 = "prompt_mask_2" PROMPT_IDS_2 = "prompt_ids_2" TEXT = "text" VIDEO = "video" +VIDEO_REJECTED = "video_rejected" PROMPT = "prompt" LATENTS = "latents" VIDEO_MASK = "video_mask" MASKED_VIDEO = "masked_video" INPUT_MASK = "input_mask" FILE_INFO = "file" +FILE_REJECTED_INFO = "file_rejected" CAPTIONS = "captions" +SCORE = "score" +SCORE_REJECTED = "score_rejected" IMG_FPS = 120 SORA_MODEL_PROTECTED_KEYS = [ PROMPT_MASK, diff --git a/mindspeed_mm/data/datasets/t2v_dataset.py b/mindspeed_mm/data/datasets/t2v_dataset.py index 2c119d43345557d05a5f3277ad88a1580fd6c611..e5831169a80496b10c8b847268c5dc13df59f70b 100644 --- a/mindspeed_mm/data/datasets/t2v_dataset.py +++ b/mindspeed_mm/data/datasets/t2v_dataset.py @@ -15,12 +15,16 @@ from megatron.core import mpu from mindspeed_mm.data.data_utils.constants import ( CAPTIONS, FILE_INFO, + FILE_REJECTED_INFO, PROMPT_IDS, PROMPT_MASK, TEXT, VIDEO, + VIDEO_REJECTED, IMG_FPS, VIDEO_MASK, + SCORE, + SCORE_REJECTED, SORA_MODEL_PROTECTED_KEYS ) from mindspeed_mm.data.data_utils.utils import ( @@ -249,6 +253,16 @@ class T2VDataset(MMBaseDataset): else self.get_value_from_vid_or_img(path) ) examples[VIDEO] = video_value + if FILE_REJECTED_INFO in sample.keys(): + video_rejected_path = os.path.join(self.data_folder, sample[FILE_REJECTED_INFO]) + video_rejected_value = ( + self.get_vid_img_fusion(video_rejected_path) + if self.vid_img_fusion_by_splicing + else self.get_value_from_vid_or_img(video_rejected_path) + ) + examples[VIDEO_REJECTED] = video_rejected_value + examples[SCORE] = sample[SCORE] + examples[SCORE_REJECTED] = sample[SCORE_REJECTED] if self.use_text_processer: prompt_ids, prompt_mask = self.get_text_processer(texts) examples[PROMPT_IDS], examples[PROMPT_MASK] = ( diff --git a/mindspeed_mm/models/text_encoder/stepllm_tokenizer.py b/mindspeed_mm/models/text_encoder/stepllm_tokenizer.py index 91d5594a973b598735963bcb1beb5740b105eaf0..3bb4fa2841e49e5e2aa7cebff67836ddf5175984 100644 --- a/mindspeed_mm/models/text_encoder/stepllm_tokenizer.py +++ b/mindspeed_mm/models/text_encoder/stepllm_tokenizer.py @@ -127,6 +127,8 @@ class WrappedStepChatTokenizer(StepChatTokenizer): self.PAD = 2 out_tokens = [] attn_mask = [] + if not isinstance(text, list): + text = [text] if len(text) == 0: part_tokens = [self.BOS] + [self.EOS] valid_size = len(part_tokens) diff --git a/mindspeed_mm/tasks/rl/dpo/dpo_trainer.py b/mindspeed_mm/tasks/rl/dpo/dpo_trainer.py index 1b3c2f935a53d1454a630fc9112b9513d5827f51..1e34ffe304a20f35cba20b2cefa39510b068ec64 100644 --- a/mindspeed_mm/tasks/rl/dpo/dpo_trainer.py +++ b/mindspeed_mm/tasks/rl/dpo/dpo_trainer.py @@ -278,16 +278,6 @@ class DPOTrainer(ABC): """ raise NotImplementedError("Subclasses must implement this method") - @abstractmethod - def _init_reference_model(self): - """ - Initializes the reference model frozen. - - Returns: - The initialized reference model. - """ - raise NotImplementedError("Subclasses must implement this method") - def loss_func(self, input_tensor: torch.Tensor, output_tensor: torch.Tensor): """DPO Loss function. @@ -337,23 +327,23 @@ class DPOTrainer(ABC): def dpo_loss( self, - policy_chosen_log_probs: torch.Tensor, - policy_rejected_log_probs: torch.Tensor, - reference_chosen_log_probs: torch.Tensor, - reference_rejected_log_probs: torch.Tensor, + policy_chosen_loss: torch.Tensor, + policy_rejected_loss: torch.Tensor, + reference_chosen_loss: torch.Tensor, + reference_rejected_loss: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: """ Compute the DPO loss for a batch of policy and reference model log probabilities. Args: - policy_chosen_log_probs: - Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) - policy_rejected_log_probs: - Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) - reference_chosen_log_probs: - Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) - reference_rejected_log_probs: - Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + policy_chosen_loss: + Log probabilities or mean squared error of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_loss: + Log probabilities or mean squared error of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_loss: + Log probabilities or mean squared error of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_loss: + Log probabilities or mean squared error of the reference model for the rejected responses. Shape: (batch_size,) Returns: A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). @@ -361,17 +351,17 @@ class DPOTrainer(ABC): The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. """ - pi_log_ratios = policy_chosen_log_probs - policy_rejected_log_probs - ref_log_ratios = reference_chosen_log_probs - reference_rejected_log_probs - logits = pi_log_ratios - ref_log_ratios + policy_ratios = policy_chosen_loss - policy_rejected_loss + ref_ratios = reference_chosen_loss - reference_rejected_loss + loss_diff = policy_ratios - ref_ratios # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. # We ignore the reference model as beta -> 0. # The label_smoothing parameter encodes our uncertainty about the labels and calculates a conservative DPO loss. if self.args.dpo_loss_type == "sigmoid": losses = ( - -F.logsigmoid(self.args.dpo_beta * logits) * (1 - self.args.dpo_label_smoothing) - - F.logsigmoid(-self.args.dpo_beta * logits) * self.args.dpo_label_smoothing + -F.logsigmoid(self.args.dpo_beta * loss_diff) * (1 - self.args.dpo_label_smoothing) + - F.logsigmoid(-self.args.dpo_beta * loss_diff) * self.args.dpo_label_smoothing ) else: raise ValueError( @@ -382,13 +372,13 @@ class DPOTrainer(ABC): chosen_rewards = ( self.args.dpo_beta * ( - policy_chosen_log_probs - reference_chosen_log_probs + policy_chosen_loss - reference_chosen_loss ).detach() ) rejected_rewards = ( self.args.dpo_beta * ( - policy_rejected_log_probs - reference_rejected_log_probs + policy_rejected_loss - reference_rejected_loss ).detach() ) @@ -455,7 +445,7 @@ class DPOTrainer(ABC): all_reference_logits, label ) - + # compute DPO loss losses, chosen_rewards, rejected_rewards = self.compute_preference_loss( policy_chosen_log_probs, policy_rejected_log_probs, @@ -474,7 +464,7 @@ class DPOTrainer(ABC): return losses.mean(), metrics - def _compute_log_probs(self, all_logits, label) -> Tuple[torch.Tensor, ...]: + def _compute_log_probs(self, all_logits, label=None) -> Tuple[torch.Tensor, ...]: """ Computes the sum log probabilities of the labels under given logits if loss_type. Otherwise, the average log probabilities. diff --git a/mindspeed_mm/tasks/rl/dpo/stepvideo_dpo_model.py b/mindspeed_mm/tasks/rl/dpo/stepvideo_dpo_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3d50b843346e8c09d284d15d558dc0a4755295ac --- /dev/null +++ b/mindspeed_mm/tasks/rl/dpo/stepvideo_dpo_model.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +import torch +from torch import nn + +from megatron.training import get_args, print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args + +from mindspeed_mm.models.ae import AEModel +from mindspeed_mm.models.diffusion import DiffusionModel +from mindspeed_mm.models.predictor import PredictModel +from mindspeed_mm.models.text_encoder import TextEncoder + + +class StepVideoDPOModel(nn.Module): + """ + The hyper model wraps multiple models required in reinforcement learning into a single model, + maintaining the original distributed perspective unchanged. + """ + + def __init__(self, config): + super().__init__() + self.config = core_transformer_config_from_args(get_args()) + self._model_provider(config) + + def _model_provider(self, config): + """Builds the model.""" + + print_rank_0("building StepVideo related modules ...") + self.ae = AEModel(config.ae).eval() + self.ae.requires_grad_(False) + + self.text_encoder = TextEncoder(config.text_encoder).eval() + self.text_encoder.requires_grad_(False) + + self.diffusion = DiffusionModel(config.diffusion).get_model() + + self.reference = PredictModel(config.predictor).get_model().eval() + self.reference.requires_grad_(False) + + self.actor = PredictModel(config.predictor).get_model() + print_rank_0("finish building StepVideo related modules ...") + + return None + + def set_input_tensor(self, input_tensor): + self.input_tensor = input_tensor + self.actor.set_input_tensor(input_tensor) + + def forward(self, video, video_lose, prompt_ids, video_mask=None, prompt_mask=None, **kwargs): + latents, _ = self.ae.encode(video) + latents_lose, _ = self.ae.encode(video_lose) + noised_latents, noise, timesteps = self.diffusion.q_sample(torch.cat((latents, latents_lose), dim=0), model_kwargs=kwargs, mask=video_mask) + prompts = self.text_encoder.encode(prompt_ids, prompt_mask) + prompt = [torch.cat((prompt, prompt), dim=0) for prompt in prompts] + + with torch.no_grad(): + refer_output = self.reference( + noised_latents, + timestep=timesteps, + encoder_hidden_states=prompt, + video_mask=video_mask, + encoder_attention_mask=prompt_mask, + **kwargs, + ) + actor_output = self.actor( + noised_latents, + timestep=timesteps, + encoder_hidden_states=prompt, + video_mask=video_mask, + encoder_attention_mask=prompt_mask, + **kwargs, + ) + output = torch.cat((refer_output, actor_output), dim=0) + + return output, torch.cat((latents, latents_lose), dim=0), noised_latents, noise, timesteps + diff --git a/mindspeed_mm/tasks/rl/dpo/stepvideo_dpo_trainer.py b/mindspeed_mm/tasks/rl/dpo/stepvideo_dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e389595b7be1930a5d244ff26ac068ee6e24f779 --- /dev/null +++ b/mindspeed_mm/tasks/rl/dpo/stepvideo_dpo_trainer.py @@ -0,0 +1,207 @@ +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +import os +import math +from functools import partial + +import torch + +from megatron.training import get_args +from megatron.training.global_vars import set_args +from megatron.training.utils import average_losses_across_data_parallel_group +from megatron.training.utils import print_rank_0 + +from mindspeed_mm.data.data_utils.constants import ( + VIDEO, + PROMPT_IDS, + PROMPT_MASK, + VIDEO_MASK, + VIDEO_REJECTED, + SCORE, + SCORE_REJECTED +) +from mindspeed_mm.tasks.rl.dpo.dpo_trainer import DPOTrainer +from mindspeed_mm.tasks.rl.dpo.stepvideo_dpo_model import StepVideoDPOModel +from mindspeed_mm.tasks.rl.utils import read_json_file, find_probability + + +class StepVideoDPOTrainer(DPOTrainer): + """ + A trainer class for Direct Preference Optimization (DPO). + + This class provides methods for model initialize, computing losses and metrics, and training. + """ + + def __init__( + self, + train_valid_test_dataset_provider, + model_type, + process_non_loss_data_func=None, + extra_args_provider=None, + args_defaults=None, + ): + """ + Initializes the DPOTrainer instance. + + Sets up the instance variables for the model provider, actual micro batch size, + and initializes the DPO model. + """ + super().__init__( + train_valid_test_dataset_provider, + model_type, + process_non_loss_data_func, + extra_args_provider, + args_defaults + ) + + args = get_args() + self.histgram = read_json_file(args.mm.model.dpo.histgram_path) + self.alpha = args.mm.model.dpo.weight_alpha + self.beta = args.mm.model.dpo.weight_beta if args.mm.model.dpo.weight_beta else self.histgram['max_num'] / self.histgram['total_num'] + self.dpo_beta = args.mm.model.dpo.loss_beta + self.args.actual_micro_batch_size = self.args.micro_batch_size * 4 + self.disable_dropout() + + def disable_dropout(self): + """ + disable dropout + """ + args_ = get_args() + args_.attention_dropout = 0.0 + args_.hidden_dropout = 0.0 + args_.retro_encoder_hidden_dropout = 0.0 + args_.retro_encoder_attention_dropout = 0.0 + set_args(args_) + + @staticmethod + def get_batch(data_iterator): + """Generate a batch.""" + if data_iterator is not None: + batch = next(data_iterator) + else: + raise ValueError("Data iterator is None. Unable to retrieve batch.") + + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(torch.cuda.current_device()) + + video = batch.pop(VIDEO, None) + prompt_ids = batch.pop(PROMPT_IDS, None) + video_mask = batch.pop(VIDEO_MASK, None) + prompt_mask = batch.pop(PROMPT_MASK, None) + video_lose = batch.pop(VIDEO_REJECTED, None) + score = batch.pop(SCORE, 1.0) + score_lose = batch.pop(SCORE_REJECTED, 1.0) + + args = get_args() + print(args.params_dtype) + + video = video.to(args.params_dtype) + video_lose = video_lose.to(args.params_dtype) + + return video, video_lose, prompt_ids, None, prompt_mask, score, score_lose + + def model_provider(self, **kwargs): + args = get_args() + print_rank_0("building StepVideoDPO model ...") + self.hyper_model = StepVideoDPOModel(args.mm.model) + return self.hyper_model + + def forward_step(self, data_iterator, model): + """DPO Forward training step. + + Args: + data_iterator : Input data iterator + model : vlm model + """ + # Get the batch. + video, video_lose, prompt_ids, video_mask, prompt_mask, score, score_lose = self.get_batch(data_iterator) + + output_tensor, latents, noised_latents, noise, timesteps = model(video=video, video_lose=video_lose, prompt_ids=prompt_ids, video_mask=video_mask, prompt_mask=prompt_mask) + + return output_tensor, partial(self.loss_func, latents, noised_latents, noise, timesteps, score, score_lose) + + def loss_func( + self, + latents: torch.Tensor, + noised_latents: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + score_win: torch.Tensor, + score_lose: torch.Tensor, + output_tensor: torch.Tensor + ): + args = get_args() + actor_output, refer_output = torch.chunk(output_tensor, 2, dim=0) + refer_output = refer_output.detach() + + loss, metrics = self.get_batch_loss_metrics(actor_output, refer_output, + latents=latents, noised_latents=noised_latents, timesteps=timesteps, noise=noise, video_mask=None, score_win=score_win, score_lose=score_lose) + + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + if loss.isnan(): + raise ValueError(f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}') + + # Reduce loss for logging. + metrics['dpo loss'] = average_losses_across_data_parallel_group([loss]) + for key in metrics.keys(): + metrics[key] = average_losses_across_data_parallel_group([metrics[key]]) + + return loss, metrics + + def get_batch_loss_metrics( + self, + actor_output, + refer_output, + **kwargs + ): + metrics = {} + + # compute L2 + actor_chosen_loss, actor_rejected_loss, actor_chosen_loss_avg = self._compute_log_probs(actor_output, **kwargs) + + refer_chosen_loss, refer_rejected_loss, *_ = self._compute_log_probs(refer_output, **kwargs) + # compute DPO loss + losses, chosen_rewards, rejected_rewards = self.compute_preference_loss( + actor_chosen_loss, + actor_rejected_loss, + refer_chosen_loss, + refer_rejected_loss, + ) + pair_prob = math.sqrt(find_probability(kwargs['score_win'], self.histgram) * find_probability(kwargs['score_lose'], self.histgram)) + weight_pair = math.pow((self.beta / max(pair_prob, 1e-3)), self.alpha) + losses = losses * weight_pair + + sft_loss = -actor_chosen_loss_avg + if self.args.pref_ftx > 1e-6: + losses += self.args.pref_ftx * sft_loss + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + prefix = "" + metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.detach().mean() + + return losses.mean(), metrics + + def _compute_log_probs(self, output, **kwargs): + """ + Computes the sum log probabilities of the labels under given logits if loss_type. + Otherwise, the average log probabilities. + Assuming IGNORE_INDEX is all negative numbers, the default is -100. + + Args: + all_logits: The logits tensor. + + Returns: + A tuple containing the log probabilities and other tensors. + """ + # # SNR is determined by snr_gamma in config and has been multiplied in training_losses. + latents, noised_latents, timesteps, noise, video_mask = kwargs['latents'], kwargs['noised_latents'], kwargs['timesteps'], kwargs['noise'], kwargs['video_mask'] + + l2_loss = self.hyper_model.diffusion.training_losses(model_output=output, x_start=latents, x_t=noised_latents, t=timesteps, noise=noise, mask=None) + + chosen_l2_losses, rejected_l2_losses = torch.chunk(- self.dpo_beta * timesteps * l2_loss, 2, dim=0) + + all_results = (chosen_l2_losses, rejected_l2_losses, chosen_l2_losses) + + return all_results \ No newline at end of file diff --git a/mindspeed_mm/tasks/rl/utils.py b/mindspeed_mm/tasks/rl/utils.py index 38ea2b8de3d5187b940c0aa2a174110c6ddee745..94ae380c354ec3fff7284d84b4c40ee7421a5987 100644 --- a/mindspeed_mm/tasks/rl/utils.py +++ b/mindspeed_mm/tasks/rl/utils.py @@ -1,11 +1,30 @@ # Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. from typing import Tuple +import json +import math import torch from megatron.core import mpu +def read_json_file(filename): + """Reade JSON File""" + with open(filename, 'r', encoding='utf-8') as file: + data = json.load(file) + return data + + +def find_probability(score, data): + bin_index = math.floor(score / data['bin_width']) + lower = bin_index * data['bin_width'] + upper = lower + data['bin_width'] + key = f"{lower}-{upper}" + if key in data: + return data[key] / data['total_num'] # Probability + return 0.0 # If score is out of bounds + + def get_attr_from_wrapped_model(model, target_attr): def recursive_search(module): if hasattr(module, target_attr): diff --git a/posttrain_stepvideo_dpo.py b/posttrain_stepvideo_dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..987911dd62f061c23441c5a21ce58223ca089c54 --- /dev/null +++ b/posttrain_stepvideo_dpo.py @@ -0,0 +1,40 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +"""Posttrain StepVideo DPO.""" + +import mindspeed.megatron_adaptor # noqa + +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.training import get_args, print_rank_0 + +from mindspeed_mm.configs.config import mm_extra_args_provider +from mindspeed_mm.data import build_mm_dataloader, build_mm_dataset +from mindspeed_mm.data.data_utils.utils import build_iterations +from mindspeed_mm.patchs import dummy_optimizer_patch # noqa +from mindspeed_mm.tasks.rl.dpo.stepvideo_dpo_trainer import StepVideoDPOTrainer + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + data_config = args.mm.data + train_dataset = build_mm_dataset(data_config.dataset_param) + train_dataloader = build_mm_dataloader(train_dataset, data_config.dataloader_param, + process_group=mpu.get_data_parallel_group(), + dataset_param=data_config.dataset_param, + consumed_samples=args.consumed_train_samples,) + train_dataloader, val_dataloader, test_dataloader = build_iterations(train_dataloader) + return train_dataloader, val_dataloader, test_dataloader + + +if __name__ == "__main__": + train_valid_test_datasets_provider.is_distributed = True + + trainer = StepVideoDPOTrainer( + train_valid_test_dataset_provider=train_valid_test_datasets_provider, + model_type=ModelType.encoder_or_decoder, + extra_args_provider=mm_extra_args_provider, + args_defaults={"dataloader_type": "external"}, + ) + trainer.train()