From f7a4ad3488fc527536533327d7ef0df33b10167d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=87=AF=E5=AE=87?= Date: Wed, 7 May 2025 15:09:15 +0800 Subject: [PATCH] [built-in][PyTorch][OpenRLHF] OpenRLHF-v0.5.7 support SFT and DPO training for Qwen2VL --- .../vision_scripts/llava_zh_300k.json | 12 + .../examples/vision_scripts/rlhf_v.json | 9 + .../vision_scripts/train_dpo_qwen2vl.sh | 39 ++ .../vision_scripts/train_sft_qwen2vl.sh | 34 ++ .../openrlhf/cli/train_vl_dpo.py | 270 +++++++++ .../openrlhf/cli/train_vl_sft.py | 242 +++++++++ .../openrlhf/datasets/__init__.py | 3 +- .../openrlhf/datasets/vl_dataset.py | 511 ++++++++++++++++++ .../openrlhf/models/actor.py | 56 +- .../openrlhf/trainer/__init__.py | 6 +- .../openrlhf/trainer/dpo_trainer.py | 248 ++++++++- .../openrlhf/trainer/sft_trainer.py | 183 ++++++- .../openrlhf/utils/__init__.py | 5 + .../openrlhf/utils/utils.py | 11 +- .../openrlhf/utils/vision_args.py | 41 ++ 15 files changed, 1630 insertions(+), 40 deletions(-) create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/llava_zh_300k.json create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/rlhf_v.json create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/train_dpo_qwen2vl.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/train_sft_qwen2vl.sh create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_dpo.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_sft.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/vl_dataset.py create mode 100644 PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_args.py diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/llava_zh_300k.json b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/llava_zh_300k.json new file mode 100644 index 0000000000..16886c6651 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/llava_zh_300k.json @@ -0,0 +1,12 @@ +{ + "columns": { + "messages": "messages", + "images": "images" + }, + "tags": { + "role_tag": "role", + "content_tag": "content", + "user_tag": "user", + "assistant_tag": "assistant" + } +} \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/rlhf_v.json b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/rlhf_v.json new file mode 100644 index 0000000000..6db7beec37 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/rlhf_v.json @@ -0,0 +1,9 @@ +{ + "ranking": true, + "columns": { + "messages": "conversations", + "chosen": "chosen", + "rejected": "rejected", + "images": "images" + } +} \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/train_dpo_qwen2vl.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/train_dpo_qwen2vl.sh new file mode 100644 index 0000000000..125e84a2c1 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/train_dpo_qwen2vl.sh @@ -0,0 +1,39 @@ +set -x + +export ACLNN_CACHE_LIMIT=100000 +export COMBINED_ENABLE=1 +export TASK_QUEUE_ENABLE=2 +export HF_DATASETS_OFFLINE=1 + +read -r -d '' training_commands < 1: + assert args.packing_samples, "packing_samples must be enabled when using ring attention" + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_sft.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_sft.py new file mode 100644 index 0000000000..f98a7f8f70 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_sft.py @@ -0,0 +1,242 @@ +import argparse +import math +import os +from datetime import datetime + +import torch +from transformers.trainer import get_scheduler + +from openrlhf.datasets import build_train_and_valid_datasets, build_data_collator +from openrlhf.models import Actor +from openrlhf.trainer import VLSFTTrainer +from openrlhf.utils import ( + get_strategy, + get_tokenizer, + get_vision_processor, + get_qwen2_vl_utils, + add_vision_args, +) + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + if torch.distributed.get_rank() == 0: + print(f"Running args {args}") + + # configure model + # load huggingface model + model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=True), + packing_samples=args.packing_samples, + create_vison_model=args.model_arch in ['qwen2_vl'], + ) + # configure tokenizer + tokenizer = get_tokenizer(args.pretrain, model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) + strategy.print(model) + + # configure processor + vision_processor = get_vision_processor(args, args.pretrain, tokenizer) + if args.model_arch == "qwen2_vl": + encoder_utils = get_qwen2_vl_utils(args) + else: + raise NotImplementedError(f"no support model arch {args.model_arch=}") + + # gradient_checkpointing + if args.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # configure optimizer + optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) + + # repace datasets + assert args.task_type == "sft", f"the script is used for SFT training" + train_dataset, eval_dataset = build_train_and_valid_datasets( + args, tokenizer, processor=vision_processor, encoder_utils=encoder_utils, strategy=strategy) + + data_collator = build_data_collator(args, tokenizer, encoder_utils, vision_processor) + + # prepare dataloader + train_dataloader = strategy.setup_dataloader( + train_dataset, + args.micro_train_batch_size, + True, + shuffle=True, + collate_fn=data_collator, + ) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, + args.micro_train_batch_size, + True, + False, + collate_fn=data_collator, + ) + + # scheduler + num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + + scheduler = get_scheduler( + args.lr_scheduler, + optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, + ) + + # prepare models + (model, optim, scheduler) = strategy.prepare((model, optim, scheduler)) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(args.ckpt_path): + _, states = strategy.load_ckpt(model.model, args.ckpt_path) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + # configure Trainer + trainer = VLSFTTrainer( + model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + scheduler=scheduler, + max_norm=args.max_norm, + pretrain_mode=args.pretrain_mode, + batch_size=args.train_batch_size, + max_epochs=args.max_epochs, + tokenizer=tokenizer, + ) + + trainer.fit(args, consumed_samples, num_update_steps_per_epoch) + + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, tokenizer, args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_sft") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DeepSpeed + parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--flash_attn", type=str, default="eager", help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + + # SFT + parser.add_argument("--max_epochs", type=int, default=2) + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--learning_rate", type=float, default=5e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--pretrain_mode", action="store_true", default=False, help="Use pretrain loss") + parser.add_argument("--lr_scheduler", type=str, default="cosine_with_min_lr") + parser.add_argument("--l2", type=float, default=0, help="weight decay loss") + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + + # ring-attention + parser.add_argument("--ring_attn_size", type=int, default=1, help="Ring attention group size") + parser.add_argument( + "--ring_head_stride", + type=int, + default=1, + help="the number of heads to do ring attention each time. " + "It should be a divisor of the number of heads. " + "A larger value may results in faster training but will consume more memory.", + ) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # packing SFT samples without CrossAttention + parser.add_argument("--packing_samples", action="store_true", default=False) + + # custom dataset + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") + parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") + parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") + + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--output_key", type=str, default=None, help="JSON dataset key") + parser.add_argument("--input_template", type=str, default="User: {}\nAssistant: ") + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + parser.add_argument("--tokenizer_chat_template", type=str, default=None) + parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") + parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples") + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_sft") + parser.add_argument( + "--wandb_run_name", + type=str, + default="sft_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + parser = add_vision_args(parser) + args = parser.parse_args() + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + if args.packing_samples and not args.flash_attn: + print("[Warning] Please --flash_attn to accelerate when --packing_samples is enabled.") + args.flash_attn = True + + if args.ring_attn_size > 1: + assert args.packing_samples, "packing_samples must be enabled when using ring attention" + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py index bbb762f1ea..5877886453 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py @@ -3,5 +3,6 @@ from .prompts_dataset import PromptDataset from .reward_dataset import RewardDataset from .sft_dataset import SFTDataset from .unpaired_preference_dataset import UnpairedPreferenceDataset +from .vl_dataset import build_train_and_valid_datasets, build_data_collator -__all__ = ["ProcessRewardDataset", "PromptDataset", "RewardDataset", "SFTDataset", "UnpairedPreferenceDataset"] +__all__ = ["ProcessRewardDataset", "PromptDataset", "RewardDataset", "SFTDataset", "UnpairedPreferenceDataset", "build_train_and_valid_datasets", "build_data_collator"] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/vl_dataset.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/vl_dataset.py new file mode 100644 index 0000000000..4e9ec16497 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/vl_dataset.py @@ -0,0 +1,511 @@ +import os +from typing import Any, Dict, Literal, Optional, Sequence, Union, List, Tuple +from functools import partial +from dataclasses import dataclass +from collections import defaultdict + +import torch +from transformers import DataCollatorForSeq2Seq, ProcessorMixin + +from openrlhf.utils.utils import blending_datasets +from openrlhf.utils.vision_utils import ( + IGNORE_INDEX, ImageInput, + VisionEncoderUtils, DatasetAttr, + get_dataset_attr, +) + + +def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: + r""" + Computes the real sequence length after truncation by the cutoff_len. + """ + if target_len * 2 < cutoff_len: # truncate source + max_target_len = cutoff_len + elif source_len * 2 < cutoff_len: # truncate target + max_target_len = cutoff_len - source_len + else: # truncate both + max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) + + new_target_len = min(max_target_len, target_len) + max_source_len = max(cutoff_len - new_target_len, 0) + new_source_len = min(max_source_len, source_len) + return new_source_len, new_target_len + + +def _convert_images( + images: Union[ImageInput, Sequence[ImageInput]], + dataset_attr: DatasetAttr, +) -> Optional[List[ImageInput]]: + r""" + Optionally concatenates image path to dataset dir when loading from local disk. + """ + if not isinstance(images, list): + images = [images] + elif len(images) == 0: + return None + else: + images = images[:] + + return images + + +def convert_sharegpt( + example: Dict[str, Any], + dataset_attr: DatasetAttr +) -> Dict[str, Any]: + r""" + Converts sharegpt format dataset to the standard format. + """ + tag_mapping = { + dataset_attr.user_tag: "user", + dataset_attr.assistant_tag: "assistant", + dataset_attr.observation_tag: "observation", + dataset_attr.function_tag: "function", + dataset_attr.system_tag: "system", + } + odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) + even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) + accept_tags = (odd_tags, even_tags) + messages = example[dataset_attr.messages] + if ( + dataset_attr.system_tag + and len(messages) != 0 + and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag + ): + system = messages[0][dataset_attr.content_tag] + messages = messages[1:] + else: + system = example[dataset_attr.system] if dataset_attr.system else "" + + aligned_messages = [] + broken_data = False + for turn_idx, message in enumerate(messages): + if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: + print(f"Invalid role tag in {message}.") + broken_data = True + + aligned_messages.append( + {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} + ) + + if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( + dataset_attr.ranking and len(aligned_messages) % 2 == 0 + ): + print(f"Invalid message count in {messages}.") + broken_data = True + + if ( + dataset_attr.ranking + and isinstance(example[dataset_attr.chosen], dict) + and isinstance(example[dataset_attr.rejected], dict) + ): # pairwise example + chosen = example[dataset_attr.chosen] + rejected = example[dataset_attr.rejected] + if ( + chosen[dataset_attr.role_tag] not in accept_tags[-1] + or rejected[dataset_attr.role_tag] not in accept_tags[-1] + ): + print(f"Invalid role tag in {[chosen, rejected]}.") + broken_data = True + + prompt = aligned_messages + response = [ + {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]}, + {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]}, + ] + else: # normal example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + + if broken_data: + print("Skipping this abnormal example.") + prompt, response = [], [] + + convert_images = partial(_convert_images, dataset_attr=dataset_attr) + output = { + "_prompt": prompt, + "_response": response, + "_system": system, + "_tools": example[dataset_attr.tools] if dataset_attr.tools else "", + "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None, + "_videos": None, + } + return output + + +def _encode_supervised_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: Sequence[ImageInput], + videos: Sequence, + encoder: VisionEncoderUtils, + tokenizer, + processor, + cutoff_len: int, + train_on_prompt: bool, + mask_history: bool, +) -> Tuple[List[int], List[int]]: + messages = encoder.mm_plugin.process_messages(prompt + response, images, videos, processor) + input_ids, labels = encoder.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor) + encoded_pairs = encoder.encode_multiturn(tokenizer, messages, system, tools) + total_length = len(input_ids) + (1 if encoder.efficient_eos else 0) + if mask_history: + encoded_pairs = encoded_pairs[::-1] # high priority for last turns + + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if total_length >= cutoff_len: + break + + source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + + if train_on_prompt: + source_label = source_ids + elif encoder.efficient_eos: + source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) + else: + source_label = [IGNORE_INDEX] * source_len + + if mask_history and turn_idx != 0: # train on the last turn only + target_label = [IGNORE_INDEX] * target_len + else: + target_label = target_ids + + if mask_history: # reversed sequences + input_ids = source_ids + target_ids + input_ids + labels = source_label + target_label + labels + else: + input_ids += source_ids + target_ids + labels += source_label + target_label + + if encoder.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + return input_ids, labels + + +def _encode_pairwise_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: Sequence[ImageInput], + videos: Sequence, + encoder: VisionEncoderUtils, + tokenizer, + processor, + cutoff_len: int, +) -> Tuple[List[int], List[int], List[int], List[int]]: + chosen_messages = encoder.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor) + rejected_messages = encoder.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor) + prompt_ids, chosen_ids = encoder.encode_oneturn(tokenizer, chosen_messages, system, tools) + _, rejected_ids = encoder.encode_oneturn(tokenizer, rejected_messages, system, tools) + + if encoder.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + + prompt_ids, _ = encoder.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor) + # consider the response is more important + source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) + prompt_ids = prompt_ids[:source_len] + chosen_ids = chosen_ids[:target_len] + rejected_ids = rejected_ids[:target_len] + + chosen_input_ids = prompt_ids + chosen_ids + chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids + rejected_input_ids = prompt_ids + rejected_ids + rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids + return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels + + +def preprocess_supervised_dataset( + examples: Dict[str, List[Any]], + encoder: VisionEncoderUtils, + tokenizer, + processor, + data_args, +) -> Dict[str, List[Any]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: + print( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels = _encode_supervised_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + encoder=encoder, + tokenizer=tokenizer, + processor=processor, + cutoff_len=data_args.max_len, + train_on_prompt=data_args.train_on_prompt, + mask_history=data_args.mask_history, + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + + return model_inputs + + +def preprocess_pairwise_dataset( + examples: Dict[str, List[Any]], + encoder: VisionEncoderUtils, + tokenizer, + processor, + data_args, +) -> Dict[str, List[Any]]: + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: + print( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + encoder=encoder, + tokenizer=tokenizer, + processor=processor, + cutoff_len=data_args.max_len, + ) + model_inputs["chosen_input_ids"].append(chosen_input_ids) + model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) + model_inputs["chosen_labels"].append(chosen_labels) + model_inputs["rejected_input_ids"].append(rejected_input_ids) + model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) + model_inputs["rejected_labels"].append(rejected_labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + + return model_inputs + + +def get_preprocessed_dataset(args, data_list, encoder, tokenizer, processor): + train_data, eval_data = data_list + + dataset_attr = get_dataset_attr(args.dataset_config_path) + + kwargs = dict( + num_proc=args.processing_num_workers, + load_from_cache_file=(not args.overwrite_cache) or (args.local_process_index != 0), + desc="Converting format of dataset", + ) + convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) + column_names = list(next(iter(train_data)).keys()) + train_data = train_data.map(convert_func, batched=False, remove_columns=column_names, + **kwargs) + if eval_data is not None: + eval_data = eval_data.map(convert_func, batched=False, remove_columns=column_names, + **kwargs) + + if args.task_type == "sft": + process_dataset_class = preprocess_supervised_dataset + elif args.task_type == "dpo": + process_dataset_class = preprocess_pairwise_dataset + else: + raise NotImplementedError(f"Unknown task_type: {args.task_type}") + + preprocess_func = partial( + process_dataset_class, + encoder=encoder, + tokenizer=tokenizer, + processor=processor, + data_args=args, + ) + kwargs.update({"desc": "Running tokenizer on dataset"}) + column_names = list(next(iter(train_data)).keys()) + train_data = train_data.map(preprocess_func, batched=True, + batch_size=args.preprocessing_batch_size, + remove_columns=column_names, **kwargs) + if eval_data is not None: + eval_data = eval_data.map(preprocess_func, batched=True, + batch_size=args.preprocessing_batch_size, + remove_columns=column_names, **kwargs) + return train_data, eval_data + + +# Copied from https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llamafactory/data/collator.py +@dataclass +class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + r""" + Data collator that supports VLMs. + + Features should contain input_ids, attention_mask, labels and images. + """ + + encoder_utils: Optional[VisionEncoderUtils] = None + vision_processor: Optional[ProcessorMixin] = None + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], [] + for feature in features: + images = feature.pop("images", None) or [] + videos = feature.pop("videos", None) or [] + batch_images.extend(images) + batch_videos.extend(videos) + batch_imglens.append(len(images)) + batch_vidlens.append(len(videos)) + batch_input_ids.append(feature["input_ids"]) + + mm_inputs = self.encoder_utils.mm_plugin.get_mm_inputs( + batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.vision_processor + ) + if "token_type_ids" in mm_inputs: + token_type_ids = mm_inputs.pop("token_type_ids") + for i, feature in enumerate(features): + feature["token_type_ids"] = token_type_ids[i] + + features: Dict[str, torch.Tensor] = super().__call__(features) + features.update(mm_inputs) + if isinstance(features.get("pixel_values"), list): # for pixtral inputs + features = features.data # use default_collate() instead of BatchEncoding.to() + return features + + +# Copied from https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llamafactory/data/collator.py +@dataclass +class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): + r""" + Data collator for 4d attention mask. + """ + + block_diag_attn: bool = False + attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" + compute_dtype: "torch.dtype" = torch.float32 + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + features = super().__call__(features) + if self.block_diag_attn and self.attn_implementation != "flash_attention_2": + features["attention_mask"] = self.prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) + return features + + def prepare_4d_attention_mask(self, attention_mask_with_indices, dtype) -> torch.Tensor: + r""" + Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), + while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. + + e.g. + ```python + # input + [[1, 1, 2, 2, 2, 0]] + # output + [ + [ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, x, x, x, x], + ] + ] + ] + ``` + where `o` equals to `0.0`, `x` equals to `min_dtype`. + """ + bsz, seq_len = attention_mask_with_indices.size() + min_dtype = torch.finfo(dtype).min + expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + padding_mask = torch.where(expanded_mask != 0, 1, 0) + # Create a block-diagonal mask. + attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask + # Use the lower triangular mask to zero out the upper triangular part + attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) + # Invert the attention mask. + attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) + return attention_mask_4d + + +# Copied from https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llamafactory/data/collator.py +@dataclass +class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): + r""" + Data collator for pairwise data. + """ + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + concatenated_features = [] + for key in ("chosen", "rejected"): + for feature in features: + target_feature = { + "input_ids": feature[f"{key}_input_ids"], + "attention_mask": feature[f"{key}_attention_mask"], + "labels": feature[f"{key}_labels"], + "images": feature["images"], + "videos": feature["videos"], + } + concatenated_features.append(target_feature) + + return super().__call__(concatenated_features) + + +def build_train_and_valid_datasets(args, tokenizer, processor, encoder_utils, strategy): + train_ds, eval_ds = blending_datasets( + args.dataset, + args.dataset_probs, + strategy, + args.seed, + max_count=args.max_samples, + train_split=args.train_split, + eval_split=args.eval_split, + ) + + return get_preprocessed_dataset(args, [train_ds, eval_ds], encoder_utils, tokenizer, processor) + + +def build_data_collator(args, tokenizer, encoder_utils, vision_processor): + collator_class = None + kwargs = {} + if args.task_type == "dpo": + collator_class = PairwiseDataCollatorWithPadding + elif args.task_type == "sft": + collator_class = SFTDataCollatorWith4DAttentionMask + kwargs = { + "block_diag_attn": args.neat_packing, + "attn_implementation": "flash_attention_2" if args.flash_attn else None, + "compute_dtype": torch.bfloat16 if args.bf16 else torch.float16 + } + else: + raise NotImplementedError(f"Task type {args.task_type} not supported.") + + data_collator = collator_class( + encoder_utils=encoder_utils, + vision_processor=vision_processor, + pad_to_multiple_of=8, + label_pad_token_id=IGNORE_INDEX, + tokenizer=tokenizer, + **kwargs + ) + return data_collator diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py index 68009603c6..13f265f157 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py @@ -5,7 +5,7 @@ import torch.distributed as dist import torch.nn as nn from peft import LoraConfig, TaskType, get_peft_model from peft.tuners.lora import LoraLayer -from transformers import AutoModelForCausalLM, BitsAndBytesConfig +from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoModelForVision2Seq from transformers.integrations.deepspeed import HfDeepSpeedConfig from .ring_attn_utils import convert_ring_attn_params @@ -45,12 +45,19 @@ class Actor(nn.Module): ds_config=None, device_map=None, packing_samples=False, + create_vison_model=False, + freeze_vision_tower=True, **kwargs, ) -> None: super().__init__() if isinstance(pretrain_or_model, str): - attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" + if use_flash_attention_2 == "fa2": + attn_implementation = "flash_attention_2" + elif use_flash_attention_2 == "sdpa": + attn_implementation = "sdpa" + else: + attn_implementation = "eager" # Note: dschf is defined in function scope to avoid global effects # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration @@ -69,8 +76,12 @@ class Actor(nn.Module): ) else: nf4_config = None - - self.model = AutoModelForCausalLM.from_pretrained( + model_class = None + if create_vison_model: + model_class = AutoModelForVision2Seq + else: + model_class = AutoModelForCausalLM + self.model = model_class.from_pretrained( pretrain_or_model, trust_remote_code=True, attn_implementation=attn_implementation, @@ -78,6 +89,9 @@ class Actor(nn.Module): torch_dtype=torch.bfloat16 if bf16 else "auto", device_map=device_map, ) + self.model_type = getattr(self.model.config, "model_type", None) + if create_vison_model: + self.prepare_model(self.model, freeze_vision_tower) # LoRA if lora_rank > 0: @@ -188,9 +202,22 @@ class Actor(nn.Module): return_output=False, ring_attn_group: Optional[dist.ProcessGroup] = None, packed_seq_lens: Optional[list[int]] = None, + **kwargs ) -> torch.Tensor: """Returns action log probs""" - if not self.packing_samples: + if self.model_type == "qwen2_vl": + # Before Transformers version 4.47, when using the Qwen2VL model, + # the position IDs needed to be externally provided in a specific mrope format + # during the forward pass. Therefore, it was decided to consistently pass them + # externally through the model. + position_ids, rope_deltas = self.model.get_rope_index( + input_ids=sequences, + image_grid_thw=kwargs.get("image_grid_thw", None), + video_grid_thw=kwargs.get("video_grid_thw", None), + attention_mask=attention_mask, + ) + kwargs["rope_deltas"] = rope_deltas + elif not self.packing_samples: # https://github.com/OpenRLHF/OpenRLHF/issues/217 position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) @@ -205,7 +232,8 @@ class Actor(nn.Module): # explicitly ignore attention_mask for packing_samples attention_mask = None - output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids) + output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids, + **kwargs) # https://github.com/OpenRLHF/OpenRLHF/pull/634 output["logits"] = output["logits"].to(torch.float32) @@ -240,3 +268,19 @@ class Actor(nn.Module): def print_trainable_parameters(self): self.model.print_trainable_parameters() + + def prepare_model(self, model, freeze_vision_tower): + freeze_modules = set() + if self.model_type == "qwen2_vl": + if freeze_vision_tower: + freeze_modules.add("visual") + elif self.model_type == None: + pass + else: + raise NotImplementedError("TODO: Implement prepare_model for model_type: {}".format(self.model_type)) + + for name, param in model.named_parameters(): + if not any(freeze_mod in name for freeze_mod in freeze_modules): + pass + else: + param.requires_grad_(False) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py index a26d247b63..4eaeca47a9 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py @@ -1,10 +1,10 @@ -from .dpo_trainer import DPOTrainer +from .dpo_trainer import DPOTrainer, VLDPOTrainer from .kd_trainer import KDTrainer from .kto_trainer import KTOTrainer from .ppo_trainer import PPOTrainer from .prm_trainer import ProcessRewardModelTrainer from .rm_trainer import RewardModelTrainer -from .sft_trainer import SFTTrainer +from .sft_trainer import SFTTrainer, VLSFTTrainer __all__ = [ "DPOTrainer", @@ -14,4 +14,6 @@ __all__ = [ "ProcessRewardModelTrainer", "RewardModelTrainer", "SFTTrainer", + "VLSFTTrainer", + "VLDPOTrainer", ] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py index 55088cd559..3b8c45bac4 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py @@ -1,14 +1,20 @@ import os +import time + from abc import ABC import torch -from flash_attn.utils.distributed import all_gather from torch.nn import functional as F from torch.optim import Optimizer from tqdm import tqdm +from transformers.utils import is_flash_attn_2_available from openrlhf.models import DPOLoss from openrlhf.utils.distributed_sampler import DistributedSampler +from openrlhf.utils.vision_utils import IGNORE_INDEX + +if is_flash_attn_2_available(): + from flash_attn.utils.distributed import all_gather class DPOTrainer(ABC): @@ -197,7 +203,6 @@ class DPOTrainer(ABC): logs_dict["nll_loss"] = nll_loss.item() # step bar logs_dict = self.strategy.all_reduce(logs_dict) - step_bar.set_postfix(logs_dict) step_bar.update() # logs/checkpoints/evaluation @@ -223,9 +228,11 @@ class DPOTrainer(ABC): def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): # logs if global_step % args.logging_steps == 0: + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + if self.strategy.is_rank_0(): + step_bar.write(str(logs)) # wandb if self._wandb is not None and self.strategy.is_rank_0(): - logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} self._wandb.log(logs) # TensorBoard elif self._tensorboard is not None and self.strategy.is_rank_0(): @@ -476,3 +483,238 @@ class DPOTrainer(ABC): index = index + seq_len return torch.stack(logprobs_sums), torch.stack(logprobs_means) + + +class VLDPOTrainer(DPOTrainer): + """ + Trainer for Direct Preference Optimization (DPO) training. + + Args: + model (torch.nn.Module): The primary model to be trained. + ref_model (torch.nn.Module): The reference model for comparing and guiding preference. + strategy (Strategy): The strategy to use for training. + tokenizer (Tokenizer): The tokenizer for processing input data. + optim (Optimizer): The optimizer for training the model. + train_dataloader (DataLoader): The dataloader for the training dataset. + eval_dataloader (DataLoader): The dataloader for the evaluation dataset. + scheduler (Scheduler): The learning rate scheduler to control learning rate during training. + max_norm (float, defaults to 0.5): Maximum gradient norm for gradient clipping. + beta (float, defaults to 0.01): Coefficient for regularizing the preference loss. + max_epochs (int, defaults to 2): Maximum number of training epochs. + """ + def data_to_device(self, input_data): + for key, value in input_data.items(): + input_data[key] = value.to(torch.cuda.current_device()) + return input_data + + def concatenated_forward(self, model, input_ids, attn_masks, labels, pixel_values, image_grid_thw): + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + + output = model(input_ids, + attention_mask=attn_masks, + return_output=True, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + all_logits = output["logits"] + all_logps_sum, all_logps_mean = self._get_batch_logps( + all_logits, labels, attn_masks, None, average_log_prob=False + ) + assert input_ids.shape[0] % 2 == 0 + batch_size = input_ids.shape[0] // 2 + chosen_logps = all_logps_sum[:batch_size] + rejected_logps = all_logps_sum[batch_size:] + aux_loss = output.aux_loss if "aux_loss" in output else [] + return chosen_logps, rejected_logps, aux_loss, -all_logps_mean[: batch_size].mean() + + def _get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_id_lens, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert average_log_prob == False + assert logits.shape[:-1] == labels.shape + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_masks = (labels != IGNORE_INDEX) + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == IGNORE_INDEX] = 0 + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + logprobs_sums = (per_token_logps * loss_masks).sum(-1) + logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1) + return logprobs_sums, logprobs_means + + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm( + range(start_epoch, self.epochs), + desc="Train epoch", + disable=not self.strategy.is_rank_0(), + ) + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + self.model.train() + self.ref_model.eval() + acc_mean = 0 + loss_mean = 0 + + assert self.strategy.ring_attn_group is None, f"Ring attention is not supported on vision models currently" + + # train + for input_data in self.train_dataloader: + start_time = time.time() + data = self.data_to_device(input_data) + + chosen_logps, rejected_logps, aux_loss, nll_loss = self.concatenated_forward( + self.model, data["input_ids"], data["attention_mask"], + data["labels"], data["pixel_values"], data["image_grid_thw"], + ) + with torch.no_grad(): + reference_chosen_logps, reference_rejected_logps, _, _ = self.concatenated_forward( + self.ref_model, data["input_ids"], data["attention_mask"], + data["labels"], data["pixel_values"], data["image_grid_thw"], + ) + + # loss function + preference_loss, chosen_reward, reject_reward = self.loss_fn( + chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + # mixtral + if not self.aux_loss: + aux_loss = 0 + # nll loss + if not self.nll_loss: + nll_loss = 0 + + loss = preference_loss + aux_loss * self.args.aux_loss_coef + nll_loss * self.args.nll_loss_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + acc = (chosen_reward > reject_reward).float().mean().item() + acc_mean = acc_mean * 0.9 + 0.1 * acc + loss_mean = loss_mean * 0.9 + 0.1 * preference_loss.item() + # dpo logs + logs_dict = { + "loss": preference_loss.item(), + "acc": acc, + "chosen_reward": chosen_reward.mean().item(), + "reject_reward": reject_reward.mean().item(), + "loss_mean": loss_mean, + "acc_mean": acc_mean, + "lr": self.scheduler.get_last_lr()[0], + } + grad_norm = self.model.model.get_global_grad_norm() + if grad_norm is not None: + logs_dict.update({ + "grad_norm": grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm}) + + if self.nll_loss: + logs_dict["nll_loss"] = nll_loss.item() + # step bar + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.update() + end_time = time.time() + step_time = end_time - start_time + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + logs_dict["step_time"] = f"{step_time:.3f}s" + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + epoch_bar.update() + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + def evaluate(self, eval_dataloader, steps=0): + self.model.eval() + with torch.no_grad(): + step_bar = tqdm( + range(eval_dataloader.__len__()), + desc="Eval stage of global_step %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + acc_sum = 0 + loss_sum = 0 + times = 0 + for input_data in eval_dataloader: + data = self.data_to_device(input_data) + + chosen_logps, rejected_logps, aux_loss, _ = self.concatenated_forward( + self.model, data["input_ids"], data["attention_mask"], + data["labels"], data["pixel_values"], data["image_grid_thw"], + ) + with torch.no_grad(): + reference_chosen_logps, reference_rejected_logps, _, _ = self.concatenated_forward( + self.ref_model, data["input_ids"], data["attention_mask"], + data["labels"], data["pixel_values"], data["image_grid_thw"], + ) + + loss, chosen_reward, reject_reward = self.loss_fn( + chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + acc_sum += (chosen_reward > reject_reward).float().mean().item() + loss_sum += loss.item() + times += 1 + step_bar.update() + + logs = { + "eval_loss": loss_sum / times, + "acc_mean": acc_sum / times, + } + logs = self.strategy.all_reduce(logs) + step_bar.set_postfix(logs) + + if self.strategy.is_rank_0(): + if self._wandb is not None: + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + elif self._tensorboard is not None: + for k, v in logs.items(): + self._tensorboard.add_scalar(f"eval/{k}", v, steps) + self.model.train() # reset model state + diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py index fa92452d50..8523197035 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py @@ -1,5 +1,8 @@ +import time import os + from abc import ABC +from typing_extensions import override import torch from torch.optim import Optimizer @@ -165,18 +168,10 @@ class SFTTrainer(ABC): if not self.pretrain_mode: if self.packing_samples: - # As response_ranges need to constrain the dataset organization strictly, we handle multiturn feature separately. - if infos["response_ranges"]: - dump_labels = torch.full(labels.size(), self.loss_fn.IGNORE_INDEX).to(labels.device) - for response_ranges in infos["response_ranges"]: - for response_range in response_ranges: - dump_labels[0][response_range[0]: response_range[1]] = labels[0][response_range[0]: response_range[1]] - labels = dump_labels - else: - index = 0 - for input_length, source_len in zip(infos["input_length"], prompt_id_lens): - labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX - index += input_length + index = 0 + for input_length, source_len in zip(infos["input_length"], prompt_id_lens): + labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX + index += input_length else: for label, source_len in zip(labels, prompt_id_lens): label[:source_len] = self.loss_fn.IGNORE_INDEX @@ -195,7 +190,6 @@ class SFTTrainer(ABC): logs_dict["aux_loss"] = aux_loss.item() # step bar logs_dict = self.strategy.all_reduce(logs_dict) - step_bar.set_postfix(logs_dict) step_bar.update() # logs/checkpoints/evaluation @@ -218,9 +212,11 @@ class SFTTrainer(ABC): # logs/checkpoints/evaluation def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): if global_step % args.logging_steps == 0: + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + if self.strategy.is_rank_0(): + step_bar.write(str(logs)) # wandb if self._wandb is not None and self.strategy.is_rank_0(): - logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} self._wandb.log(logs) # TensorBoard elif self._tensorboard is not None and self.strategy.is_rank_0(): @@ -284,17 +280,10 @@ class SFTTrainer(ABC): if not self.pretrain_mode: if self.packing_samples: - if infos["response_ranges"]: - dump_labels = torch.full(labels.size(), self.loss_fn.IGNORE_INDEX).to(labels.device) - for response_ranges in infos["response_ranges"]: - for response_range in response_ranges: - dump_labels[0][response_range[0]: response_range[1]] = labels[0][response_range[0]: response_range[1]] - labels = dump_labels - else: - index = 0 - for input_length, source_len in zip(infos["input_length"], prompt_id_lens): - labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX - index += input_length + index = 0 + for input_length, source_len in zip(infos["input_length"], prompt_id_lens): + labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX + index += input_length else: for label, source_len in zip(labels, prompt_id_lens): label[:source_len] = self.loss_fn.IGNORE_INDEX @@ -316,3 +305,147 @@ class SFTTrainer(ABC): for k, v in logs.items(): self._tensorboard.add_scalar(f"eval/{k}", v, steps) self.model.train() # reset model state + + +class VLSFTTrainer(SFTTrainer): + def data_to_device(self, input_data): + for key, value in input_data.items(): + input_data[key] = value.to(torch.cuda.current_device()) + return input_data + + @override + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm( + range(start_epoch, self.epochs), + desc="Train epoch", + disable=not self.strategy.is_rank_0(), + ) + + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + # train + self.model.train() + loss_mean = 0 + + assert self.strategy.ring_attn_group is None, f"Ring attention is not supported on vision models currently" + + for input_data in self.train_dataloader: + start_time = time.time() + data = self.data_to_device(input_data) + labels = data["labels"] + + output = self.model( + data["input_ids"], + attention_mask=data["attention_mask"], + return_output=True, + pixel_values=data["pixel_values"], + image_grid_thw=data["image_grid_thw"], + ) + + # mixtral + if self.aux_loss: + aux_loss = output.aux_loss + else: + aux_loss = 0 + + gpt_loss = self.loss_fn(output.logits, labels) + loss = gpt_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + loss_mean = loss_mean * 0.9 + 0.1 * gpt_loss.item() + logs_dict = { + "gpt_loss": gpt_loss.item(), + "loss_mean": loss_mean, + "lr": self.scheduler.get_last_lr()[0], + } + grad_norm = self.model.model.get_global_grad_norm() + if grad_norm is not None: + logs_dict.update({ + "grad_norm": grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm}) + + if self.aux_loss: + logs_dict["aux_loss"] = aux_loss.item() + # step bar + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.update() + end_time = time.time() + step_time = end_time - start_time + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + logs_dict["step_time"] = f"{step_time:.3f}s" + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + + epoch_bar.update() + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + def evaluate(self, eval_dataloader, steps=0): + times = 0 + self.model.eval() + with torch.no_grad(): + loss_sum = 0 + step_bar = tqdm( + range(eval_dataloader.__len__()), + desc="Eval stage of steps %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + + for input_data in eval_dataloader: + data = self.data_to_device(input_data) + labels = data["labels"] + + output = self.model( + data["input_ids"], + attention_mask=data["attention_mask"], + return_output=True, + pixel_values=data["pixel_values"], + image_grid_thw=data["image_grid_thw"], + ) + + loss = self.loss_fn(output.logits, labels) + + times += 1 + loss_sum += loss.item() + bar_dict = {"eval gpt_loss": loss_sum / times} + step_bar.update() + logs = self.strategy.all_reduce(bar_dict) + step_bar.set_postfix(logs) + + if self.strategy.is_rank_0(): + if self._wandb is not None: + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + elif self._tensorboard is not None: + for k, v in logs.items(): + self._tensorboard.add_scalar(f"eval/{k}", v, steps) + self.model.train() # reset model state diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py index 08ab0a9ba9..e69a0696ac 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py @@ -1,5 +1,7 @@ from .processor import get_processor, reward_normalization from .utils import blending_datasets, get_strategy, get_tokenizer +from .vision_args import add_vision_args +from .vision_utils import get_qwen2_vl_utils, get_vision_processor __all__ = [ "get_processor", @@ -7,4 +9,7 @@ __all__ = [ "blending_datasets", "get_strategy", "get_tokenizer", + "get_vision_processor", + "get_qwen2_vl_utils", + "add_vision_args", ] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py index a69b13ece3..f3ad417387 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py @@ -73,9 +73,14 @@ def blending_datasets( strategy.print(f"loaded {dataset} with data_files={dataset}") # local dataset saved with `datasets.Dataset.save_to_disk` elif os.path.isdir(dataset): - data = load_from_disk(dataset) - strategy.print(f"loaded {dataset} from disk") - # remote/local folder or common file + try: + data = load_from_disk(dataset) + strategy.print(f"loaded {dataset} from disk") + except Exception as e: + strategy.print(f"failed to load {dataset} from disk, Attempting to load data files from folder or common file") + data = load_dataset(dataset, data_dir=data_dir) + strategy.print(f"loaded {dataset} from files") + # remote/local folder or common files else: data = load_dataset(dataset, data_dir=data_dir) strategy.print(f"loaded {dataset} from files") diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_args.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_args.py new file mode 100644 index 0000000000..be1acde581 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_args.py @@ -0,0 +1,41 @@ +def add_vision_args(parser): + group = parser.add_argument_group(title='vision args') + group.add_argument("--task_type", type=str, + default="sft", + choices=["sft", "dpo"], + help="task type") + group.add_argument("--model_arch", type=str, choices=["qwen2_vl"], + help="model arch",) + group.add_argument("--dataset_config_path", type=str, default=None, + help="the dataset config") + + group.add_argument("--image_resolution", type=int, default=512 * 512, + help="The number of pixels of image below this resolution.") + group.add_argument("--video_resolution", type=int, default=128 * 128, + help="The number of pixels of video below this resolution.") + group.add_argument("--video_fps", type=float, default=2.0, + help="The frames to sample per second for video inputs.") + group.add_argument("--video_maxlen", type=int, default=64, + help="The maximum number of sampled frames for video inputs.") + + group.add_argument("--efficient_eos", type=bool, default=False, + help="the efficient_eos of VisionEncoderUtils") + group.add_argument("--processing_num_workers", type=int, default=18, + help="num workers processing process") + group.add_argument("--train_on_prompt", type=bool, default=False, + help="train_on_prompt") + group.add_argument("--mask_history", type=bool, default=False, + help="mask_history") + group.add_argument("--overwrite_cache", type=bool, default=True, + help="overwrite_cache") + group.add_argument("--local_process_index", type=int, default=0, + help="local_process_index") + group.add_argument("--preprocessing_batch_size", type=int, default=1000, + help="preprocessing_batch_size") + group.add_argument("--neat_packing", action="store_true", + help="enable sequence packing without cross-attention.") + + group.add_argument("--freeze_vision_tower", type=bool, default=True, + help="Whether ot not to freeze vision tower in training. default: True") + + return parser -- Gitee