diff --git a/configs/rlxf/model/llama2-7b.yaml b/configs/rlxf/model/llama2-7b.yaml deleted file mode 100644 index 8a84f722589705f5854c26837a7a467ef1d08f45..0000000000000000000000000000000000000000 --- a/configs/rlxf/model/llama2-7b.yaml +++ /dev/null @@ -1,27 +0,0 @@ -llama2-7b: - use_mcore_models: true - sequence_parallel: true - num_layers: 32 - hidden_size: 4096 - ffn_hidden_size: 11008 - num_attention_heads: 32 - seq_length: 4096 - max_position_embeddings: 4096 - make_vocab_size_divisible_by: 1 - untie_embeddings_and_output_weights: true - disable_bias_linear: true - attention_dropout: 0.0 - init_method_std: 0.01 - hidden_dropout: 0.0 - position_embedding_type: rope - normalization: RMSNorm - use_fused_rmsnorm: true - swiglu: true - use_flash_attn: true - use_mc2: true - no_masked_softmax_fusion: true - attention_softmax_in_fp32: true - no_gradient_accumulation_fusion: true - use_fused_swiglu: true - use_fused_rotary_pos_emb: true - bf16: true \ No newline at end of file diff --git a/configs/rlxf/model/llama3-8b.yaml b/configs/rlxf/model/llama3-8b.yaml deleted file mode 100644 index b6d54851bdcba69528f38a3908dcb2972139633f..0000000000000000000000000000000000000000 --- a/configs/rlxf/model/llama3-8b.yaml +++ /dev/null @@ -1,36 +0,0 @@ -llama3-8b: - use_mcore_models: true - sequence_parallel: true - use_flash_attn: true - use_rotary_position_embeddings: true - rope_scaling_type: llama3 - rope_scaling_factor: 8.0 - rotary_percent: 1.0 - low_freq_factor: 1.0 - high_freq_factor: 4.0 - original_max_position_embeddings: 8192 - num_layers: 32 - hidden_size: 4096 - ffn_hidden_size: 14336 - num_attention_heads: 32 - group_query_attention: true - num_query_groups: 8 - max_position_embeddings: 8192 - make_vocab_size_divisible_by: 1 - padded_vocab_size: 128256 - untie_embeddings_and_output_weights: true - disable_bias_linear: true - attention_dropout: 0.0 - init_method_std: 0.02 - hidden_dropout: 0.0 - position_embedding_type: rope - rotary_base: 500000 - normalization: RMSNorm - norm_epsilon: 1e-5 - swiglu: true - no_masked_softmax_fusion: true - attention_softmax_in_fp32: true - no_gradient_accumulation_fusion: true - bf16: true - seed: 42 - vocab_size: 128256 \ No newline at end of file diff --git a/configs/rlxf/model/llama32-1b.yaml b/configs/rlxf/model/llama32-1b.yaml deleted file mode 100644 index 324548e026623607a47cfb6c480c2376496626a9..0000000000000000000000000000000000000000 --- a/configs/rlxf/model/llama32-1b.yaml +++ /dev/null @@ -1,35 +0,0 @@ -llama32-1b: - use_mcore_models: true - sequence_parallel: true - use_mc2: true - use_flash_attn: true - use_rotary_position_embeddings: true - use_fused_rmsnorm: true - use_fused_swiglu: true - rope_scaling_type: llama3 - rope_scaling_factor: 32.0 - low_freq_factor: 1.0 - high_freq_factor: 4.0 - original_max_position_embeddings: 8192 - max_position_embeddings: 8192 - num_layers: 16 - hidden_size: 2048 - ffn_hidden_size: 8192 - num_attention_heads: 32 - group_query_attention: true - num_query_groups: 8 - make_vocab_size_divisible_by: 1 - padded_vocab_size: 128256 - disable_bias_linear: true - attention_dropout: 0.0 - init_method_std: 0.01 - hidden_dropout: 0.0 - position_embedding_type: rope - rotary_base: 500000 - normalization: RMSNorm - norm_epsilon: 1e-5 - swiglu: true - no_masked_softmax_fusion: true - attention_softmax_in_fp32: true - no_gradient_accumulation_fusion: true - bf16: true \ No newline at end of file diff --git a/configs/rlxf/model/qwen25-7b.yaml b/configs/rlxf/model/qwen25-7b.yaml deleted file mode 100644 index b5eb4b00e7601bcb6c9a6fbe9c46e1836428c643..0000000000000000000000000000000000000000 --- a/configs/rlxf/model/qwen25-7b.yaml +++ /dev/null @@ -1,33 +0,0 @@ -qwen25-7b: - use_mcore_models: true - num_layers: 28 - hidden_size: 3584 - ffn_hidden_size: 18944 - num_attention_heads: 28 - seq_length: 4096 - rotary_base: 1000000 - max_position_embeddings: 32768 - make_vocab_size_divisible_by: 1 - padded_vocab_size: 152064 - untie_embeddings_and_output_weights: true - add_qkv_bias: true - disable_bias_linear: true - group_query_attention: true - num_query_groups: 4 - attention_dropout: 0.0 - init_method_std: 0.01 - hidden_dropout: 0.0 - adam_beta1: 0.9 - adam_beta2: 0.95 - position_embedding_type: rope - normalization: RMSNorm - use_fused_rmsnorm: true - swiglu: true - use_mc2: true - use_flash_attn: true - no_masked_softmax_fusion: true - attention_softmax_in_fp32: true - no_gradient_accumulation_fusion: true - use_fused_swiglu: true - use_fused_rotary_pos_emb: true - bf16: true \ No newline at end of file diff --git a/configs/rlxf/online_dpo_trainer_llama32_1b.yaml b/configs/rlxf/online_dpo_trainer_llama32_1b.yaml deleted file mode 100644 index 49fb3fbceaa80a998f811a3f3f9866698074faac..0000000000000000000000000000000000000000 --- a/configs/rlxf/online_dpo_trainer_llama32_1b.yaml +++ /dev/null @@ -1,86 +0,0 @@ -defaults: - - model: - - llama32-1b - -training: - global_batch_size: 4 - seq_length: 512 - tokenizer_type: PretrainedFromHF - tokenizer_name_or_path: ./models/llama-3.2-1b-instruct/ - train_iters: 1000 - distributed_backend: nccl - no_shared_storage: true - save_interval: 10000 - no_load_optim: true - no_load_rng: true - bf16: true - is_instruction_dataset: true - variable_seq_lengths: true - stage: ray_online_dpo - - -actor_rollout_ref: - actor_rollout: - model: llama32-1b - micro_batch_size: 4 - ppo_mini_batch_size: 4 - max_prompt_length: 256 - ppo_epochs: 1 - clip_ratio: 0.2 - entropy_coeff: 0.001 - do_sample: true - shuffle: false - use_kv_cache: true - num_samples_per_step: 4 - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - lr: 1e-7 - lr_decay_style: constant - min_lr: 0.0 - weight_decay: 0.0 - lr_warmup_fraction: 0.0 - clip_grad: 10000.0 - adam_beta1: 0.9 - adam_beta2: 0.999 - initial_loss_scale: 4096 - finetune: true - load: ./models/llama-3.2-1b-instruct-tp1-pp1 - save: ./ckpt - num_gpus_for_train: 1 - num_gpus_for_infer: 1 - pad_to_multiple_of: 1 - inference_tensor_model_parallel_size: 1 - data_path: ./dataset/descriptiveness/descriptiveness - split: 100,0,0 - no_shuffle: true - missing_eos_penalty: 1.0 - - ref: - model: llama32-1b - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - micro_batch_size: 4 - load: ./models/llama-3.2-1b-instruct-tp1-pp1 - -reward: - model: llama32-1b - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - micro_batch_size: 4 - sequence_parallel: False - load: ./models/llama-3.2-1b-rm-mcore-tp1-pp1 - -algorithm: - gamma: 1.0 - lam: 0.95 - adv_estimator: gae - kl_penalty: kl - kl_ctrl: - type: fixed - kl_coef: 0.05 - missing_eos_penalty: 1.0 - -resource_pool: - actor_rollout: [2] - ref: [1] - reward: [1] diff --git a/configs/rlxf/ppo_trainer_llama2_7b.yaml b/configs/rlxf/ppo_trainer_llama2_7b.yaml deleted file mode 100644 index 1dd1788100f68f7336285a8305b2e7522e5b4cf2..0000000000000000000000000000000000000000 --- a/configs/rlxf/ppo_trainer_llama2_7b.yaml +++ /dev/null @@ -1,108 +0,0 @@ -defaults: - - model: - - llama2-7b - -training: - global_batch_size: 1 - seq_length: 512 - tokenizer_type: PretrainedFromHF - tokenizer_name_or_path: ./models/llama-2-7b/ - train_iters: 1000 - distributed_backend: nccl - no_shared_storage: true - save_interval: 10000 - no_load_optim: true - no_load_rng: true - bf16: true - is_instruction_dataset: true - variable_seq_lengths: true - no_shuffle: true - stage: ray_ppo - sequence_parallel: False - -actor_rollout_ref: - actor_rollout: - model: llama2-7b - do_sample: false - micro_batch_size: 1 - ppo_mini_batch_size: 1 - num_samples_per_step: 1 - max_prompt_length: 256 - ppo_epochs: 1 - clip_ratio: 0.2 - entropy_coeff: 0.001 - shuffle_minibatch: false - use_kv_cache: true - tensor_model_parallel_size: 4 - pipeline_model_parallel_size: 1 - lr: 1e-7 - lr_decay_style: constant - min_lr: 0.0 - weight_decay: 0.0 - lr_warmup_fraction: 0.0 - clip_grad: 1.0 - adam_beta1: 0.9 - adam_beta2: 0.95 - initial_loss_scale: 1 - finetune: true - load: ./ckpt - save: ./ckpt - num_gpus_for_train: 4 - num_gpus_for_infer: 4 - pad_to_multiple_of: 1 - data_path: ./dataset/descriptiveness/descriptiveness - split: 100,0,0 - - ref: - model: llama2-7b - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - micro_batch_size: 1 - load: ./ckpt - -critic: - model: llama2-7b - tensor_model_parallel_size: 4 - pipeline_model_parallel_size: 1 - use_mcore_models: True - micro_batch_size: 1 - lr: 1e-7 - lr_decay_style: constant - min_lr: 0.0 - weight_decay: 0.0 - lr_warmup_fraction: 0.0 - use_distributed_optimizer: true - clip_grad: 1.0 - adam_beta1: 0.9 - adam_beta2: 0.95 - initial_loss_scale: 1 - no_load_optim: True - no_load_rng: True - load: ./ckpt - save: ./ckpt - cliprange_value: 0.2 - critic_mini_batch_size: 1 - critic_update_epochs: 1 - -reward: - model: llama2-7b - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - micro_batch_size: 1 - load: ./ckpt - -algorithm: - gamma: 1.0 - lam: 0.95 - adv_estimator: gae - kl_penalty: kl - kl_ctrl: - type: fixed - kl_coef: 0.05 - missing_eos_penalty: 0.0 - -resource_pool: - actor_rollout: [8] - ref: [2] - critic: [4] - reward: [2] \ No newline at end of file diff --git a/configs/rlxf/ppo_trainer_llama32_1b.yaml b/configs/rlxf/ppo_trainer_llama32_1b.yaml deleted file mode 100644 index 5f1d3c16c953f4b11ad13c525dc86414dbacf140..0000000000000000000000000000000000000000 --- a/configs/rlxf/ppo_trainer_llama32_1b.yaml +++ /dev/null @@ -1,111 +0,0 @@ -defaults: - - model: - - llama32-1b - -training: - global_batch_size: 8 - seq_length: 512 - tokenizer_type: PretrainedFromHF - tokenizer_name_or_path: ./models/llama-3.2-1b-instruct/ - train_iters: 1000 - distributed_backend: nccl - no_shared_storage: true - save_interval: 10000 - no_load_optim: true - no_load_rng: true - bf16: true - is_instruction_dataset: true - variable_seq_lengths: true - no_shuffle: true - stage: ray_ppo - -actor_rollout_ref: - actor_rollout: - model: llama32-1b - do_sample: false - micro_batch_size: 4 - ppo_mini_batch_size: 4 - num_samples_per_step: 2 - max_prompt_length: 256 - ppo_epochs: 1 - clip_ratio: 0.2 - entropy_coeff: 0.001 - shuffle_minibatch: false - use_kv_cache: true - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - lr: 1e-7 - lr_decay_style: constant - min_lr: 0.0 - weight_decay: 0.0 - lr_warmup_fraction: 0.0 - clip_grad: 10000.0 - adam_beta1: 0.9 - adam_beta2: 0.999 - initial_loss_scale: 4096 - finetune: true - load: ./ckpt - save: ./ckpt - num_gpus_for_train: 1 - num_gpus_for_infer: 1 - pad_to_multiple_of: 1 - data_path: ./dataset/descriptiveness/descriptiveness - split: 100,0,0 - - ref: - model: llama32-1b - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - micro_batch_size: 8 - load: ./ckpt - -critic: - model: llama32-1b - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - use_mcore_models: True - micro_batch_size: 4 - sequence_parallel: False - lr: 1e-7 - lr_decay_style: constant - min_lr: 0.0 - weight_decay: 0.0 - lr_warmup_fraction: 0.0 - use_distributed_optimizer: true - clip_grad: 10000.0 - adam_beta1: 0.9 - adam_beta2: 0.999 - initial_loss_scale: 4096 - no_load_optim: True - no_load_rng: True - is_instruction_dataset: true - variable_seq_lengths: true - load: ./ckpt - save: ./ckpt - cliprange_value: 0.2 - critic_mini_batch_size: 4 - critic_update_epochs: 1 - -reward: - model: llama32-1b - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - micro_batch_size: 8 - sequence_parallel: false - load: ./ckpt - -algorithm: - gamma: 1.0 - lam: 0.95 - adv_estimator: gae - kl_penalty: kl - kl_ctrl: - type: fixed - kl_coef: 0.05 - missing_eos_penalty: 0.0 - -resource_pool: - actor_rollout: [2] - ref: [1] - critic: [1] - reward: [1] \ No newline at end of file diff --git a/configs/rlxf/ppo_trainer_llama3_8b.yaml b/configs/rlxf/ppo_trainer_llama3_8b.yaml deleted file mode 100644 index d5e37bc60f813b1d4e963781290fd946b5424e07..0000000000000000000000000000000000000000 --- a/configs/rlxf/ppo_trainer_llama3_8b.yaml +++ /dev/null @@ -1,108 +0,0 @@ -defaults: - - model: - - llama3-8b - -training: - global_batch_size: 2 - seq_length: 512 - tokenizer_type: PretrainedFromHF - tokenizer_name_or_path: ./models/llama-3-8b/ - train_iters: 1000 - distributed_backend: nccl - no_shared_storage: true - save_interval: 10000 - no_load_optim: true - no_load_rng: true - bf16: true - is_instruction_dataset: true - variable_seq_lengths: true - no_shuffle: true - stage: ray_ppo - sequence_parallel: False - -actor_rollout_ref: - actor_rollout: - model: llama3-8b - do_sample: false - micro_batch_size: 1 - ppo_mini_batch_size: 1 - num_samples_per_step: 1 - max_prompt_length: 256 - ppo_epochs: 1 - clip_ratio: 0.2 - entropy_coeff: 0.001 - shuffle_minibatch: false - use_kv_cache: true - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 2 - lr: 1e-7 - lr_decay_style: constant - min_lr: 0.0 - weight_decay: 0.0 - lr_warmup_fraction: 0.0 - clip_grad: 1.0 - adam_beta1: 0.9 - adam_beta2: 0.999 - initial_loss_scale: 4096 - finetune: true - load: ./ckpt - save: ./ckpt - num_gpus_for_train: 4 - num_gpus_for_infer: 4 - pad_to_multiple_of: 1 - data_path: ./dataset/descriptiveness/descriptiveness - split: 100,0,0 - - ref: - model: llama3-8b - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - micro_batch_size: 1 - load: ./ckpt - -critic: - model: llama3-8b - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 2 - use_mcore_models: True - micro_batch_size: 1 - lr: 1e-7 - lr_decay_style: constant - min_lr: 0.0 - weight_decay: 0.0 - lr_warmup_fraction: 0.0 - use_distributed_optimizer: true - clip_grad: 1.0 - adam_beta1: 0.9 - adam_beta2: 0.999 - initial_loss_scale: 1 - no_load_optim: True - no_load_rng: True - load: ./ckpt - save: ./ckpt - cliprange_value: 0.2 - critic_mini_batch_size: 1 - critic_update_epochs: 1 - -reward: - model: llama3-8b - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - micro_batch_size: 1 - load: ./ckpt - -algorithm: - gamma: 1.0 - lam: 0.95 - adv_estimator: gae - kl_penalty: kl - kl_ctrl: - type: fixed - kl_coef: 0.05 - missing_eos_penalty: 0.0 - -resource_pool: - actor_rollout: [8] - ref: [2] - critic: [4] - reward: [2] \ No newline at end of file diff --git a/docs/pytorch/solutions/preference-alignment/ray_online_dpo.md b/docs/pytorch/solutions/preference-alignment/ray_online_dpo.md deleted file mode 100644 index 3531230e48da36c4d6b859eead034fdb170bedfe..0000000000000000000000000000000000000000 --- a/docs/pytorch/solutions/preference-alignment/ray_online_dpo.md +++ /dev/null @@ -1,149 +0,0 @@ -# 后训练方法 Ray Online DPO - -Online Direct Preference Optimization (Online DPO) 是 Direct Preference Optimization (DPO) 的一种扩展或变体,旨在通过 在线学习 的方式进一步优化大型语言模型(LLMs)。DPO 是一种基于人类偏好数据的训练方法,而 Online DPO 则专注于在 动态、实时 的环境中使用偏好数据来持续改进模型。 - -Online DPO方法中包含了三个模型:Actor,Reference,Reward。其中Actor/Reference模型是经过预训练和指令微调(Supervised Fine-Tuning,SFT)得到的大语言模型,Reward是训练得到的奖励模型。Online DPO 的训练目标是使得 Actor 模型的回答可以更加符合人类偏好。 - -# 使用说明 - -## 环境配置 - -配置MindSpeed-LLM基础环境: 参考[安装指南](../../../features/install_guide.md) - -## 数据预处理 - -数据集转换参考脚本:MindSpeed-LLM/examples/mcore/llama3/data_convert_llama3_ppo.sh -以 [descriptiveness 数据集](https://huggingface.co/datasets/trl-internal-testing/descriptiveness-sentiment-trl-style/tree/main/data) 为例。 - -```bash -source /usr/local/Ascend/ascend-toolkit/set_env.sh -mkdir ./dataset/llama3-hf/ - -python ./preprocess_data.py \ - --input ./dataset/descriptiveness-00000-of-00001.parquet \ - --tokenizer-name-or-path ./model_from_hf/llama3-hf/ \ - --output-prefix ./dataset/llama3-hf/descriptiveness \ - --workers 16 \ - --log-interval 1000 \ - --tokenizer-type PretrainedFromHF \ - --handler-name PPOAlpacaStyleInstructionHandler \ - --prompt-type llama3 \ - --map-keys '{"prompt":"prompt", "query":"", "response": "prompt", "system":""}' -``` - -## 模型权重转换 - -根据 Online DPO 算法要求,Actor 和 Reference 模型应该使用 SFT 微调后的模型进行初始化,Critic 和 Reward 模型应该使用奖励模型训练后的模型进行初始化。DPO算法模型权重均使用Megatron-mcore格式,其他格式的权重需要进行模型权重转换,具体可参考[权重转换](../checkpoint_convert.md)。 - -下面以llama3.2-1b模型作为示例参考: - -actor_rollout_ref 涉及到的actor_rollout 与 ref 均需要 SFT 微调后的模型,涉及到的权重转换操作与 SFT 阶段的一致。权重转换示例脚本: -llama32-1b - -critic 与 reward 模型需要使用奖励模型训练后的模型,权重转换示例脚本:llama32-1b-orm - - -相应的oneline_dpo_trainer_llama32_1b.yaml配置如下 -``` - actor_rollout_ref: - actor_rollout: - ... - load: ./model_weights/llama32-mcore/ - save: ./model_weights/llama32-mcore-save/ - - ref: - ... - load: ./model_weights/llama32-mcore/ - - reward: - ... - load: ./model_weights/llama32-mcore-orm/ -``` - -## 启动方式 - -### 单机 - -通过 --config-name 传递选取的 config 文件名(不添加.yaml后缀),可以通过下列命令直接启动训练(Llama32 1B 模型可单机运行)。 -目前已支持的配置文件放置在 configs/rlxf/ 文件夹下。配置文件的具体说明见下文。 - -```bash -python ray_gpt.py --config-name online_dpo_trainer_llama32_1b -``` - -### 多机 - -多机运行程序时,需要首先进入对应目录,并激活conda或docker环境: - -```bash -cd MindSpeed-LLM -conda activate xxx -``` - -然后,在主节点上启动 Ray 集群: - -```bash -# 创建一个集群,端口6344,dashboard端口8260,有8个NPU -ray start --head --port 6344 --dashboard-host=0.0.0.0 --dashboard-port=8260 --resources='{"NPU": 8}' -``` - -随后,在其他节点加入主节点的集群 - -```bash -# IP_ADDRESS 处填写主节点 IP 地址 -ray start --address="IP_ADDRESS:6344" --resources='{"NPU": 8}' -``` - -在完成 Ray 集群构建后,在主节点启动运行程序即可(Llama3 8B 模型可双机运行) - -```bash -python ray_gpt.py --config-name online_dpo_trainer_llama3_8b -``` -结束之后,使用如下命令行结束 ray 进程 -``` -ray stop -``` -## 配置文件 - -由于 Online DPO 训练过程中涉及 3 个模型,通过将模型参数和训练配置解耦的层级化参数配置,来简化 Online DPO 训练的参数配置过程。RLXF 训练涉及到的所有配置文件均存储在 configs/rlxf 路径下,其中 model 文件夹下存储了模型结构相关的配置文件,Online DPO训练相关的模型参数文件以online_dpo_trainer_{模型名}.yaml方式命名。 - -在每个 online_dpo_trainer 配置文件中,需要包含defaults,training,resource_pool,algorithm等字段,以及 Online DPO 训练过程中涉及到的 3 个角色 actor,reward,ref的配置。其中: - -1. defaults 负责引入模型配置文件,在 defaults 中应列举本配置文件中所需要用到的所有模型配置,模型配置可以在下方3个角色的具体配置中通过 model 字段进行选择。 -2. training 字段设置的参数为所有 3 个角色通用的默认参数,这些参数可以在下方进一步被角色的单独配置所覆盖。 -3. resource_pool 字段指定了各个角色所需的 NPU 资源数量。 -4. actor,reward,ref 字段分别指定了Online DPO算法中三个角色训练相关的参数配置。 - -## 参数解析 - -相较于普通模型训练,DPO增加一些特殊参数: - -### `training:` - -* `stage`:用于指定训练算法,使用 Ray Online DPO 训练须设置为`ray_online_dpo`; - -### `actor_rollout:` - -* `do_sample`:控制 Actor 模型进行推理时是否采样,默认为 False,Online DPO 需要设置为True ; -* `ppo_mini_batch_size`:Actor 模型的 mini_batch_size,默认为1; -* `max_prompt_length`:DPO 训练中最大 prompt 长度,默认为512; -* `num_samples_per_step`:Actor 推理时每个step的推理样本数量,默认为1; -* `ppo_epochs`:Actor 训练对同一批经验数据的重复次数,默认为1; -* `clip_ratio`:Actor模型训练计算损失函数时的clip比例,默认为0.2 一般取值范围 [0.1,0.3] 最大取值范围[0,1] 该数值越大允许策略更新的幅度越大,反之不然; -* `shuffle_minibatch`:Actor 训练时是否对 minibatch 进行 shuffle,默认为 False; -* `num_gpus_for_train` :Actor 模型分配给训练部分的显卡数量; -* `num_gpus_for_infer` :Actor 模型分配给推理部分的显卡数量; -* `missing_eos_penalty`:缺少序列结束符EOS时的惩罚系数; - -### `resource_pool:` - -* `actor_rollout`:给 Actor 模型训练和推理总共分配的显卡数量; -* `ref`:给 Reference 模型分配的显卡数量; -* `reward`:给 Reward 模型分配的显卡数量; - -# 精度对比 - -我们与 HuggingFace 的强化学习开源仓库 [TRL](https://github.com/huggingface/trl/) 进行了精度对比,来辅助验证算法实现的正确性。因为 Online DPO 1Q2A的特性需求,推理状态do sample 设置为 True,为了与基准方法进行精度对齐,在 Actor 推理时固定 responses 方式进行精度对齐的实验。可以看到,固定 responses 后 loss 能够较好地实现对齐。 - -![online_dpo_loss_compare.png](../../../../sources/images/online_dpo/online_dpo_loss_compare.png) - diff --git a/docs/pytorch/solutions/preference-alignment/ray_ppo.md b/docs/pytorch/solutions/preference-alignment/ray_ppo.md deleted file mode 100644 index c06947e14554f15913d30a112e6a0ebb873aee68..0000000000000000000000000000000000000000 --- a/docs/pytorch/solutions/preference-alignment/ray_ppo.md +++ /dev/null @@ -1,195 +0,0 @@ -# 后训练方法 Ray PPO - -[PPO(Proximal Policy Optimization)](https://arxiv.org/abs/1707.06347)是一种强化对齐微调方法,常用于人类反馈强化学习(Reinforcement Learning with Human Feedback, RLHF)任务。 - -PPO方法中包含了四个模型:Actor,Critic,Reference,Reward。其中Actor/Reference模型是经过预训练和指令微调(Supervised Fine-Tuning,SFT)得到的大语言模型,Critic和Reward是训练得到的奖励模型。PPO 的训练目标是使得 Actor 模型的回答可以更加符合人类偏好。 - -# 使用说明 - -## 环境配置 - -配置MindSpeed-LLM基础环境: 参考[安装指南](../../../features/install_guide.md) - -## 数据预处理 - -数据集转换参考脚本:MindSpeed-LLM\examples\mcore\llama3\data_convert_llama3_ppo.sh -以 [descriptiveness 数据集](https://huggingface.co/datasets/trl-internal-testing/descriptiveness-sentiment-trl-style/tree/main/data) 为例。 - -```bash -source /usr/local/Ascend/ascend-toolkit/set_env.sh -mkdir ./dataset/llama3-hf/ - -python ./preprocess_data.py \ - --input ./dataset/descriptiveness-00000-of-00001.parquet \ - --tokenizer-name-or-path ./model_from_hf/llama3-hf/ \ - --output-prefix ./dataset/llama3-hf/descriptiveness \ - --workers 16 \ - --log-interval 1000 \ - --tokenizer-type PretrainedFromHF \ - --handler-name PPOAlpacaStyleInstructionHandler \ - --prompt-type llama3 \ - --map-keys '{"prompt":"prompt", "query":"", "response": "prompt", "system":""}' -``` - -## 模型权重转换 - -根据 PPO 算法要求,Actor 和 Reference 模型应该使用 SFT 微调后的模型进行初始化,Critic 和 Reward 模型应该使用奖励模型训练后的模型进行初始化。PPO算法模型权重均使用Megatron-mcore格式,其他格式的权重需要进行模型权重转换,具体可参考[权重转换](../checkpoint_convert.md)。 - -下面以llama3.2-1b模型作为示例参考: - -actor_rollout_ref 涉及到的actor_rollout 与 ref 均需要 SFT 微调后的模型,涉及到的权重转换操作与 SFT 阶段的一致。权重转换示例脚本: -llama32-1b - -critic 与 reward 模型需要使用奖励模型训练后的模型,权重转换示例脚本:llama32-1b-orm - - -相应的ppo_trainer_llama32_1b.yaml配置如下 -``` - actor_rollout_ref: - actor_rollout: - ... - load: ./model_weights/llama32-mcore/ - save: ./model_weights/llama32-mcore-save/ - - ref: - ... - load: ./model_weights/llama32-mcore/ - - critic: - ... - load: ./model_weights/llama32-mcore-orm/ - save: ./model_weights/llama32-mcore-orm-save/ - - reward: - ... - load: ./model_weights/llama32-mcore-orm/ -``` - -## 启动方式 - -### 单机 - -通过 --config-name 传递选取的 config 文件名(不添加.yaml后缀),可以通过下列命令直接启动训练(Llama32 1B 模型可单机运行)。 -目前已支持的配置文件放置在 configs/rlxf/ 文件夹下。配置文件的具体说明见下文。 - -```bash -python ray_gpt.py --config-name ppo_trainer_llama32_1b -``` - -### 多机 - -多机运行程序时,需要首先进入对应目录,并激活conda或docker环境: - -```bash -cd MindSpeed-LLM -conda activate xxx -``` - -然后,在主节点上启动 Ray 集群: - -```bash -# 配置最大文件描述符环境变量 -ulimit -n 32768 -# 创建一个集群,端口6344,dashboard端口8260,有8个NPU -ray start --head --port 6344 --dashboard-host=0.0.0.0 --dashboard-port=8260 --resources='{"NPU": 8}' -``` - -随后,在其他节点加入主节点的集群 - -```bash -# 配置最大文件描述符环境变量 -ulimit -n 32768 -# IP_ADDRESS 处填写主节点 IP 地址 -ray start --address="IP_ADDRESS:6344" --resources='{"NPU": 8}' -``` - -在完成 Ray 集群构建后,在主节点启动运行程序即可(Llama3 8B 模型可双机运行) - -```bash -python ray_gpt.py --config-name ppo_trainer_llama3_8b -``` -结束之后,使用如下命令行结束 ray 进程 -``` -ray stop -``` -## 配置文件 - -由于 PPO 训练过程中涉及 4 个模型,通过将模型参数和训练配置解耦的层级化参数配置,来简化 PPO 训练的参数配置过程。RLXF 训练涉及到的所有配置文件均存储在 configs/rlxf 路径下,其中 model 文件夹下存储了模型结构相关的配置文件,PPO训练相关的模型参数文件以ppo_trainer_{模型名}.yaml方式命名。 - -在每个 ppo_trainer 配置文件中,需要包含defaults,training,resource_pool,algorithm等字段,以及 PPO 训练过程中涉及到的 4 个角色 actor,critic,reward,ref的配置。其中: - -1. defaults 负责引入模型配置文件,在 defaults 中应列举本配置文件中所需要用到的所有模型配置,模型配置可以在下方四个角色的具体配置中通过 model 字段进行选择。 -2. training 字段设置的参数为所有 4 个角色通用的默认参数,这些参数可以在下方进一步被角色的单独配置所覆盖。 -3. resource_pool 字段指定了各个角色所需的 NPU 资源数量。 -4. algorithm 字段配置计算PPO中advantages算法的相关参数。 -5. actor,critic,reward,ref 字段分别指定了PPO算法中四个角色训练相关的参数配置。 - -## 参数解析 - -相较于普通模型训练,PPO增加一些特殊参数: - -### `training:` - -* `stage`:用于指定训练算法,使用 Ray PPO 训练须设置为`ray_ppo`; - -### `actor_rollout:` - -* `do_sample`:控制 Actor 模型进行推理时是否采样,默认为 False ; -* `ppo_mini_batch_size`:Actor 模型的 mini_batch_size,默认为1; -* `max_prompt_length`:PPO 训练中最大 prompt 长度,默认为512; -* `num_samples_per_step`:Actor 推理时每个step的推理样本数量,默认为1; -* `ppo_epochs`:Actor 训练对同一批经验数据的重复次数,默认为1; -* `clip_ratio`:Actor模型训练计算损失函数时的clip比例,默认为0.2 一般取值范围 [0.1,0.3] 最大取值范围[0,1] 该数值越大允许策略更新的幅度越大,反之不然; -* `shuffle_minibatch`:Actor 训练时是否对 minibatch 进行 shuffle,默认为 False; -* `num_gpus_for_train` :Actor 模型分配给训练部分的显卡数量; -* `num_gpus_for_infer` :Actor 模型分配给推理部分的显卡数量; - -### `critic:` - -* `cliprange_value`:Critic 模型计算损失函数时 clip 范围,默认为0.2; -* `critic_mini_batch_size`:Critic 模型设置的 mini_batch_size,默认为1; -* `critic_update_epochs`:Critic 训练对同一批经验数据的重复次数,默认为1; - -### `algorithm:` - -* `adv_estimator`:advantages计算的方式,通常采用gae(广义优势估计Generalized Advantage Estimation, GAE); -* `gamma`:计算 advantage 时的折扣因子,取值范围[0, 1]取值趋向于0表示侧重瞬时奖励,趋向于1表示趋向于延迟奖励; -* `lam`:GAE 优势计算的 lambda 值,取值范围[0, 1]取值趋向于0会减少方差,提高收敛速度,但可能引入偏差。趋向于1会增大方差,降低收敛速度,偏差较小。取值为0等价于单步的TD误差,取值为1等价于蒙特卡洛返回; -* `kl_penalty`:KL 散度计算方式; -* `kl_ctrl:` - * `kl_coef`:施加 KL 散度惩罚的系数,取值范围[0, 1]越大越会限制策略更新,越小则允许策略更新越大; - * `type`:KL 散度惩罚的系数类型; -* `missing_eos_penalty`:缺少序列结束符EOS时的惩罚系数; - -### `resource_pool:` - -* `actor_rollout`:给 Actor 模型训练和推理总共分配的显卡数量; -* `ref`:给 Reference 模型分配的显卡数量; -* `critic`:给 Critic 模型分配的显卡数量; -* `reward`:给 Reward 模型分配的显卡数量; - -# 精度对比 - -我们与 HuggingFace 的强化学习开源仓库 [TRL](https://github.com/huggingface/trl/) 进行了精度对比,来辅助验证算法实现的正确性。为了与基准方法进行精度对齐,在 Actor 推理时采用贪婪(greedy)策略去除随机性,训练过程中的 critic loss和 actor loss对比如下图所示。 - -
- Image description -
未固定 responses 时 loss 对比图 (左为 actor loss,右为 critc loss)
-
- -然而,由于 greedy 方法的策略为选取 logits 最大的 token,当如果两个 token 的 logits 值十分接近时,可能会导致选取的 token 的结果产生偏差。这种误差会被多次迭代逐步累计放大,最终影响到 loss 精度对齐。 - -因此,我们额外补充了固定 responses 方式进行精度对齐的实验。可以看到,固定 responses 后 loss 能够较好地实现对齐。 - -
- Image description -
固定 responses 后的 loss 对比图 (左为 actor loss,右为 critc loss)
-
- -注: 为了验证 actor loss 的精度对齐效果,这里并未直接对比 PPO 算法中记录的 actor loss。这是由于 PPO 算法在计算 advantages 时,为保证算法的稳定性,会在 Actor 训练过程中在 minibatch 间做白化操作(将其分布的均值调整为0,方差调整为1)。这导致 Actor 虽然在进行梯度更新时使用每个 minibatch 计算的 loss,但记录下来的 minibatch 间的 loss 均值接近于 0 。因此,我们选择了记录 Actor 每个 minibatch loss 绝对值的均值,来验证精度对齐效果。 - - -# 参考文献 - -[PPO](https://arxiv.org/abs/1707.06347) - diff --git a/mindspeed_llm/core/pipeline_parallel/p2p_communication.py b/mindspeed_llm/core/pipeline_parallel/p2p_communication.py index 7ba04f58c2b16b7942a9aab9f059267f7d2815d2..079bf27cb286c500a8bb3bd1c66be69630fd7943 100644 --- a/mindspeed_llm/core/pipeline_parallel/p2p_communication.py +++ b/mindspeed_llm/core/pipeline_parallel/p2p_communication.py @@ -68,93 +68,3 @@ def _batched_p2p_ops( else: reqs = [] return reqs - - - -def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config): - """ - Add group=get_pipeline_model_parallel_group() in P2POp shape communications to avoid error. - Currently only enable in ray scenarios. - """ - - recv_prev_shape_tensor = None - recv_next_shape_tensor = None - send_prev_shape_tensor = None - send_next_shape_tensor = None - if recv_prev: - recv_prev_shape_tensor = torch.empty( - (3), device=torch.cuda.current_device(), dtype=torch.int64 - ) - if recv_next: - recv_next_shape_tensor = torch.empty( - (3), device=torch.cuda.current_device(), dtype=torch.int64 - ) - if tensor_send_prev is not None: - send_prev_shape_tensor = torch.tensor( - tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64 - ) - if tensor_send_next is not None: - send_next_shape_tensor = torch.tensor( - tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64 - ) - - if config.use_ring_exchange_p2p: - torch.distributed.ring_exchange( - tensor_send_prev=send_prev_shape_tensor, - tensor_recv_prev=recv_prev_shape_tensor, - tensor_send_next=send_next_shape_tensor, - tensor_recv_next=recv_next_shape_tensor, - group=get_pipeline_model_parallel_group(), - ) - else: - ops = [] - if send_prev_shape_tensor is not None: - send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, - send_prev_shape_tensor, - get_pipeline_model_parallel_prev_rank(), - group=get_pipeline_model_parallel_group(), - ) - ops.append(send_prev_op) - if recv_prev_shape_tensor is not None: - recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, - recv_prev_shape_tensor, - get_pipeline_model_parallel_prev_rank(), - group=get_pipeline_model_parallel_group(), - ) - ops.append(recv_prev_op) - if send_next_shape_tensor is not None: - send_next_op = torch.distributed.P2POp( - torch.distributed.isend, - send_next_shape_tensor, - get_pipeline_model_parallel_next_rank(), - group=get_pipeline_model_parallel_group(), - ) - ops.append(send_next_op) - if recv_next_shape_tensor is not None: - recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, - recv_next_shape_tensor, - get_pipeline_model_parallel_next_rank(), - group=get_pipeline_model_parallel_group(), - ) - ops.append(recv_next_op) - if len(ops) > 0: - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() - - # To protect against race condition when using batch_isend_irecv(). - # should take this out once the bug with batch_isend_irecv is resolved. - torch.cuda.synchronize() - - recv_prev_shape = [0, 0, 0] - if recv_prev_shape_tensor is not None: - recv_prev_shape = recv_prev_shape_tensor.tolist() - - recv_next_shape = [0, 0, 0] - if recv_next_shape_tensor is not None: - recv_next_shape = recv_next_shape_tensor.tolist() - - return recv_prev_shape, recv_next_shape diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 9214ba6daaf3f95a917ed7e682492a78acc827d5..3048a4693dfcf7f4f5485835cfd4f3158de19bc7 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -18,7 +18,6 @@ import sys import types import argparse import torch -import tensordict from torch_npu.contrib import transfer_to_npu from mindspeed_llm.features_manager import FEATURES_LIST @@ -763,7 +762,6 @@ class LegacyAdaptation(MegatronAdaptationABC): self.patch_inference() self.patch_log_handler() self.patch_optimizer() - self.patch_2megatron() def patch_log_handler(self): from megatron.training.log_handler import CustomHandler @@ -929,24 +927,4 @@ class LegacyAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.training.training.setup_model_and_optimizer', setup_model_and_optimizer_wrapper) - def patch_2megatron(self): - # This patch is only for running ray scenarios - if "--raylet-name" in "".join(sys.argv): - from mindspeed_llm.core.pipeline_parallel.p2p_communication import _communicate_shapes - from mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state import (rank_generator_init_wrapper, rank_generator_get_ranks_wrapper) - from mindspeed_llm.tasks.posttrain.rlxf.training.initialize import initialize_megatron - from mindspeed_llm.tasks.posttrain.rlxf.training.initialize import barrier_wrapper, broadcast_wrapper, is_last_rank, \ - get_world_size_wrapper, get_elapsed_time_all_ranks - - MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._communicate_shapes', _communicate_shapes) - MegatronAdaptation.register('megatron.core.parallel_state.RankGenerator.__init__', rank_generator_init_wrapper) - MegatronAdaptation.register('megatron.core.parallel_state.RankGenerator.get_ranks', rank_generator_get_ranks_wrapper) - MegatronAdaptation.register('megatron.training.utils.is_last_rank', is_last_rank) - MegatronAdaptation.register('megatron.core.timers.Timers._get_elapsed_time_all_ranks', get_elapsed_time_all_ranks) - MegatronAdaptation.register('torch.distributed.barrier', barrier_wrapper) - MegatronAdaptation.register('torch.distributed.all_reduce', barrier_wrapper) - MegatronAdaptation.register('torch.distributed.broadcast', broadcast_wrapper) - MegatronAdaptation.register('torch.distributed.get_world_size', get_world_size_wrapper) - MegatronAdaptation.register('megatron.training.initialize.initialize_megatron', initialize_megatron, force_patch=True) - MegatronAdaptation.execute() \ No newline at end of file diff --git a/mindspeed_llm/tasks/posttrain/launcher.py b/mindspeed_llm/tasks/posttrain/launcher.py index fc0c6e6d15f38355f9f22c5fffc2f007e83b7b83..f3d1a15a00024992097265a5d92e49bda7d992a2 100644 --- a/mindspeed_llm/tasks/posttrain/launcher.py +++ b/mindspeed_llm/tasks/posttrain/launcher.py @@ -3,9 +3,6 @@ import logging from megatron.training import get_args from megatron.training.initialize import initialize_megatron -from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.online_dpo_trainer import RayOnlineDPOTrainer -from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.ppo_trainer import RayPPOTrainer -from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.grpo_trainer import RayGRPOTrainer from mindspeed_llm.tasks.posttrain.sft import SFTTrainer from mindspeed_llm.tasks.posttrain.dpo import DPOTrainer from mindspeed_llm.tasks.posttrain.orm import ORMTrainer @@ -33,14 +30,8 @@ def get_trainer(stage): return PRMTrainer() elif stage == "simpo": return SimPOTrainer() - elif stage == "ray_ppo": - return RayPPOTrainer - elif stage == "ray_online_dpo": - return RayOnlineDPOTrainer elif stage == "trl_ppo": return TrlPPOTrainer() - elif stage == "ray_grpo": - return RayGRPOTrainer else: logger.info(f'Unknown Stage: {stage}') return None diff --git a/mindspeed_llm/tasks/posttrain/rlxf/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/__init__.py deleted file mode 100644 index d3f0b85929e87e26ba235b74cb211ba2ac519154..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. \ No newline at end of file diff --git a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/__init__.py deleted file mode 100644 index d3f0b85929e87e26ba235b74cb211ba2ac519154..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. \ No newline at end of file diff --git a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/grpo_trainer.py b/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/grpo_trainer.py deleted file mode 100644 index 396a9dd34c25d250d51165e57c076336b7675d8a..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/grpo_trainer.py +++ /dev/null @@ -1,188 +0,0 @@ -from typing import Type -from codetiming import Timer - -from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.ppo_trainer import ResourcePoolManager, Role -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.megatron import NVMegatronRayWorkerGroup -from mindspeed_llm.tasks.posttrain.rlxf.training.core_algos import compute_grpo_data_metrics, reduce_metrics, \ - compute_advantage, compute_score, FixedKLController, AdaptiveKLController -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import Worker -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.base import create_colocated_worker_cls, \ - set_actor_infer_world_size, set_actor_train_world_size, RayClassWithInitArgs -from mindspeed_llm.tasks.posttrain.rlxf.utils.loggers import Loggers -from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import PPOActorWorker -from mindspeed_llm.tasks.posttrain.rlxf.workers.reference import ReferenceWorker -from mindspeed_llm.tasks.posttrain.rlxf.workers.reward import RewardWorker - - -WorkerType = Type[Worker] - - -class RayGRPOTrainer(object): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - def __init__(self, config): - - self.config = config - if hasattr(self.config.training, "dataset_additional_keys"): - self.config.training.dataset_additional_keys = config.training.dataset_additional_keys.strip().split(" ") if config.training.dataset_additional_keys else [] - - self.role_worker_mapping = { - Role.ActorRollout: PPOActorWorker, - Role.RefPolicy: ReferenceWorker, - Role.RewardModel: RewardWorker - } - actor_pool_id = 'actor_pool' - ref_pool_id = 'ref_pool' - reward_pool_id = 'reward_pool' - - if config.resource_pool.reward: - resource_pool_spec = { - actor_pool_id: config.resource_pool.actor_rollout, - ref_pool_id: config.resource_pool.ref, - reward_pool_id: config.resource_pool.reward, - } - mapping = { - Role.ActorRollout: actor_pool_id, - Role.RefPolicy: ref_pool_id, - Role.RewardModel: reward_pool_id, - } - else: - resource_pool_spec = { - actor_pool_id: config.resource_pool.actor_rollout, - ref_pool_id: config.resource_pool.ref - } - - mapping = { - Role.ActorRollout: actor_pool_id, - Role.RefPolicy: ref_pool_id - } - - self.resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - self.use_reference_policy = Role.RefPolicy in self.role_worker_mapping - self.ray_worker_group_cls = NVMegatronRayWorkerGroup - - # define KL control - if self.use_reference_policy: - if config.algorithm.kl_ctrl.type == 'fixed': - self.kl_ctrl = FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef) - elif config.algorithm.kl_ctrl.type == 'adaptive': - if config.algorithm.kl_ctrl.horizon <= 0: - raise ValueError(f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}') - self.kl_ctrl = AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef, - target_kl=config.algorithm.kl_ctrl.target_kl, - horizon=config.algorithm.kl_ctrl.horizon) - else: - raise NotImplementedError - else: - self.kl_ctrl = FixedKLController(kl_coef=0.) - self.init_workers() - - def init_workers(self): - """Init resource pool and worker group""" - set_actor_infer_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_infer) - set_actor_train_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_train) - self.resource_pool_manager.create_resource_pool() - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} - - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config, - role='actor_rollout') - self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls - - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], - config=self.config, - role='ref') - self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls - - if self.config.resource_pool.reward: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - reward_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.RewardModel], - config=self.config, - role='reward') - self.resource_pool_to_cls[resource_pool]['reward'] = reward_cls - - # initialize WorkerGroup - all_wg = {} - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - - self.ref_policy_wg = all_wg.get('ref') - self.ref_policy_wg.initialize() - - self.actor_rollout_wg = all_wg.get('actor_rollout') - self.actor_rollout_wg.initialize() - - if self.config.resource_pool.reward: - self.reward_wg = all_wg.get('reward') - self.reward_wg.initialize() - else: - self.reward_wg = None - - def train(self): - """ - The training loop of PPO. - """ - logger = Loggers() - - iteration = self.actor_rollout_wg.get_iteration()[0] - - while iteration < self.config.training.train_iters: - with Timer(name='gen', logger=None) as all_timer: - metrics = {} - self.actor_rollout_wg.auto_mapping() - # generate a batch - with Timer(name='gen', logger=None) as timer: - batch = self.actor_rollout_wg.generate_sequences() - batch = self.actor_rollout_wg.get_log_probs(batch) - metrics['timing/gen'] = timer.last - - if self.use_reference_policy: - # compute reference log_prob - with Timer(name='ref', logger=None) as timer: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - metrics['timing/ref'] = timer.last - - with Timer(name='adv', logger=None) as timer: - # compute rm scores. - batch = compute_score( - self.reward_wg, - batch, - metrics, - self.config - ) - - # compute advantages, executed on the driver process - batch = compute_advantage( - batch, - self.config - ) - - metrics['timing/adv'] = timer.last - kl_info = {'kl_ctrl': self.kl_ctrl} - batch.meta_info.update(kl_info) - - # update actor - with Timer(name='update_actor', logger=None) as timer: - actor_output = self.actor_rollout_wg.update_actor(batch) - metrics['timing/update_actor'] = timer.last - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) - metrics.update(actor_output_metrics) - - # collect metrics - data_metrics = compute_grpo_data_metrics(batch=batch) - metrics.update(data_metrics) - metrics['timing/all'] = all_timer.last - iteration += 1 - logger.info(metrics, iteration, self.config.training.train_iters) - logger.flush() - - if iteration % self.config.training.save_interval == 0: - self.actor_rollout_wg.save_checkpoint(iteration) diff --git a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/online_dpo_trainer.py b/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/online_dpo_trainer.py deleted file mode 100644 index 347e4919e72b6d5dc4669134cb74d26c94c6707f..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/online_dpo_trainer.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. -from typing import Type - -from codetiming import Timer - -from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.ppo_trainer import ResourcePoolManager, Role -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import Worker -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.base import create_colocated_worker_cls, \ - set_actor_infer_world_size, set_actor_train_world_size, RayClassWithInitArgs -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.megatron import NVMegatronRayWorkerGroup -from mindspeed_llm.tasks.posttrain.rlxf.training.core_algos import compute_data_online_dpo_metrics, reduce_metrics -from mindspeed_llm.tasks.posttrain.rlxf.utils.loggers import Loggers -from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import PPOActorWorker -from mindspeed_llm.tasks.posttrain.rlxf.workers.reference import ReferenceWorker -from mindspeed_llm.tasks.posttrain.rlxf.workers.reward import RewardWorker - -WorkerType = Type[Worker] - - -class RayOnlineDPOTrainer(object): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - def __init__(self, config): - self.config = config - self.role_worker_mapping = { - Role.ActorRollout: PPOActorWorker, - Role.RefPolicy: ReferenceWorker, - Role.RewardModel: RewardWorker - } - actor_pool_id = 'actor_pool' - ref_pool_id = 'ref_pool' - reward_pool_id = 'reward_pool' - - resource_pool_spec = { - actor_pool_id: config.resource_pool.actor_rollout, - ref_pool_id: config.resource_pool.ref, - reward_pool_id: config.resource_pool.reward, - } - - mapping = { - Role.ActorRollout: actor_pool_id, - Role.RefPolicy: ref_pool_id, - Role.RewardModel: reward_pool_id, - } - - self.resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - self.use_reference_policy = Role.RefPolicy in self.role_worker_mapping - self.ray_worker_group_cls = NVMegatronRayWorkerGroup - self.init_workers() - - - def init_workers(self): - """Init resource pool and worker group""" - set_actor_infer_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_infer) - set_actor_train_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_train) - self.resource_pool_manager.create_resource_pool() - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} - - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config, - role='actor_rollout') - self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls - - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], - config=self.config, - role='ref') - self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls - - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - reward_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.RewardModel], - config=self.config, - role='reward') - self.resource_pool_to_cls[resource_pool]['reward'] = reward_cls - - # initialize WorkerGroup - all_wg = {} - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - - self.ref_policy_wg = all_wg.get('ref') - self.ref_policy_wg.initialize() - - self.actor_rollout_wg = all_wg.get('actor_rollout') - self.actor_rollout_wg.initialize() - - self.reward_wg = all_wg.get('reward') - self.reward_wg.initialize() - - - def train(self): - """ - The training loop of online DPO. - """ - logger = Loggers() - iteration = self.actor_rollout_wg.get_iteration()[0] - while iteration < self.config.training.train_iters: - with Timer(name='gen', logger=None) as all_timer: - metrics = {} - self.actor_rollout_wg.auto_mapping() - # generate a batch - with Timer(name='gen', logger=None) as timer: - batch = self.actor_rollout_wg.generate_sequences() - metrics['timing/gen'] = timer.last - - if self.use_reference_policy: - # compute reference log_prob - with Timer(name='ref', logger=None) as timer: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - metrics['timing/ref'] = timer.last - - with Timer(name='adv', logger=None) as timer: - # compute scores. - reward_tensor = self.reward_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - metrics['timing/adv'] = timer.last - - # update actor - with Timer(name='update_actor', logger=None) as timer: - actor_output = self.actor_rollout_wg.update_actor(batch) - metrics['timing/update_actor'] = timer.last - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) - metrics.update(actor_output_metrics) - - # collect metrics - data_metrics = compute_data_online_dpo_metrics(batch=batch) - metrics.update(data_metrics) - metrics['timing/all'] = all_timer.last - iteration += 1 - logger.info(metrics, iteration, self.config.training.train_iters) - - if iteration % self.config.training.save_interval == 0: - self.actor_rollout_wg.save_checkpoint(iteration) diff --git a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/ppo_trainer.py b/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/ppo_trainer.py deleted file mode 100644 index 8c31b706c5ff664adcfa7520c665b89cbbd40b68..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/ppo_trainer.py +++ /dev/null @@ -1,229 +0,0 @@ -from dataclasses import dataclass, field -from enum import Enum -from typing import Type, Dict, List - -from codetiming import Timer - -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.megatron import NVMegatronRayWorkerGroup -from mindspeed_llm.tasks.posttrain.rlxf.training.core_algos import compute_data_metrics, reduce_metrics, compute_advantage, \ - apply_kl_penalty, FixedKLController, AdaptiveKLController -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import Worker -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.base import create_colocated_worker_cls, \ - set_actor_infer_world_size, set_actor_train_world_size, RayResourcePool, RayClassWithInitArgs -from mindspeed_llm.tasks.posttrain.rlxf.utils.loggers import Loggers -from mindspeed_llm.tasks.posttrain.rlxf.workers.critic import CriticWorker -from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import PPOActorWorker -from mindspeed_llm.tasks.posttrain.rlxf.workers.reference import ReferenceWorker -from mindspeed_llm.tasks.posttrain.rlxf.workers.reward import RewardWorker - -WorkerType = Type[Worker] - - -class Role(Enum): - """ - To create more roles dynamically, you can subclass Role and add new members - """ - ActorRollout = 0 - Critic = 1 - RefPolicy = 2 - RewardModel = 3 - - -@dataclass -class ResourcePoolManager: - """ - Define a resource pool specification. Resource pool will be initialized first. - Mapping - """ - resource_pool_spec: Dict[str, List[int]] - mapping: Dict[Role, str] - resource_pool_dict: Dict[str, RayResourcePool] = field(default_factory=dict) - - def create_resource_pool(self): - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1) - self.resource_pool_dict[resource_pool_name] = resource_pool - - def get_resource_pool(self, role: Role) -> RayResourcePool: - """Get the resource pool of the worker_cls""" - print(f"role:{role}, self.mapping[role]:{self.mapping[role]}") - return self.resource_pool_dict[self.mapping[role]] - - -class RayPPOTrainer(object): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - def __init__(self, config): - - self.config = config - self.role_worker_mapping = { - Role.ActorRollout: PPOActorWorker, - Role.RefPolicy: ReferenceWorker, - Role.Critic: CriticWorker, - Role.RewardModel: RewardWorker - } - actor_pool_id = 'actor_pool' - ref_pool_id = 'ref_pool' - critic_pool_id = 'critic_pool' - reward_pool_id = 'reward_pool' - - resource_pool_spec = { - actor_pool_id: config.resource_pool.actor_rollout, - ref_pool_id: config.resource_pool.ref, - critic_pool_id: config.resource_pool.critic, - reward_pool_id: config.resource_pool.reward, - } - - mapping = { - Role.ActorRollout: actor_pool_id, - Role.RefPolicy: ref_pool_id, - Role.Critic: critic_pool_id, - Role.RewardModel: reward_pool_id, - } - - self.resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - self.use_reference_policy = Role.RefPolicy in self.role_worker_mapping - self.ray_worker_group_cls = NVMegatronRayWorkerGroup - - # define KL control - if self.use_reference_policy: - if config.algorithm.kl_ctrl.type == 'fixed': - self.kl_ctrl = FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef) - elif config.algorithm.kl_ctrl.type == 'adaptive': - if config.algorithm.kl_ctrl.horizon <= 0: - raise ValueError(f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}') - self.kl_ctrl = AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef, - target_kl=config.algorithm.kl_ctrl.target_kl, - horizon=config.algorithm.kl_ctrl.horizon) - else: - raise NotImplementedError - else: - self.kl_ctrl = FixedKLController(kl_coef=0.) - self.init_workers() - - def init_workers(self): - """Init resource pool and worker group""" - set_actor_infer_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_infer) - set_actor_train_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_train) - self.resource_pool_manager.create_resource_pool() - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} - - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config, - role='actor_rollout') - self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls - - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], - config=self.config, - role='ref') - self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls - - resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], - config=self.config, - role='critic') - self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls - - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - reward_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.RewardModel], - config=self.config, - role='reward') - self.resource_pool_to_cls[resource_pool]['reward'] = reward_cls - - # initialize WorkerGroup - all_wg = {} - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - - self.ref_policy_wg = all_wg.get('ref') - self.ref_policy_wg.initialize() - - self.actor_rollout_wg = all_wg.get('actor_rollout') - self.actor_rollout_wg.initialize() - - self.critic_wg = all_wg.get('critic') - self.critic_wg.initialize() - - self.reward_wg = all_wg.get('reward') - self.reward_wg.initialize() - - def train(self): - """ - The training loop of PPO. - """ - logger = Loggers() - - iteration = self.actor_rollout_wg.get_iteration()[0] - - while iteration < self.config.training.train_iters: - with Timer(name='gen', logger=None) as all_timer: - metrics = {} - self.actor_rollout_wg.auto_mapping() - # generate a batch - with Timer(name='gen', logger=None) as timer: - batch = self.actor_rollout_wg.generate_sequences() - batch = self.actor_rollout_wg.get_log_probs(batch) - metrics['timing/gen'] = timer.last - - if self.use_reference_policy: - # compute reference log_prob - with Timer(name='ref', logger=None) as timer: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - metrics['timing/ref'] = timer.last - - # compute values - with Timer(name='values', logger=None) as timer: - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - metrics['timing/values'] = timer.last - - with Timer(name='adv', logger=None) as timer: - # compute rm scores. - reward_tensor = self.reward_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # compute rewards. apply_kl_penalty if available - batch, kl_metrics = apply_kl_penalty(self.config, batch, - kl_ctrl=self.kl_ctrl) - metrics.update(kl_metrics) - - # compute advantages, executed on the driver process - batch = compute_advantage( - batch, - self.config - ) - metrics['timing/adv'] = timer.last - - # update critic - with Timer(name='update_critic', logger=None) as timer: - critic_output = self.critic_wg.update_critic(batch) - metrics['timing/update_critic'] = timer.last - critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) - metrics.update(critic_output_metrics) - - # update actor - with Timer(name='update_actor', logger=None) as timer: - actor_output = self.actor_rollout_wg.update_actor(batch) - metrics['timing/update_actor'] = timer.last - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) - metrics.update(actor_output_metrics) - - # collect metrics - data_metrics = compute_data_metrics(batch=batch) - metrics.update(data_metrics) - metrics['timing/all'] = all_timer.last - iteration += 1 - logger.info(metrics, iteration, self.config.training.train_iters) - - if iteration % self.config.training.save_interval == 0: - self.critic_wg.save_checkpoint(iteration) - self.actor_rollout_wg.save_checkpoint(iteration) - diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/__init__.py deleted file mode 100644 index 75846436cd1285259d2bae6d4a7f190aebed1a80..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .worker import Worker -from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/decorator.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/decorator.py deleted file mode 100644 index 1597d16fd929f251c0a51854c21977b69ac20189..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/decorator.py +++ /dev/null @@ -1,484 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum -from functools import wraps, partial -from typing import Dict, List, Tuple -from types import FunctionType - - -# here we add a magic number of avoid user-defined function already have this attribute -MAGIC_ATTR = 'attrs_3141562937' - - -class Dispatch(Enum): - RANK_ZERO = 0 - ONE_TO_ALL = 1 - ALL_TO_ALL = 2 - MEGATRON_COMPUTE = 3 - MEGATRON_PP_AS_DP = 4 - MEGATRON_PP_ONLY = 5 - MEGATRON_COMPUTE_PROTO = 6 - MEGATRON_PP_AS_DP_PROTO = 7 - DP_COMPUTE = 8 - DP_COMPUTE_PROTO = 9 - DP_COMPUTE_PROTO_WITH_FUNC = 10 - DP_COMPUTE_METRIC = 11 - DP_ALL_GATHER_TRAIN = 12 - DP_ALL_GATHER_INFER = 13 - - - -class Execute(Enum): - ALL = 0 - RANK_ZERO = 1 - INFER = 2 - TRAIN = 3 - - -def _split_args_kwargs_data_proto(chunks, *args, **kwargs): - from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, DataProtoFuture - splitted_args = [] - for arg in args: - if not isinstance(arg, (DataProto, DataProtoFuture)): - raise TypeError(f"Argument {arg} must be an instance of DataProto or DataProtoFuture. Got {type(arg)}") - splitted_args.append(arg.chunk(chunks=chunks)) - - splitted_kwargs = {} - for key, val in kwargs.items(): - if not isinstance(val, (DataProto, DataProtoFuture)): - raise TypeError(f"Value for key {key} must be an instance of DataProto or DataProtoFuture. Got {type(val)}") - splitted_kwargs[key] = val.chunk(chunks=chunks) - - return splitted_args, splitted_kwargs - - -def dispatch_one_to_all(worker_group, *args, **kwargs): - args = tuple([arg] * worker_group.world_size for arg in args) - kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} - return args, kwargs - - -def dispatch_all_to_all(worker_group, *args, **kwargs): - return args, kwargs - - -def collect_all_to_all(worker_group, output): - return output - - -def dispatch_megatron_compute(worker_group, *args, **kwargs): - """ - User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp - """ - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup - if not isinstance(worker_group, MegatronWorkerGroup): - raise TypeError(f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}') - - all_args = [] - for arg in args: - if not isinstance(arg, (Tuple, List)) or len(arg) != worker_group.dp_size: - raise ValueError(f'Each argument must be a Tuple or List of length {worker_group.dp_size}, Got length {len(arg)}') - transformed_args = [] - for i in range(worker_group.world_size): - local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank - transformed_args.append(arg[local_dp_rank]) - all_args.append(transformed_args) - all_args = tuple(all_args) - - all_kwargs = {} - for k, v in kwargs.items(): - if not isinstance(v, (Tuple, List)) or len(v) != worker_group.dp_size: - raise ValueError(f'Each argument in kwargs must be a Tuple or List of length {worker_group.dp_size}, Got length {len(v)}') - transformed_v = [] - for i in range(worker_group.world_size): - local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank - transformed_v.append(v[local_dp_rank]) - all_kwargs[k] = transformed_v - return all_args, all_kwargs - - -def collect_megatron_compute(worker_group, output): - """ - Only collect the data from the tp=0 and pp=last and every dp ranks - """ - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup - if not isinstance(worker_group, MegatronWorkerGroup): - raise TypeError(f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}') - output_in_dp = [] - pp_size = worker_group.get_megatron_global_info().pp_size - for global_rank in range(worker_group.world_size): - local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) - if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1: - output_in_dp.append(output[global_rank]) - return output_in_dp - - -def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): - """ - All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank - """ - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup - if not isinstance(worker_group, MegatronWorkerGroup): - raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}') - - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs) - return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs) - - -def _concat_data_proto_or_future(output: List): - from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, DataProtoFuture - import ray - - # make sure all the elements in output has the same type - for single_output in output: - if not isinstance(single_output, type(output[0])): - raise TypeError(f"All elements in output must have the same type. Found {type(single_output)} and {type(output[0])}") - - output_prime = output[0] - - if isinstance(output_prime, DataProto): - return DataProto.concat(output) - elif isinstance(output_prime, ray.ObjectRef): - return DataProtoFuture.concat(output) - else: - raise NotImplementedError - - -def collect_megatron_compute_data_proto(worker_group, output): - """ - Each output must be a DataProto. We concat the dim=0 of output - """ - from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto - import ray - - output = collect_megatron_compute(worker_group, output) - for single_output in output: - if not isinstance(single_output, (DataProto, ray.ObjectRef)): - raise TypeError(f"Expecting {single_output} to be DataProto or ray.ObjectRef, but got {type(single_output)}") - - return _concat_data_proto_or_future(output) - - -def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): - """ - treat pp as dp. - """ - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup - if not isinstance(worker_group, MegatronWorkerGroup): - raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}') - - pp_size = worker_group.pp_size - dp_size = worker_group.dp_size - - pp_dp_size = pp_size * dp_size - - all_args = [] - for arg in args: - if not isinstance(arg, (List, Tuple)) or len(arg) != pp_dp_size: - raise ValueError(f'Each argument in args must be a List or Tuple of length {pp_dp_size}, but got length {len(arg)}') - transformed_args = [] - for i in range(worker_group.world_size): - local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank - local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank - # compute the rank in arg. Note that the order is dp then pp - # Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected. - # For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order: - # dispatch: pp_allgther: collect: - # dp 0 1 2 3 dp 0 1 2 3 - # pp +---------+ pp +-------------+ - # 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH - # 1 | B D F H | 1 | AB CD EF GH | - # +---------+ +-------------+ - arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank - - transformed_args.append(arg[arg_rank]) - all_args.append(transformed_args) - all_args = tuple(all_args) - - all_kwargs = {} - for k, v in kwargs.items(): - if not isinstance(v, (List, Tuple)) or len(v) != pp_dp_size: - raise ValueError(f'Each argument in kwargs must be a List or Tuple of length {pp_dp_size}, but got length {len(v)}') - transformed_v = [] - for i in range(worker_group.world_size): - local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank - local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank - # compute the rank in arg. Note that the order is dp then pp - arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank - transformed_v.append(v[arg_rank]) - all_kwargs[k] = transformed_v - return all_args, all_kwargs - - -def collect_megatron_pp_as_dp(worker_group, output): - """ - treat pp as dp. Only collect data on tp=0 - """ - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup - if not isinstance(worker_group, MegatronWorkerGroup): - raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}') - output_in_dp = [] - for global_rank in range(worker_group.world_size): - local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) - if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == 0: - output_in_dp.append(output[global_rank]) - return output_in_dp - - -def collect_megatron_pp_only(worker_group, output): - """ - Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp - """ - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup - if not isinstance(worker_group, MegatronWorkerGroup): - raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}') - output_in_pp = [] - for global_rank in range(worker_group.world_size): - local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) - if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0: - output_in_pp.append(output[global_rank]) - return output_in_pp - - -def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup - if not isinstance(worker_group, MegatronWorkerGroup): - raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}') - - pp_dp_size = worker_group.dp_size * worker_group.pp_size - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_size, *args, **kwargs) - return dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs) - - -def collect_megatron_pp_as_dp_data_proto(worker_group, output): - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup - if not isinstance(worker_group, MegatronWorkerGroup): - raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}') - - output = collect_megatron_pp_as_dp(worker_group, output) - return _concat_data_proto_or_future(output) - - -def dispatch_dp_compute(worker_group, *args, **kwargs): - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker_group import WorkerGroup - if not isinstance(worker_group, WorkerGroup): - raise TypeError(f'worker_group must be an instance of WorkerGroup. Got {type(worker_group)}') - for arg in args: - if not isinstance(arg, (Tuple, List)) or len(arg) != worker_group.world_size: - raise ValueError(f'Each argument in args must be a Tuple or List of length {worker_group.world_size}') - for _, v in kwargs.items(): - if not isinstance(v, (Tuple, List)) or len(v) != worker_group.world_size: - raise ValueError(f'Each argument in kwargs must be a Tuple or List of length {worker_group.world_size}') - return args, kwargs - - -def collect_dp_compute(worker_group, output): - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker_group import WorkerGroup - if not isinstance(worker_group, WorkerGroup): - raise TypeError(f'worker_group must be an instance of WorkerGroup. Got {type(worker_group)}') - - if len(output) != worker_group.world_size: - raise ValueError(f'Output must have a length equal to world_size. Got length {len(output)}') - return output - - -def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker_group import WorkerGroup - if not isinstance(worker_group, WorkerGroup): - raise TypeError(f'worker_group must be an instance of WorkerGroup. Got {type(worker_group)}') - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) - return splitted_args, splitted_kwargs - - -def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker_group import WorkerGroup - if not isinstance(worker_group, WorkerGroup): - raise TypeError(f'worker_group must be an instance of WorkerGroup. Got {type(worker_group)}') - - if type(args[0]) != FunctionType: - raise TypeError(f'The first argument must be a callable function. Got {type(args[0])}') # NOTE: The first one args is a function! - - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) - splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args - return splitted_args_with_func, splitted_kwargs - - -def collect_dp_compute_data_proto(worker_group, output): - import ray - from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto - for single_output in output: - if not isinstance(single_output, (DataProto, ray.ObjectRef)): - raise TypeError(f"Expecting {single_output} to be DataProto or ray.ObjectRef, but got {type(single_output)}") - - output = collect_dp_compute(worker_group, output) - return _concat_data_proto_or_future(output) - - -def collect_dp_all_gather(worker_group, output, is_train): - """ - collect data in DP groups, in each DP group, only use the output return on TP_0 PP_last. - """ - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup - if not isinstance(worker_group, MegatronWorkerGroup): - raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}') - output_in_dp = [] - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.base import get_actor_train_world_size - actor_train_world_size = get_actor_train_world_size() - pp_size = worker_group.get_megatron_global_info().pp_size if is_train else 1 - rank_offset = 0 if is_train else actor_train_world_size - for global_rank in range(worker_group.world_size): - is_train_node = global_rank < actor_train_world_size - if is_train_node and not is_train: - continue - elif not is_train_node and is_train: - continue - local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) - if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1: - output_in_dp.append(output[global_rank - rank_offset]) - return _concat_data_proto_or_future(output_in_dp) - -collect_dp_train = partial(collect_dp_all_gather, is_train=True) -collect_dp_infer = partial(collect_dp_all_gather, is_train=False) - - - -def get_predefined_dispatch_fn(dispatch_mode): - predefined_dispatch_mode_fn = { - Dispatch.ONE_TO_ALL: { - 'dispatch_fn': dispatch_one_to_all, - 'collect_fn': collect_all_to_all, - }, - Dispatch.ALL_TO_ALL: { - 'dispatch_fn': dispatch_all_to_all, - 'collect_fn': collect_all_to_all, - }, - Dispatch.MEGATRON_COMPUTE: { - 'dispatch_fn': dispatch_megatron_compute, - 'collect_fn': collect_megatron_compute, - }, - Dispatch.MEGATRON_PP_AS_DP: { - 'dispatch_fn': dispatch_megatron_pp_as_dp, - 'collect_fn': collect_megatron_pp_as_dp, - }, - Dispatch.MEGATRON_PP_ONLY: { - 'dispatch_fn': dispatch_one_to_all, - 'collect_fn': collect_megatron_pp_only - }, - Dispatch.MEGATRON_COMPUTE_PROTO: { - 'dispatch_fn': dispatch_megatron_compute_data_proto, - 'collect_fn': collect_megatron_compute_data_proto - }, - Dispatch.MEGATRON_PP_AS_DP_PROTO: { - 'dispatch_fn': dispatch_megatron_pp_as_dp_data_proto, - 'collect_fn': collect_megatron_pp_as_dp_data_proto - }, - Dispatch.DP_COMPUTE: { - 'dispatch_fn': dispatch_dp_compute, - 'collect_fn': collect_dp_compute - }, - Dispatch.DP_COMPUTE_PROTO: { - 'dispatch_fn': dispatch_dp_compute_data_proto, - 'collect_fn': collect_dp_compute_data_proto - }, - Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: { - 'dispatch_fn': dispatch_dp_compute_data_proto_with_func, - 'collect_fn': collect_dp_compute_data_proto - }, - Dispatch.DP_COMPUTE_METRIC: { - 'dispatch_fn': dispatch_dp_compute_data_proto, - 'collect_fn': collect_dp_compute - }, - Dispatch.DP_ALL_GATHER_TRAIN: { - 'dispatch_fn': dispatch_one_to_all, - 'collect_fn': collect_dp_train, - }, - Dispatch.DP_ALL_GATHER_INFER: { - 'dispatch_fn': dispatch_one_to_all, - 'collect_fn': collect_dp_infer, - }, - } - return predefined_dispatch_mode_fn.get(dispatch_mode) - - -def get_predefined_execute_fn(execute_mode): - """ - Note that here we only asks execute_all and execute_rank_zero to be implemented - Leave the choice of how these two functions handle argument 'blocking' to users - """ - predefined_execute_mode_fn = { - Execute.ALL: { - 'execute_fn_name': 'execute_all' - }, - Execute.RANK_ZERO: { - 'execute_fn_name': 'execute_rank_zero' - }, - Execute.INFER: { - 'execute_fn_name': 'execute_infer' - }, - Execute.TRAIN: { - 'execute_fn_name': 'execute_train' - } - } - return predefined_execute_mode_fn.get(execute_mode) - - -def _check_dispatch_mode(dispatch_mode): - if not isinstance(dispatch_mode, (Dispatch, Dict)): - raise TypeError(f'dispatch_mode must be a Dispatch or a Dict. Got {type(dispatch_mode)}') - if isinstance(dispatch_mode, Dict): - necessary_keys = ['dispatch_fn', 'collect_fn'] - for key in necessary_keys: - if key not in dispatch_mode: - raise KeyError(f'key {key} should be in dispatch_mode if it is a dictionary') - - -def _check_execute_mode(execute_mode): - if not isinstance(execute_mode, Execute): - raise TypeError(f'execute_mode must be an instance of Execute. Got {type(execute_mode)}') - - -def _materialize_futures(*args, **kwargs): - from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProtoFuture - new_args = [] - for arg in args: - if isinstance(arg, DataProtoFuture): - arg = arg.get() - # add more type to materialize - new_args.append(arg) - for k, v in kwargs.items(): - if isinstance(v, DataProtoFuture): - kwargs[k] = v.get() - - new_args = tuple(new_args) - return new_args, kwargs - - -def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): - _check_dispatch_mode(dispatch_mode=dispatch_mode) - _check_execute_mode(execute_mode=execute_mode) - - def decorator(func): - - @wraps(func) - def inner(*args, **kwargs): - if materialize_futures: - args, kwargs = _materialize_futures(*args, **kwargs) - return func(*args, **kwargs) - - attrs = {'dispatch_mode': dispatch_mode, 'execute_mode': execute_mode, 'blocking': blocking} - setattr(inner, MAGIC_ATTR, attrs) - return inner - - return decorator diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/__init__.py deleted file mode 100644 index 1ce90c5eb352d85c59105c0dc85b5f1dd576f095..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker.py deleted file mode 100644 index f3c3341df05db9c85a59c5c25f14214c6d18bb7f..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from dataclasses import dataclass -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo - - -class MegatronWorker(Worker): - - def __init__(self, cuda_visible_devices=None) -> None: - super().__init__(cuda_visible_devices) - - def get_megatron_global_info(self): - from megatron.core import parallel_state as mpu - tp_size = mpu.get_tensor_model_parallel_world_size() - dp_size = mpu.get_data_parallel_world_size() - pp_size = mpu.get_pipeline_model_parallel_world_size() - info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size) - return info - - def get_megatron_rank_info(self): - from megatron.core import parallel_state as mpu - tp_rank = mpu.get_tensor_model_parallel_rank() - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank) - return info \ No newline at end of file diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker_group.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker_group.py deleted file mode 100644 index 85371f6eab72b4097b2f599ef4c845272b33b06c..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker_group.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict - -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import ResourcePool, WorkerGroup -from .worker import DistRankInfo, DistGlobalInfo - - -class MegatronWorkerGroup(WorkerGroup): - - def __init__(self, resource_pool: ResourcePool, **kwargs): - super().__init__(resource_pool=resource_pool, **kwargs) - self._megatron_rank_info = None - self._megatron_global_info: DistGlobalInfo = None - - def init_megatron(self, default_megatron_kwargs: Dict = None): - raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten") - - def get_megatron_rank_info(self, rank: int) -> DistRankInfo: - if not (0 <= rank < self.world_size): - raise ValueError(f'rank must be from [0, world_size), Got {rank}') - return self._megatron_rank_info[rank] - - @property - def tp_size(self): - if self._megatron_global_info is None: - raise ValueError("MegatronWorkerGroup._megatron_global_info must be initialized") - return self._megatron_global_info.tp_size - - @property - def dp_size(self): - if self._megatron_global_info is None: - raise ValueError("MegatronWorkerGroup._megatron_global_info must be initialized") - return self._megatron_global_info.dp_size - - @property - def pp_size(self): - if self._megatron_global_info is None: - raise ValueError("MegatronWorkerGroup._megatron_global_info must be initialized") - return self._megatron_global_info.pp_size - - def get_megatron_global_info(self): - return self._megatron_global_info diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/__init__.py deleted file mode 100644 index 1ce90c5eb352d85c59105c0dc85b5f1dd576f095..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/ray.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/ray.py deleted file mode 100644 index 430290cf2683d882d35a83256aa363d959265a05..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/ray.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ray - - -@ray.remote -class WorkerGroupRegisterCenter: - - def __init__(self, rank_zero_info): - self.rank_zero_info = rank_zero_info - - def get_rank_zero_info(self): - return self.rank_zero_info - - -def create_worker_group_register_center(name, info): - return WorkerGroupRegisterCenter.options(name=name).remote(info) diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker.py deleted file mode 100644 index fd9c162e70c4821d146c9d9f1b42e251744325f2..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -the class for Worker -""" -import os -import socket -from dataclasses import dataclass -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch - - -@dataclass -class DistRankInfo: - tp_rank: int - dp_rank: int - pp_rank: int - - -@dataclass -class DistGlobalInfo: - tp_size: int - dp_size: int - pp_size: int - - -class WorkerHelper: - - def _get_node_ip(self): - - def get_node_ip_by_sdk(): - if os.getenv("WG_BACKEND", None) == "ray": - import ray - return ray._private.services.get_node_ip_address() - elif os.getenv("WG_BACKEND", None) == "torch_rpc": - from mindspeed_llm.tasks.posttrain.rlxf.single_controller import get_ip_addr - return get_ip_addr() - return None - - host_ipv4 = os.getenv("MY_HOST_IP", None) - host_ipv6 = os.getenv("MY_HOST_IPV6", None) - host_ip_by_env = host_ipv4 or host_ipv6 - host_ip_by_sdk = get_node_ip_by_sdk() - - host_ip = host_ip_by_env or host_ip_by_sdk - return host_ip - - def _get_free_port(self): - with socket.socket() as sock: - sock.bind(('', 0)) - return sock.getsockname()[1] - - def get_availale_master_addr_port(self): - return self._get_node_ip(), str(self._get_free_port()) - - def _get_pid(self): - return - - -class WorkerMeta: - keys = [ - "WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", "CUDA_VISIBLE_DEVICES" - ] - - def __init__(self, store) -> None: - self._store = store - - def to_dict(self): - return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys} - - -# we assume that in each WorkerGroup, there is a Master Worker -class Worker(WorkerHelper): - - def __new__(cls, *args, **kwargs): - instance = super().__new__(cls) - - # note that here we use int to distinguish - disable_worker_init = int(os.environ.get('DISABLE_WORKER_INIT', 0)) - if disable_worker_init: - return instance - - rank = os.environ.get("RANK", None) - worker_group_prefix = os.environ.get("WG_PREFIX", None) - - # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init - if None not in [rank, worker_group_prefix] and 'ActorClass(' not in cls.__name__: - instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) - - return instance - - def _configure_before_init(self, register_center_name: str, rank: int): - if not isinstance(rank, int): - raise TypeError(f"rank must be int, instead of {type(rank)}") - - if rank == 0: - master_addr, master_port = self.get_availale_master_addr_port() - rank_zero_info = { - "MASTER_ADDR": master_addr, - "MASTER_PORT": master_port, - } - - if os.getenv("WG_BACKEND", None) == "ray": - from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.register_center.ray import create_worker_group_register_center - self.register_center = create_worker_group_register_center(name=register_center_name, - info=rank_zero_info) - - os.environ.update(rank_zero_info) - - def __init__(self, cuda_visible_devices=None) -> None: - # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely - import os - world_size = int(os.environ['WORLD_SIZE']) - rank = int(os.environ['RANK']) - self._rank = rank - self._world_size = world_size - - master_addr = os.environ["MASTER_ADDR"] - master_port = os.environ["MASTER_PORT"] - - local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - - store = { - '_world_size': world_size, - '_rank': rank, - '_local_world_size': local_world_size, - '_local_rank': local_rank, - '_master_addr': master_addr, - '_master_port': master_port - } - if cuda_visible_devices is not None: - store['_cuda_visible_devices'] = cuda_visible_devices - - meta = WorkerMeta(store=store) - self._configure_with_meta(meta=meta) - - def _configure_with_meta(self, meta: WorkerMeta): - """ - This function should only be called inside by WorkerGroup - """ - if not isinstance(meta, WorkerMeta): - raise TypeError( - f"Invalid meta type: expected WorkerMeta, got {type(meta).__name__}. " - f"(Received value: {repr(meta)})" - ) - self.__dict__.update(meta.to_dict()) # this is hacky - for key in WorkerMeta.keys: - val = self.__dict__.get(f"_{key.lower()}", None) - if val is not None: - os.environ[key] = str(val) - os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace( - "]", "") if self._master_addr else "" - - def get_master_addr_port(self): - return self._master_addr, self._master_port - - def get_cuda_visible_devices(self): - import os - cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set") - return cuda_visible_devices - - @property - def world_size(self): - return self._world_size - - @property - def rank(self): - return self._rank - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC) - def execute_with_func_generator(self, func, *args, **kwargs): - ret_proto = func(self, *args, **kwargs) - return ret_proto diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker_group.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker_group.py deleted file mode 100644 index e847b7dccde9c2c24d9a782760a07ea859b36ab0..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker_group.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -the class of WorkerGroup -""" -import logging -import threading -import signal -import time -from typing import List, Any, Callable, Dict - -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn - - -class ResourcePool: - - def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None: - if process_on_nodes is None: - process_on_nodes = [] - self._store = process_on_nodes - self.max_collocate_count = max_collocate_count - self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node - - def add_node(self, process_count): - self._store.append(process_count) - - @property - def world_size(self): - return sum(self._store) - - def __call__(self) -> Any: - return self._store - - @property - def store(self): - return self._store - - def local_world_size_list(self) -> List[int]: - nested_local_world_size_list = [] - - for local_world_size in self._store: - inner_list = [] - - for _ in range(local_world_size): - inner_list.append(local_world_size) - - nested_local_world_size_list.append(inner_list) - - return [item for row in nested_local_world_size_list for item in row] - - def local_rank_list(self) -> List[int]: - nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] - return [item for row in nested_local_rank_list for item in row] - - -class ClassWithInitArgs: - """ - This class stores a class constructor and the args/kwargs to construct the class. - It is used to instantiate the remote class. - """ - - def __init__(self, cls, *args, **kwargs) -> None: - self.cls = cls - self.args = args - self.kwargs = kwargs - - def __call__(self) -> Any: - return self.cls(*self.args, **self.kwargs) - - -def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: - import time - while True: - for worker in workers: - if not is_alive(worker): - logging.warning(f"worker {worker} is not alive" + " sending signal to main thread") - signal.raise_signal(signal.SIGABRT) - time.sleep(gap_time) - - -class WorkerGroup: - - def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: - self._is_init_with_detached_workers = True if resource_pool is None else False - - if resource_pool is not None: - # handle the case when WorkGroup is attached to an existing one - self._procecss_dispatch_config = resource_pool() - else: - self._procecss_dispatch_config = None - - self._workers = [] - self._worker_names = [] - - self._master_addr = None - self._master_port = None - - self._checker_thread: threading.Thread = None - - def _is_worker_alive(self, worker): - raise NotImplementedError(f"WorkerGroup._is_worker_alive called, should be implemented in derived class.") - - def _block_until_all_workers_alive(self) -> None: - while True: - all_state = [self._is_worker_alive(worker) for worker in self._workers] - if False in all_state: - time.sleep(1) - else: - break - - def start_worker_aliveness_check(self, every_n_seconds=1) -> None: - # before starting checking worker aliveness, make sure all workers are already alive - self._block_until_all_workers_alive() - - self._checker_thread = threading.Thread(target=check_workers_alive, - args=(self._workers, self._is_worker_alive, every_n_seconds)) - self._checker_thread.start() - - @property - def world_size(self): - return len(self._workers) - - # execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup, - # MegatronWorkerGroup, XperfWorkerGroup should skip - - def _bind_worker_method(self, user_defined_cls, func_generator): - """ - Bind the worker method to the WorkerGroup - """ - - for method_name in dir(user_defined_cls): - - try: - method = getattr(user_defined_cls, method_name) - if not callable(method): - raise TypeError( - f"{method_name} in {user_defined_cls} is not callable" - ) - except Exception as e: - # if it is a property, it will fail because Class doesn't have instance property - continue - - if hasattr(method, MAGIC_ATTR): - # this method is decorated by register - attribute = getattr(method, MAGIC_ATTR) - if not isinstance(attribute, dict): - raise TypeError( - f"Attribute must be a dictionary. Got {type(attribute)} for {method_name}" - ) - if 'dispatch_mode' not in attribute: - raise KeyError( - f"Attribute must contain 'dispatch_mode' key in {method_name}. " - f"Found keys: {list(attribute.keys())}" - ) - - dispatch_mode = attribute['dispatch_mode'] - execute_mode = attribute['execute_mode'] - blocking = attribute['blocking'] - - # get dispatch fn - if isinstance(dispatch_mode, Dispatch): - # get default dispatch fn - fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode) - dispatch_fn = fn['dispatch_fn'] - collect_fn = fn['collect_fn'] - else: - if not isinstance(dispatch_mode, dict): - raise TypeError( - f"dispatch_mode must be dict type. Got {type(dispatch_mode)} " - f"in {method_name}" - ) - - if 'dispatch_fn' not in dispatch_mode: - raise KeyError( - f"dispatch_mode requires 'dispatch_fn' key in {method_name}. " - f"Found keys: {list(dispatch_mode.keys())}" - ) - - if 'collect_fn' not in dispatch_mode: - raise KeyError( - f"dispatch_mode requires 'collect_fn' key in {method_name}. " - f"Found keys: {list(dispatch_mode.keys())}" - ) - dispatch_fn = dispatch_mode['dispatch_fn'] - collect_fn = dispatch_mode['collect_fn'] - - # get execute_fn_name - execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) - wg_execute_fn_name = execute_mode['execute_fn_name'] - - # get execute_fn from string - try: - execute_fn = getattr(self, wg_execute_fn_name) - if not callable(execute_fn): - raise TypeError(f"{wg_execute_fn_name} is not callable") - except Exception as e: - print(f'execute_fn {wg_execute_fn_name} is invalid') - raise - - # bind a new method to the RayWorkerGroup - func = func_generator(self, - method_name, - dispatch_fn=dispatch_fn, - collect_fn=collect_fn, - execute_fn=execute_fn, - blocking=blocking) - - try: - setattr(self, method_name, func) - except Exception as e: - raise ValueError(f'Fail to set method_name {method_name}') from e diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/__init__.py deleted file mode 100644 index ed0ede061bc226772a8777e10032ed2511747a86..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls -from .megatron import (DistRankInfo, DistGlobalInfo) \ No newline at end of file diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/base.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/base.py deleted file mode 100644 index 8f0fa4cb088458bbdff5c76d718f2897cda8dd60..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/base.py +++ /dev/null @@ -1,480 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ['Worker'] - -import os -import time -from typing import Dict, List, Any -from unittest.mock import patch - -import ray -from ray.util import list_named_actors -from ray.util.placement_group import placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy -from ray.experimental.state.api import get_actor - -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import MAGIC_ATTR - -ACTOR_INFER_WORLD_SIZE = None -ACTOR_TRAIN_WORLD_SIZE = None - - -def set_actor_infer_world_size(world_size): - global ACTOR_INFER_WORLD_SIZE - ACTOR_INFER_WORLD_SIZE = world_size - - -def set_actor_train_world_size(world_size): - global ACTOR_TRAIN_WORLD_SIZE - ACTOR_TRAIN_WORLD_SIZE = world_size - - -def get_actor_infer_world_size(): - return ACTOR_INFER_WORLD_SIZE - - -def get_actor_train_world_size(): - return ACTOR_TRAIN_WORLD_SIZE - - -def get_random_string(length: int) -> str: - import random - import string - letters_digits = string.ascii_letters + string.digits - return ''.join(random.choice(letters_digits) for _ in range(length)) - - -def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking): - - def func(*args, **kwargs): - args, kwargs = dispatch_fn(self, *args, **kwargs) - output = execute_fn(method_name, *args, **kwargs) - if blocking: - output = ray.get(output) - output = collect_fn(self, output) - return output - - return func - - -class RayResourcePool(ResourcePool): - - def __init__(self, - process_on_nodes: List[int] = None, - use_gpu: bool = True, - name_prefix: str = "", - max_colocate_count: int = 5, - detached=False) -> None: - super().__init__(process_on_nodes, max_colocate_count) - self.use_gpu = use_gpu - self.name_prefix = name_prefix - self.pgs = None - self.detached = detached - - def get_placement_groups(self, strategy="STRICT_PACK", name=None): - if self.pgs is not None: - return self.pgs - - if not self.name_prefix: - self.name_prefix = name - - pg_name_prefix = f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" - pg_scheme = [[{ - "CPU": self.max_collocate_count, - "NPU": 1 - } if self.use_gpu else { - "CPU": self.max_collocate_count - } for _ in range(process_count)] for process_count in self._store] - - lifetime = 'detached' if self.detached else None - - pgs = [ - placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) - for idx, bundles in enumerate(pg_scheme) - ] - - ray.get([pg.ready() for pg in pgs]) - - self.pgs = pgs - return pgs - - -class RayClassWithInitArgs(ClassWithInitArgs): - - def __init__(self, cls, *args, **kwargs) -> None: - super().__init__(cls, *args, **kwargs) - self._options = {} - self._additional_resource = {} - - def set_additional_resource(self, additional_resource): - self._additional_resource = additional_resource - - def update_options(self, options: Dict): - self._options.update(options) - - def __call__(self, - placement_group, - placement_group_bundle_idx, - use_gpu: bool = True, - num_gpus=1, - sharing_with=None) -> Any: - if sharing_with is not None: - target_node_id = ray.get(sharing_with.get_node_id.remote()) - cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) - options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)} - return self.cls.options(**options).remote(*self.args, - cuda_visible_devices=cuda_visible_devices, - **self.kwargs) - - options = { - "scheduling_strategy": - PlacementGroupSchedulingStrategy(placement_group=placement_group, - placement_group_bundle_index=placement_group_bundle_idx) - } - options.update(self._options) - - if use_gpu: - options["resources"] = {"NPU": num_gpus} - - if len(self._additional_resource) > 1: - for k, v in self._additional_resource.items(): - options[k] = v - - print("cls:", self.cls) - print("args: ", self.args) - print("kwargs: ", self.kwargs) - return self.cls.options(**options).remote(*self.args, **self.kwargs) - - -class RayWorkerGroup(WorkerGroup): - - def __init__(self, - resource_pool: RayResourcePool = None, - ray_cls_with_init: RayClassWithInitArgs = None, - bin_pack: bool = True, - name_prefix: str = None, - detached=False, - worker_names=None, - **kwargs) -> None: - super().__init__(resource_pool=resource_pool, **kwargs) - self.ray_cls_with_init = ray_cls_with_init - self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix - - if self._is_init_with_detached_workers: - self._init_with_detached_workers(worker_names=worker_names) - else: - self._init_with_resource_pool(resource_pool=resource_pool, - ray_cls_with_init=ray_cls_with_init, - bin_pack=bin_pack, - detached=detached) - - if ray_cls_with_init is not None: - self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) - - def _is_worker_alive(self, worker: ray.actor.ActorHandle): - worker_state_dict = get_actor(worker._actor_id.hex()) - return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False - - def _init_with_detached_workers(self, worker_names): - workers = [ray.get_actor(name=name) for name in worker_names] - self._workers = workers - self._worker_names = worker_names - self._world_size = len(worker_names) - - def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached): - use_gpu = resource_pool.use_gpu - - strategy = "PACK" - if bin_pack: - strategy = "STRICT_PACK" - pgs = resource_pool.get_placement_groups(strategy=strategy, name=self.name_prefix) - world_size = resource_pool.world_size - self._world_size = world_size - num_gpus = 1 / resource_pool.max_collocate_count - - rank = -1 - for pg_idx, local_world_size in enumerate(resource_pool.store): - pg = pgs[pg_idx] - if local_world_size > pg.bundle_count: - raise ValueError(f"when generating for {self.name_prefix}, for the ") - for local_rank in range(local_world_size): - rank += 1 - - # we pass in environment variable at option so that Worker can use environment variable to set - env_vars = { - 'WORLD_SIZE': str(world_size), - 'RANK': str(rank), - 'WG_PREFIX': self.name_prefix, - 'WG_BACKEND': 'ray', - 'RAY_LOCAL_WORLD_SIZE': str(local_world_size), - 'RAY_LOCAL_RANK': str(local_rank), - } - if rank != 0: - env_vars['MASTER_ADDR'] = self._master_addr - env_vars['MASTER_PORT'] = self._master_port - - import re - cia_name = type(ray_cls_with_init.cls).__name__ - match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" - cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" - name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 - ray_cls_with_init.update_options({'runtime_env': {'env_vars': env_vars}, 'name': name}) - - if detached: - ray_cls_with_init.update_options({'lifetime': 'detached'}) - - # create a worker - worker = ray_cls_with_init(placement_group=pg, - placement_group_bundle_idx=local_rank, - use_gpu=use_gpu, - num_gpus=num_gpus) - self._workers.append(worker) - self._worker_names.append(name) - - if rank == 0: - register_center_actor = None - for _ in range(120): - if f"{self.name_prefix}_register_center" not in list_named_actors(): - time.sleep(1) - else: - register_center_actor = ray.get_actor(f"{self.name_prefix}_register_center") - if register_center_actor is None: - available_actors = list_named_actors(all_namespaces=True) - raise ValueError( - f"failed to get register_center_actor: {self.name_prefix}_register_center in {list_named_actors(all_namespaces=True)}" - ) - rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote()) - self._master_addr, self._master_port = rank_zero_info['MASTER_ADDR'], rank_zero_info['MASTER_PORT'] - - @property - def worker_names(self): - return self._worker_names - - @classmethod - def from_detached(cls, worker_names=None, ray_cls_with_init=None): - worker_group = cls(resource_pool=None, - ray_cls_with_init=ray_cls_with_init, - name_prefix=None, - worker_names=worker_names) - return worker_group - - def spawn(self, prefix_set): - """ - spawn to a dictionary of worker groups, each with a subset of method with prefix. - - """ - - def remove_prefix(text, prefix): - if text.startswith(prefix): - return text[len(prefix):] - return text - - def _rebind_actor_methods(worker_group, actor_name): - """ - bind the method with actor_prefix to its original name - """ - prefix: str = actor_name + '_' - for method_name in dir(worker_group): - if method_name.startswith(prefix): - original_method_name = remove_prefix(method_name, prefix) - method = getattr(worker_group, method_name) - setattr(worker_group, original_method_name, method) - - new_worker_group_dict = {} - for prefix in prefix_set: - new_worker_group = self.from_detached(worker_names=self._worker_names, - ray_cls_with_init=self.ray_cls_with_init) - - _rebind_actor_methods(new_worker_group, prefix) - new_worker_group_dict[prefix] = new_worker_group - return new_worker_group_dict - - def execute_rank_zero_sync(self, method_name: str, *args, **kwargs): - return ray.get(self.execute_all_async(method_name, **args, **kwargs)) - - def execute_rank_zero_async(self, method_name: str, *args, **kwargs): - remote_call = getattr(self._workers[0], method_name) - return remote_call.remote(*args, **kwargs) - - def execute_rank_zero(self, method_name: str, *args, **kwargs): - return self.execute_rank_zero_async(method_name, *args, **kwargs) - - def execute_all(self, method_name: str, *args, **kwargs): - return self.execute_all_async(method_name, *args, **kwargs) - - def execute_all_sync(self, method_name: str, *args, **kwargs): - return ray.get(self.execute_all_async(method_name, *args, **kwargs)) - - def execute_all_async(self, method_name: str, *args, **kwargs): - # 这里我们假设,如果 args 和 kwargs 里面所有的参数都是 list,且所有的 list 长度都与 len(self._workers) 一致的话,我们会把 - # list 中的每一个分别发到对应的 worker 上去 - length = len(self._workers) - if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): - if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()): - result = [] - for i in range(length): - sliced_args = tuple(arg[i] for arg in args) - sliced_kwargs = {k: v[i] for k, v in kwargs.items()} - remote_call = getattr(self._workers[i], method_name) - result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) - return result - - return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self._workers] - - def execute_train(self, method_name: str, *args, **kwargs): - return self.execute_train_infer_async(method_name, True, *args, **kwargs) - - def execute_train_sync(self, method_name: str, *args, **kwargs): - return ray.get(self.execute_train_infer_async(method_name, True, *args, **kwargs)) - - def execute_infer(self, method_name: str, *args, **kwargs): - return self.execute_train_infer_async(method_name, False, *args, **kwargs) - - def execute_infer_sync(self, method_name: str, *args, **kwargs): - return ray.get(self.execute_train_infer_async(method_name, False, *args, **kwargs)) - - - def execute_train_infer_async(self, method_name: str, is_train: bool, *args, **kwargs): - # 这里我们假设,如果 args 和 kwargs 里面所有的参数都是 list,且所有的 list 长度都与 len(self._workers) 一致的话,我们会把 - # list 中的每一个分别发到对应的 worker 上去 - all_length = len(self._workers) - infer_length = get_actor_infer_world_size() - train_length = get_actor_train_world_size() - if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): - if all(len(arg) == all_length for arg in args) and all(len(kwarg) == all_length for kwarg in kwargs.values()): - result = [] - loop_length = train_length if is_train else infer_length - offset = 0 if is_train else train_length - for i in range(loop_length): - sliced_args = tuple(arg[i] for arg in args) - sliced_kwargs = {k: v[i] for k, v in kwargs.items()} - remote_call = getattr(self._workers[offset + i], method_name) - result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) - return result - - if is_train: - workers = self._workers[:train_length] - else: - workers = self._workers[train_length:] - return [getattr(worker, method_name).remote(*args, **kwargs) for worker in workers] - - @property - def master_address(self): - return self._master_addr - - @property - def master_port(self): - return self._master_port - - @property - def workers(self): - return self._workers - - @property - def world_size(self): - return self._world_size - - -def _bind_workers_method_to_parent(cls, key, user_defined_cls): - """ - Utilities that enables creating workers inside the same ray.Actor, - with code written in separate ray.Actors. - - Binds the methods of each worker to the WorkerDict. - Note that we only bind public methods that are decorated by register - """ - for method_name in dir(user_defined_cls): - try: - method = getattr(user_defined_cls, method_name) - if not callable(method): - raise TypeError( - f"{method_name} in {user_defined_cls} is not callable" - ) - except Exception as e: - # if it is a property, it will fail because Class doesn't have instance property - continue - - if hasattr(method, MAGIC_ATTR): - - def generate_function(name): - - def func(self, *args, **kwargs): - # dispatch to the actual worker - return getattr(self.worker_dict[key], name)(*args, **kwargs) - - return func - - func = generate_function(method_name) - # pass MAGIC_ATTR for outer worker group - setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR)) - try: - method_name_with_prefix = key + '_' + method_name - setattr(cls, method_name_with_prefix, func) - except Exception as e: - raise ValueError(f'Fail to set method_name {method_name}') from e - - -def _unwrap_ray_remote(cls): - if hasattr(cls, '__ray_actor_class__'): - cls = cls.__ray_actor_class__ - return cls - - -def create_colocated_worker_cls(class_dict: Dict[str, RayClassWithInitArgs]): - """ - This function should return a class instance that delegates the calls to every - cls in cls_dict - """ - cls_dict = {} - init_args_dict = {} - worker_cls = None - for key, cls in class_dict.items(): - if worker_cls is None: - worker_cls = cls.cls.__ray_actor_class__.__base__ - else: - if worker_cls != cls.cls.__ray_actor_class__.__base__: - raise ValueError( - "the worker class should be the same when sharing the same process" - ) - cls_dict[key] = cls.cls - init_args_dict[key] = {'args': cls.args, 'kwargs': cls.kwargs} - - if cls_dict.keys() != init_args_dict.keys(): - raise ValueError( - f"Key mismatch: cls_dict keys ({cls_dict.keys()}) " - f"must match init_args_dict keys ({init_args_dict.keys()})" - ) - - class WorkerDict(worker_cls): - - def __init__(self): - super().__init__() - self.worker_dict = {} - for key, user_defined_cls in cls_dict.items(): - user_defined_cls = _unwrap_ray_remote(user_defined_cls) - # directly instantiate the class without remote - with patch.dict(os.environ, {'DISABLE_WORKER_INIT': '1'}): - self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get('args', ()), - **init_args_dict[key].get('kwargs', {})) - - # now monkey-patch the methods from inner class to WorkerDict - for key, user_defined_cls in cls_dict.items(): - user_defined_cls = _unwrap_ray_remote(user_defined_cls) - _bind_workers_method_to_parent(WorkerDict, key, user_defined_cls) - - remote_cls = ray.remote(WorkerDict) - remote_cls = RayClassWithInitArgs(cls=remote_cls) - return remote_cls diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/megatron.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/megatron.py deleted file mode 100644 index 7dbb88876746df934c258a7029f3630280c08293..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/megatron.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Optional - -import ray - -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup -from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs - - -# NOTE(sgm): for opensource megatron-core -class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): - """ - MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup - so that the dispatcher can use it to dispatch data. - """ - - def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): - super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) - self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info') - self._megatron_global_info: DistGlobalInfo = ray.get( - self.execute_rank_zero_async(method_name='get_megatron_global_info')) diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/training/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/core_algos.py b/mindspeed_llm/tasks/posttrain/rlxf/training/core_algos.py deleted file mode 100644 index a4ae895f0b4c61eb5207b3ac8b365785beb3fb6e..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/training/core_algos.py +++ /dev/null @@ -1,598 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Core functions to implement PPO algorithms. -The function implemented in this file should be used by trainer with different distributed strategies to -implement PPO -""" - -from copy import deepcopy - -import numpy as np -import torch -from transformers import AutoTokenizer - -import mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional as F -from mindspeed_llm.tasks.posttrain.rlxf.utils.loggers import Loggers -from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto -from mindspeed_llm.tasks.posttrain.verifier.rule_verifier import preprocess_box_response_for_qwen_prompt, format_reward, reasoning_steps_reward, strict_format_reward, \ -base_model_accuracy_reward - -logger = Loggers() - - -class AdaptiveKLController: - def __init__(self, init_kl_coef, target_kl, horizon): - self.value = init_kl_coef - self.target = target_kl - self.horizon = horizon - - def update(self, current_kl, n_steps): - target = self.target - proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.horizon - self.value *= mult - - -class FixedKLController: - """Fixed KL controller.""" - - def __init__(self, kl_coef): - self.value = kl_coef - - def update(self, current_kl, n_steps): - pass - - -def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor, - gamma: torch.Tensor, lam: torch.Tensor): - """ - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - values: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. - gamma: `(float)` - discounted factor used in RL - lam: `(float)` - lambda value when computing Generalized Advantage Estimation - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - - """ - with torch.no_grad(): - lastgaelam = 0 - advantages_reversed = [] - gen_len = token_level_rewards.shape[-1] - - for t in reversed(range(gen_len)): - nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 - delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] - lastgaelam = delta + gamma * lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], dim=1) - - returns = advantages + values - advantages = F.masked_whiten(advantages, eos_mask) - advantages = torch.masked_fill(advantages, ~eos_mask, 0) - return advantages, returns - - -def compute_group_norm_advantage_return( - token_level_rewards: torch.Tensor, - eos_mask: torch.Tensor, - gamma: torch.Tensor, - lam: torch.Tensor, - config -): - """ - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. - gamma: `(float)` - discounted factor used in RL - lam: `(float)` - lambda value when computing Generalized Advantage Estimation - - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - - """ - response_length = token_level_rewards.size(1) - returns = torch.zeros_like(token_level_rewards) - cumulative_return = torch.zeros(token_level_rewards.size(0), device=token_level_rewards.device) - - # Calculate returns by accumulating discounted rewards - for t in reversed(range(response_length)): - cumulative_return = token_level_rewards[:, t] + gamma * cumulative_return - returns[:, t] = cumulative_return - advantages = deepcopy(returns) - if not hasattr(config.algorithm, "advantage_whiten") or config.algorithm.advantage_whiten: - advantages = F.masked_whiten(advantages, eos_mask) - else: - advantages = advantages * eos_mask - return advantages, returns - - -def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange): - """ - Args: - old_log_prob: `(torch.Tensor)` - shape: (bs, response_length) - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - advantages: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - cliprange: (float) - The clip range used in PPO. - - Returns: - pg_loss: `a scalar torch.Tensor` - policy gradient loss computed via PPO - pg_clipfrac: (float) - a float number indicating the fraction of policy gradient loss being clipped - - """ - negative_approx_kl = log_prob - old_log_prob - ratio = torch.exp(negative_approx_kl) - ppo_kl = F.masked_mean(-negative_approx_kl, eos_mask) - - pg_losses = -advantages * ratio - pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) - - pg_loss = F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask) - pg_clipfrac = F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask) - return pg_loss, pg_clipfrac, ppo_kl - - -def compute_grpo_policy_loss(old_log_prob, log_prob, ref_log_prob, advantages, eos_mask, cliprange, kl_ctrl): - """ - Args: - old_log_prob: `(torch.Tensor)` - shape: (bs, response_length) - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - ref_log_prob `(torch.Tensor)` - shape: (bs, response_length) - advantages: `(torch.Tensor)` - shape: (bs, response_length) - eos_mask: `(torch.Tensor)` - shape: (bs, response_length) - cliprange: (float) - The clip range used in PPO. - - Returns: - pg_loss: `a scalar torch.Tensor` - policy gradient loss computed via GRPO - pg_clipfrac: (float) - a float number indicating the fraction of policy gradient loss being clipped - - """ - negative_approx_kl = log_prob - old_log_prob - ratio = torch.exp(negative_approx_kl) - ppo_kl = F.masked_mean(-negative_approx_kl, eos_mask) - - pg_losses = -advantages * ratio - pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) - - pg_loss = F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask) - pg_clipfrac = F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask) - - ref_approx_kl = ref_log_prob - log_prob - ratio_kl = torch.exp(ref_approx_kl) - kl_losses = ratio_kl - ref_approx_kl - 1 - kl_loss = F.masked_mean(kl_losses, eos_mask) - pg_loss = pg_loss + kl_loss * kl_ctrl.value - return pg_loss, pg_clipfrac, ppo_kl - - -def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, penalty) -> torch.FloatTensor: - """Compute KL divergence given logprob and ref_logprob. - Args: - logprob: - ref_logprob: - - Returns: - - """ - if penalty == "kl": - return logprob - ref_logprob - - if penalty == "abs": - return (logprob - ref_logprob).abs() - - if penalty == "mse": - return 0.5 * (logprob - ref_logprob).square() - - if penalty == "full": - # so, here logprob and ref_logprob should contain the logits for every token in vocabulary - raise NotImplementedError - - raise NotImplementedError - - -def find_first_eos_index(tensor, eos_token_id): - """ - 找到张量中每一行第一个等于 eos_token_id 的索引。 - - Args: - tensor (torch.Tensor): 输入的张量,形状为 (batch_size, seq_len)。 - - Returns: - torch.Tensor: 每一行中每一行第一个等于 eos_token_id 的索引,形状为 (batch_size,)。 - 如果没有找到,返回 -1。 - """ - - is_eos = (tensor == eos_token_id) - - # 使用 torch.argmax 找到第一个等于 eos_id 的索引 - score_first_eos_index = torch.argmax(is_eos.int(), dim=1) - reward_first_eos_index = torch.argmax(is_eos.int(), dim=1) + 1 - max_id = is_eos.shape[1] - 1 - reward_first_eos_index = torch.min(reward_first_eos_index, torch.tensor(max_id, device=reward_first_eos_index.device)) - has_eos = is_eos.any(dim=1) - score_first_eos_index[~has_eos] = -1 - reward_first_eos_index[~has_eos] = -1 - - return score_first_eos_index, reward_first_eos_index - - -def apply_kl_penalty(config, data: DataProto, kl_ctrl: AdaptiveKLController): - responses = data.batch['responses'] - response_length = responses.size(1) - token_level_scores = data.batch['rm_scores'] - batch_size = data.batch.batch_size[0] - attention_mask = data.batch['attention_mask'] - response_mask = attention_mask[:, -response_length:] - eos_token_id = data.meta_info['eos_token_id'] - contain_eos_token = torch.any(responses == eos_token_id, dim=-1) - if config.algorithm.missing_eos_penalty is not None: - token_level_scores[~contain_eos_token] -= config.algorithm.missing_eos_penalty - data.batch['rm_scores'] = token_level_scores - # compute kl between ref_policy and current policy - if 'ref_log_prob' in data.batch.keys(): - kld = kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'], - penalty=config.algorithm.kl_penalty) # (batch_size, response_length) - kld = kld * response_mask - beta = kl_ctrl.value - else: - beta = 0 - kld = torch.zeros_like(response_mask, dtype=torch.float32) - actual_start = torch.arange(token_level_scores.size(0), device=token_level_scores.device) - score_first_eos_index, reward_first_eos_index = find_first_eos_index(responses, eos_token_id) - token_level_rewards = - beta * kld - token_level_rewards[[actual_start, reward_first_eos_index]] += token_level_scores[[actual_start, score_first_eos_index]] - - current_kl = F.masked_mean(kld, mask=response_mask, axis=-1) # average over sequence - current_kl = torch.mean(current_kl, dim=0).item() - - # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 - kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) - data.batch['token_level_rewards'] = token_level_rewards - - metrics = {'critic/kl': current_kl, 'critic/kl_coeff': beta} - - return data, metrics - - -def compute_advantage(data: DataProto, config): - responses = data.batch['responses'] - response_length = responses.size(1) - attention_mask = data.batch['attention_mask'] - response_mask = attention_mask[:, -response_length:] - token_level_rewards = data.batch['token_level_rewards'] - - if config.algorithm.adv_estimator == 'gae': - values = data.batch['values'] - advantages, returns = compute_gae_advantage_return(token_level_rewards=token_level_rewards, - values=values, - eos_mask=response_mask, - gamma=config.algorithm.gamma, - lam=config.algorithm.lam) - data.batch['advantages'] = advantages - data.batch['returns'] = returns - elif config.algorithm.adv_estimator == 'group_norm': - advantages, returns = compute_group_norm_advantage_return(token_level_rewards=token_level_rewards, - eos_mask=response_mask, - gamma=config.algorithm.gamma, - lam=config.algorithm.lam, - config=config) - data.batch['advantages'] = advantages - data.batch['returns'] = returns - else: - raise NotImplementedError - return data - - -def compute_score(reward_wg, batch, metrics, config): - token_level_rewards = torch.zeros_like(batch.batch["responses"], dtype=torch.float32) - - assert reward_wg is not None or config.reward.verifier, "At least one reward should be provided for score computing." - - # 0 for default/general problems, 1 for math problems - if "categories" in batch.batch.keys(): - use_verifier_mask = batch.batch["categories"][:, 0].squeeze().bool() - elif hasattr(config.reward, "verifier") and config.reward.verifier: - use_verifier_mask = torch.ones(len(batch.batch["input_ids"]), dtype=torch.bool) - else: - use_verifier_mask = torch.zeros(len(batch.batch["input_ids"]), dtype=torch.bool) - - if reward_wg and (~use_verifier_mask).sum(): - score_tensor = reward_wg.compute_rm_score(batch).batch['rm_scores'] - - rm_token_level_rewards = get_last_reward( - batch, - rm_scores=score_tensor, - n_sample_batch=config.actor_rollout_ref.actor_rollout.n_samples_per_prompt, - metrics=metrics, - valid_mask=~use_verifier_mask - ) - token_level_rewards[~use_verifier_mask] += rm_token_level_rewards - - - if hasattr(config.reward, "verifier") and config.reward.verifier and use_verifier_mask.sum(): - verifier_token_level_rewards = compute_verifier_score(batch, metrics, config, use_verifier_mask) - token_level_rewards[use_verifier_mask] += verifier_token_level_rewards - - rewards = DataProto.from_dict( - tensors={ - 'token_level_rewards': token_level_rewards, - 'rm_scores': token_level_rewards - } - ) - - return batch.union(rewards) - - -def compute_verifier_score(batch, metrics, config, valid_mask): - tokenizer = AutoTokenizer.from_pretrained(config.training.tokenizer_name_or_path, trust_remote_code=True) - - responses = batch.batch["responses"][valid_mask] - str_responses = tokenizer.batch_decode(responses, skip_special_tokens=True) - question = batch.batch["prompts"][valid_mask] - str_question = tokenizer.batch_decode(question, skip_special_tokens=True) - - reward_index = batch.batch["responses_ori_length"].unsqueeze(1) - 1 - reward_index = reward_index[valid_mask] - - logger.logger.info("=" * 50) - logger.logger.info(">>>>>>>>>> User:\n") - logger.logger.info(str_question[0]) - logger.logger.info(">>>>>>>>>> Assistant:\n") - logger.logger.info(str_responses[0]) - - extra_data = {} - - if hasattr(config.training, "dataset_additional_keys"): - for k in config.training.dataset_additional_keys: - extra_data[k] = tokenizer.batch_decode(batch.batch[k], skip_special_tokens=True) - if k == "categories": - continue - logger.logger.info(f">>>>>>>>>> {k}") - logger.logger.info(extra_data[k][valid_mask.nonzero()[0]]) - - logger.logger.info("=" * 50) - - labels = [label for label, mask in zip(extra_data.get("labels"), valid_mask) if mask] - scores = verifier(str_responses, labels, config, metrics, infos=None) - - scores = torch.tensor( - scores, - dtype=torch.float32, - device=reward_index.device - ) - - scores = scores.reshape(-1, config.actor_rollout_ref.actor_rollout.n_samples_per_prompt) - scores = (scores - scores.mean(dim=1, keepdim=True)) / (scores.std(dim=1, keepdim=True) + 1e-8) - scores = scores.reshape(-1).unsqueeze(1) - - token_level_rewards = torch.zeros_like(responses, dtype=torch.float32) - token_level_rewards.scatter_(1, reward_index, scores) - - return token_level_rewards - - -def verifier(responses, labels, config, metrics, infos=None): - """ - User-defined verifier scoring process. - - Parameters: - ---------- - responses(List[`str`]): - Actor rollout answers. - labels(List[`str`]): - Ground Truth. - infos(List[`str`], *optional*): - Additional usable information loaded from the dataset. - - Return: - scores(List[`float`]): Final scores. - """ - rule_verifier_function = { - "acc": preprocess_box_response_for_qwen_prompt, - "format": format_reward, - "step": reasoning_steps_reward, - "strict_format": strict_format_reward, - "base_acc": base_model_accuracy_reward - } - - scores = [0.0] * len(labels) - - verifier_function = config.algorithm.verifier_function if hasattr( - config.algorithm, "verifier_function") else ["acc"] - verifier_weight = config.algorithm.verifier_weight if hasattr( - config.algorithm, "verifier_weight") else [1.0] - - for idx, fun_verifier in enumerate(verifier_function): - if fun_verifier not in rule_verifier_function: - continue - score = rule_verifier_function[fun_verifier](sequences=responses, answers=labels) - metrics[f"grpo/{fun_verifier}_rewards/mean"] = sum(score) / max(len(score), 1) - scores = [all_score + tmp_score * verifier_weight[idx] - for all_score, tmp_score in zip(scores, score)] - - return scores - - -def get_last_reward(data, rm_scores, n_sample_batch, metrics, valid_mask): - eos_indices = data.batch["responses_ori_length"].unsqueeze(1) - 1 - - # gather reward from eos position - rm_scores = rm_scores[valid_mask] - eos_indices = eos_indices[valid_mask] - reward = rm_scores.gather(dim=1, index=eos_indices).squeeze(1) - - # record raw reward - metrics[f"grpo/reward_model_rewards/mean"] = sum(reward) / max(len(reward), 1) - - # calculate group norm - reward = reward.reshape(-1, n_sample_batch) - reward = (reward - reward.mean(dim=1, keepdim=True)) / (reward.std(dim=1, keepdim=True) + 1e-8) - reward = reward.reshape(-1) - token_level_rewards = torch.zeros_like(rm_scores).scatter_(dim=1, index=eos_indices, src=reward.unsqueeze(1).to(rm_scores.dtype)) - return token_level_rewards - - -def reduce_metrics(metrics: dict): - for key, val in metrics.items(): - metrics[key] = np.mean(val) - return metrics - - -def compute_data_metrics(batch): - sequence_score = batch.batch['rm_scores'].sum(-1) - sequence_reward = batch.batch['token_level_rewards'].sum(-1) - - response_length = batch.batch['responses'].shape[-1] - - advantages = batch.batch['advantages'] - prompt_mask = batch.batch['attention_mask'][:, :-response_length] - response_mask = batch.batch['attention_mask'][:, -response_length:] - - prompt_length = prompt_mask.sum(-1).float() - response_length = response_mask.sum(-1).float() # (batch_size,) - - returns = batch.batch['returns'] - values = batch.batch['values'] - - metrics = { - # score - 'critic/score/mean': torch.mean(sequence_score).detach().item(), - 'critic/score/max': torch.max(sequence_score).detach().item(), - 'critic/score/min': torch.min(sequence_score).detach().item(), - # reward - 'critic/rewards/mean': torch.mean(sequence_reward).detach().item(), - 'critic/rewards/max': torch.max(sequence_reward).detach().item(), - 'critic/rewards/min': torch.min(sequence_reward).detach().item(), - # adv - 'critic/advantages/mean': F.masked_mean(advantages, response_mask).detach().item(), - 'critic/advantages/max': torch.max(advantages[response_mask]).detach().item(), - 'critic/advantages/min': torch.min(advantages[response_mask]).detach().item(), - # returns - 'critic/returns/mean': F.masked_mean(returns, response_mask).detach().item(), - 'critic/returns/max': torch.max(returns[response_mask]).detach().item(), - 'critic/returns/min': torch.min(returns[response_mask]).detach().item(), - # values - 'critic/values/mean': F.masked_mean(values, response_mask).detach().item(), - 'critic/values/max': torch.max(values[response_mask]).detach().item(), - 'critic/values/min': torch.min(values[response_mask]).detach().item(), - # response length - 'response_length/mean': torch.mean(response_length).detach().item(), - 'response_length/max': torch.max(response_length).detach().item(), - 'response_length/min': torch.min(response_length).detach().item(), - # prompt length - 'prompt_length/mean': torch.mean(prompt_length).detach().item(), - 'prompt_length/max': torch.max(prompt_length).detach().item(), - 'prompt_length/min': torch.min(prompt_length).detach().item(), - } - return metrics - - -def compute_data_online_dpo_metrics(batch): - sequence_score = batch.batch['rm_scores'].sum(-1) - response_length = batch.batch['responses'].shape[-1] - - prompt_mask = batch.batch['attention_mask'][:, :-response_length] - response_mask = batch.batch['attention_mask'][:, -response_length:] - - prompt_length = prompt_mask.sum(-1).float() - response_length = response_mask.sum(-1).float() - - metrics = { - # score - 'reward/score/mean': torch.mean(sequence_score).detach().item(), - 'reward/score/max': torch.max(sequence_score).detach().item(), - 'reward/score/min': torch.min(sequence_score).detach().item(), - # response length - 'response_length/mean': torch.mean(response_length).detach().item(), - 'response_length/max': torch.max(response_length).detach().item(), - 'response_length/min': torch.min(response_length).detach().item(), - # prompt length - 'prompt_length/mean': torch.mean(prompt_length).detach().item(), - 'prompt_length/max': torch.max(prompt_length).detach().item(), - 'prompt_length/min': torch.min(prompt_length).detach().item(), - } - return metrics - - -def compute_grpo_data_metrics(batch): - sequence_score = batch.batch['rm_scores'].sum(-1) - sequence_reward = batch.batch['token_level_rewards'].sum(-1) - response_length = batch.batch['responses'].shape[-1] - advantages = batch.batch['advantages'] - prompt_mask = batch.batch['attention_mask'][:, :-response_length] - response_mask = batch.batch['attention_mask'][:, -response_length:] - prompt_length = prompt_mask.sum(-1).float() - response_length = response_mask.sum(-1).float() - returns = batch.batch['returns'] - metrics = { - # score - 'grpo/score/mean': torch.mean(sequence_score).detach().item(), - 'grpo/score/max': torch.max(sequence_score).detach().item(), - 'grpo/score/min': torch.min(sequence_score).detach().item(), - # reward - 'grpo/rewards/mean': torch.mean(sequence_reward).detach().item(), - 'grpo/rewards/max': torch.max(sequence_reward).detach().item(), - 'grpo/rewards/min': torch.min(sequence_reward).detach().item(), - # adv - 'grpo/advantages/mean': F.masked_mean(advantages, response_mask).detach().item(), - 'grpo/advantages/max': torch.max(advantages[response_mask]).detach().item(), - 'grpo/advantages/min': torch.min(advantages[response_mask]).detach().item(), - 'grpo/returns/mean': F.masked_mean(returns, response_mask).detach().item(), - 'grpo/returns/max': torch.max(returns[response_mask]).detach().item(), - 'grpo/returns/min': torch.min(returns[response_mask]).detach().item(), - # response length - 'response_length/mean': torch.mean(response_length).detach().item(), - 'response_length/max': torch.max(response_length).detach().item(), - 'response_length/min': torch.min(response_length).detach().item(), - # prompt length - 'prompt_length/mean': torch.mean(prompt_length).detach().item(), - 'prompt_length/max': torch.max(prompt_length).detach().item(), - 'prompt_length/min': torch.min(prompt_length).detach().item(), - } - return metrics diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/initialize.py b/mindspeed_llm/tasks/posttrain/rlxf/training/initialize.py deleted file mode 100644 index d1a01469ceb0765beeb9d8647a552e69d763291b..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/training/initialize.py +++ /dev/null @@ -1,325 +0,0 @@ -from functools import wraps -from omegaconf import OmegaConf - -import ray -import torch - -import megatron -from megatron.core import mpu -from megatron.training import get_args, print_rank_0 -from megatron.training.arguments import validate_args -from megatron.training.yaml_arguments import validate_yaml -from megatron.training.checkpointing import load_args_from_checkpoint -from megatron.training.global_vars import set_global_variables -from megatron.training.initialize import ( - _set_random_seed, - _init_autoresume, _initialize_tp_communicators, -) - -from mindspeed.core.tensor_parallel.ascend_turbo.initialize import initialize_cfg_from_args -from mindspeed_llm.training.arguments import parse_args_decorator -from mindspeed_llm.tasks.utils.error_utils import ensure_valid -from mindspeed_llm.training.utils import seed_all -from mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state import initialize_model_parallel_2megatron -from mindspeed_llm.training.initialize import _compile_dependencies -import mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state as ps - - -def parse_args_from_config(role, config): - import sys - # update role and model configs - OmegaConf.set_struct(config, False) # unset read only properties - role_args_from_config = getattr(config, role, None) if role in ["critic", "reward"] else getattr( - config.actor_rollout_ref, role, None) - model_name = role_args_from_config.model - model_config = config['model'][model_name] - common_config = config['training'] - # override priority: role > training (common) > model - role_args_from_config = OmegaConf.merge(model_config, common_config, role_args_from_config) - role_args_from_config.pop("model") - OmegaConf.set_struct(config, True) - - # Parsing training parameters. - for key, value in role_args_from_config.items(): - if isinstance(value, bool): - if value: - sys.argv.append(f"--{key.replace('_', '-')}") - else: - sys.argv.append(f"--{key.replace('_', '-')}={value}") - - -def initialize_megatron( - extra_args_provider=None, - args_defaults={}, - ignore_unknown_args=False, - allow_no_cuda=False, - skip_mpu_initialization=False, - role=None, - config=None, - two_megatron=False -): - """Set global variables, initialize distributed, and - set autoresume and random seeds. - `allow_no_cuda` should not be set unless using megatron for cpu only - data processing. In general this arg should not be set unless you know - what you are doing. - Returns a function to finalize distributed env initialization - (optionally, only when args.lazy_mpu_init == True) - """ - if not allow_no_cuda: - # Make sure cuda is available. - ensure_valid(torch.cuda.is_available(), "Megatron requires CUDA.") - - # Parse arguments - import sys - origin_sys_argv = sys.argv - if role and config is not None: - sys.argv = [sys.argv[0]] - parse_args_from_config(role, config) - parse_args = parse_args_decorator(megatron.training.arguments.parse_args) - args = parse_args(extra_args_provider, ignore_unknown_args) - args.role = role - sys.argv = origin_sys_argv - - if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): - ensure_valid(args.load is not None, - "--use-checkpoints-args requires --load argument") - load_args_from_checkpoint(args) - - if args.yaml_cfg is not None: - args = validate_yaml(args, args_defaults) - else: - validate_args(args, args_defaults) - - # set global args, build tokenizer, and set adlr-autoresume, - # tensorboard-writer, and timers. - set_global_variables(args) - - # add deterministic computing function - if args.use_deter_comp: - seed_all(args.seed) - print_rank_0("deterministic computing is applied for npu.") - - # torch.distributed initialization - def finish_mpu_init(): - args = get_args() - # Pytorch distributed. - _initialize_distributed(two_megatron) - - # Random seeds for reproducibility. - if args.rank == 0: - print("> setting random seeds to {} ...".format(args.seed)) - _set_random_seed(args.seed, args.data_parallel_random_init) - if args.use_mc2: - initialize_cfg_from_args(args) - - if skip_mpu_initialization: - return None - - args = get_args() - if args.lazy_mpu_init: - args.use_cpu_initialization = True - # delayed initialization of DDP-related stuff - # We only set basic DDP globals - mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) - # and return function for external DDP manager - # to call when it has DDP initialized - mpu.set_tensor_model_parallel_rank(args.rank) - return finish_mpu_init - else: - # Megatron's MPU is the master. Complete initialization right away. - finish_mpu_init() - - # Autoresume. - _init_autoresume() - - # Compile dependencies. - _compile_dependencies() - - if args.tp_comm_overlap: - _initialize_tp_communicators() - - # No continuation function - return None - - -def _initialize_distributed(two_megatron=False): - """Initialize torch.distributed and core model parallel.""" - args = get_args() - from datetime import timedelta - - device_count = torch.cuda.device_count() - if torch.distributed.is_initialized(): - if args.rank == 0: - print( - "torch distributed is already initialized, " - "skipping initialization ...", - flush=True, - ) - args.rank = torch.distributed.get_rank() - args.world_size = torch.distributed.get_world_size() - else: - if args.rank == 0: - print("> initializing torch distributed ...", flush=True) - # Manually set the device ids. - if device_count > 0: - if args.stage in ["ray_ppo", "ray_online_dpo", "ray_grpo"]: - allocated_device = int(ray.get_runtime_context().get_accelerator_ids()["NPU"][0]) - torch.cuda.set_device(allocated_device) - else: - device = args.rank % device_count - if args.local_rank is not None: - if args.local_rank != device: - raise ValueError("expected local-rank to be the same as rank % device-count.") - else: - args.local_rank = device - torch.cuda.set_device(device) - # Call the init process - torch.distributed.init_process_group( - backend=args.distributed_backend, - world_size=args.world_size, - rank=args.rank, - timeout=timedelta(minutes=args.distributed_timeout_minutes), - ) - - # Set the tensor model-parallel, pipeline model-parallel, and - # data-parallel communicators. - if device_count > 0: - if mpu.model_parallel_is_initialized(): - print("model parallel is already initialized") - else: - if not two_megatron: # normal case - mpu.initialize_model_parallel( - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, - args.virtual_pipeline_model_parallel_size, - args.pipeline_model_parallel_split_rank, - context_parallel_size=args.context_parallel_size, - expert_model_parallel_size=args.expert_model_parallel_size, - distributed_timeout_minutes=args.distributed_timeout_minutes, - nccl_communicator_config_path=args.nccl_communicator_config_path, - order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp', - ) - else: - # It is a little tricky here, that both the training and inference nodes need to build the two groups. - TRAIN_SIZE = args.num_gpus_for_train - INFER_SIZE = args.num_gpus_for_infer - if torch.distributed.get_world_size() != TRAIN_SIZE + INFER_SIZE: - raise ValueError("TRAIN_SIZE + INFER_SIZE should equal to total GPU num.") - initialize_model_parallel_2megatron( - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, - args.virtual_pipeline_model_parallel_size, - args.pipeline_model_parallel_split_rank, - context_parallel_size=args.context_parallel_size, - expert_model_parallel_size=args.expert_model_parallel_size, - distributed_timeout_minutes=args.distributed_timeout_minutes, - nccl_communicator_config_path=args.nccl_communicator_config_path, - infer_size=INFER_SIZE, - ) - initialize_model_parallel_2megatron( # currently only use TP for Inference - args.tensor_model_parallel_size, - 1, # inference do not use PP - distributed_timeout_minutes=args.distributed_timeout_minutes, - nccl_communicator_config_path=args.nccl_communicator_config_path, - infer_size=INFER_SIZE, - is_second_megatron=True - ) - - if ps.in_mg2_inference_group(): # set true TP, PP args for inference groups - args.pipeline_model_parallel_size = 1 - args.virtual_pipeline_model_parallel_size = 1 - - if args.rank != 0 and ps.is_mg2_first_rank(): # first rank in inference - print( - f"> initialized inference tensor model parallel with size " - f"{mpu.get_tensor_model_parallel_world_size()}" - ) - - if args.rank == 0: - print( - f"> initialized tensor model parallel with size " - f"{mpu.get_tensor_model_parallel_world_size()}" - ) - print( - f"> initialized pipeline model parallel with size " - f"{mpu.get_pipeline_model_parallel_world_size()}" - ) - - -def barrier_wrapper(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - mg2_available = ps._MEGATRON2_INITIALIZED - no_group_info = 'group' not in kwargs or kwargs['group'] is None - if no_group_info and mg2_available: - dist_group = ps.get_mg2_local_group() - kwargs['group'] = dist_group - return fn(*args, **kwargs) - - return wrapper - - -def broadcast_wrapper(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - no_group_info = 'group' not in kwargs or kwargs['group'] is None - mg2_available = ps._MEGATRON2_INITIALIZED - if no_group_info and mg2_available: - dist_group = ps.get_mg2_local_group() - kwargs['group'] = dist_group - args = list(args) - if len(args) >= 2: - src_rank = ps.get_mg2_local_ranks()[0] - args[1] = src_rank - return fn(*args, **kwargs) - - return wrapper - - -def get_world_size_wrapper(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - mg2_available = ps._MEGATRON2_INITIALIZED - if mg2_available and not args and not kwargs: - dist_group = ps.get_mg2_local_group() - args = [dist_group] - return fn(*args, **kwargs) - - return wrapper - - -def get_elapsed_time_all_ranks(self, names, reset, barrier): - if barrier: - torch.distributed.barrier() - - world_size = torch.distributed.get_world_size() - args = get_args() - if args.role == "actor_rollout": # patch here: use mg2 local rank - rank = ps.get_mg2_local_rank() - group = ps.get_mg2_local_group() - else: - rank = torch.distributed.get_rank() - group = None - - rank_name_to_time = torch.zeros( - (world_size, len(names)), dtype=torch.float, device=torch.cuda.current_device() - ) - for i, name in enumerate(names): - if name in self._timers: - rank_name_to_time[rank, i] = self._timers[name].elapsed(reset=reset) - - torch.distributed._all_gather_base( - rank_name_to_time.view(-1), rank_name_to_time[rank, :].view(-1), group=group - ) - - return rank_name_to_time - - -def is_last_rank(): - rank = torch.distributed.get_rank() - if ps._MEGATRON2_INITIALIZED: - return rank == ps.get_mg2_local_ranks()[-1] - else: - return torch.distributed.get_rank() == ( - torch.distributed.get_world_size() - 1) \ No newline at end of file diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/parallel_state.py b/mindspeed_llm/tasks/posttrain/rlxf/training/parallel_state.py deleted file mode 100644 index ebfad2062ddd3086d7a9154c301f5fc0e1421fc8..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/training/parallel_state.py +++ /dev/null @@ -1,390 +0,0 @@ -from functools import wraps -from typing import Optional -from datetime import timedelta -import torch - - -# global variables for two megatron running -_TENSOR_AND_CONTEXT_PARALLEL_GROUP = None -_TENSOR_AND_CONTEXT_PARALLEL_GLOBAL_RANKS = None -_MEGATRON2_LOCAL_RANKS = None -_MEGATRON2_LOCAL_GROUP = None -_MEGATRON2_INITIALIZED = False - - -def initialize_model_parallel_2megatron( - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - virtual_pipeline_model_parallel_size: Optional[int] = None, - pipeline_model_parallel_split_rank: Optional[int] = None, - use_sharp: bool = False, - context_parallel_size: int = 1, - expert_model_parallel_size: int = 1, - nccl_communicator_config_path: Optional[str] = None, - distributed_timeout_minutes: int = 30, - order: str = "tp-cp-ep-dp-pp", - infer_size: int = 0, - is_second_megatron: bool = False -) -> None: - """Initialize model data parallel groups with offset. - Assert two groups are initialized : - training group contains GPU with rank[0, world_size - infer_world_size - 1] - inference group contains GPU with rank [world_size - infer_world_size, world_size - 1] - rank_offset is only set for inference groups. - """ - import megatron.core.parallel_state as ps - - global _MEGATRON2_LOCAL_RANKS - global _MEGATRON2_LOCAL_GROUP - global _MEGATRON2_INITIALIZED - - timeout = timedelta(minutes=distributed_timeout_minutes) - - # Get world size and rank. Ensure some consistencies. - if not torch.distributed.is_initialized(): - raise RuntimeError("torch.distributed not initialized.") - world_size: int = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - train_size = world_size - infer_size - - nccl_comm_cfgs = {} - if nccl_communicator_config_path is not None: - try: - import yaml - except ImportError as e: - raise ImportError( - "Cannot import `yaml`. Setting custom nccl communicator configs " - "requires the yaml package." - ) from e - - with open(nccl_communicator_config_path, "r") as stream: - nccl_comm_cfgs = yaml.safe_load(stream) - - # build megatron2 groups for inference and training - if infer_size and not is_second_megatron: # only build megatron2 groups once with positive inf_size - if _MEGATRON2_LOCAL_GROUP is not None: - raise RuntimeError("megatron local group is already initialized.") - ranks_mg2_inference = range(train_size, world_size) - group_mg2_inference = torch.distributed.new_group( - ranks_mg2_inference, timeout=timeout, pg_options=ps.get_nccl_options('dp_cp', nccl_comm_cfgs) - ) - ranks_mg2_training = range(train_size) - group_mg2_training = torch.distributed.new_group( - ranks_mg2_training, timeout=timeout, pg_options=ps.get_nccl_options('dp_cp', nccl_comm_cfgs) - ) - if rank in ranks_mg2_inference: # inf groups - _MEGATRON2_LOCAL_GROUP = group_mg2_inference - _MEGATRON2_LOCAL_RANKS = ranks_mg2_inference - else: - _MEGATRON2_LOCAL_GROUP = group_mg2_training - _MEGATRON2_LOCAL_RANKS = ranks_mg2_training - - # update world_size and rank_offset - if is_second_megatron: # inference group, i.e. the second group - world_size = infer_size - rank_offset = train_size - else: - world_size = train_size - rank_offset = 0 - - if ( - world_size - % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) - != 0 - ): - raise RuntimeError( - f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " - f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size}) " - f"x context_parallel_size ({context_parallel_size})" - ) - - data_parallel_size: int = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size - ) - - if data_parallel_size % expert_model_parallel_size != 0: - raise RuntimeError( - f"data_parallel_size ({data_parallel_size}) is not divisible by expert_model_parallel_size " - ) - - if expert_model_parallel_size > 1 and context_parallel_size > 1: - raise RuntimeError( - f"combination of expert model prallellism and context parallelism is not supported" - ) - - # num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - # num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - - if virtual_pipeline_model_parallel_size is not None: - if not pipeline_model_parallel_size > 2: - raise RuntimeError( - "pipeline-model-parallel size should be greater than 2 with interleaved schedule" - ) - ps._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 - ps._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size - - if pipeline_model_parallel_split_rank is not None: - ps._PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank - - rank_generator = ps.RankGenerator( - tp=tensor_model_parallel_size, - ep=expert_model_parallel_size, - dp=data_parallel_size, - pp=pipeline_model_parallel_size, - cp=context_parallel_size, - order=order, - offset=rank_offset - ) - - timeout = timedelta(minutes=distributed_timeout_minutes) - - # Build the data-parallel groups. - for ranks in rank_generator.get_ranks('dp'): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('dp', nccl_comm_cfgs) - ) - group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo") - if rank in ranks: - if ps._DATA_PARALLEL_GROUP is not None: - raise RuntimeError('data parallel group is already initialized') - ps._DATA_PARALLEL_GROUP = group - ps._DATA_PARALLEL_GROUP_GLOO = group_gloo - ps._DATA_PARALLEL_GLOBAL_RANKS = ranks - for ranks_with_cp in rank_generator.get_ranks('dp-cp'): - group_with_cp = torch.distributed.new_group( - ranks_with_cp, timeout=timeout, pg_options=ps.get_nccl_options('dp_cp', nccl_comm_cfgs) - ) - group_with_cp_gloo = torch.distributed.new_group( - ranks_with_cp, timeout=timeout, backend="gloo" - ) - if rank in ranks_with_cp: - ps._DATA_PARALLEL_GROUP_WITH_CP = group_with_cp - ps._DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo - ps._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp - - # Build the context-parallel groups. - global _TENSOR_AND_CONTEXT_PARALLEL_GROUP - global _TENSOR_AND_CONTEXT_PARALLEL_GLOBAL_RANKS - for ranks in rank_generator.get_ranks('cp'): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('cp', nccl_comm_cfgs) - ) - if rank in ranks: - if ps._CONTEXT_PARALLEL_GROUP is not None: - raise RuntimeError('context parallel group is already initialized') - ps._CONTEXT_PARALLEL_GROUP = group - ps._CONTEXT_PARALLEL_GLOBAL_RANKS = ranks - - for ranks in rank_generator.get_ranks('tp-cp'): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp_cp', nccl_comm_cfgs) - ) - if rank in ranks: - if ps._TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None: - raise RuntimeError('tensor and context parallel group is already initialized') - _TENSOR_AND_CONTEXT_PARALLEL_GROUP = group - _TENSOR_AND_CONTEXT_PARALLEL_GLOBAL_RANKS = ranks - - - # Build the model-parallel groups. - for ranks in rank_generator.get_ranks('tp-pp'): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('mp', nccl_comm_cfgs) - ) - if rank in ranks: - if ps._MODEL_PARALLEL_GROUP is not None: - raise RuntimeError('model parallel group is already initialized') - ps._MODEL_PARALLEL_GROUP = group - - # Build the tensor model-parallel groups. - for ranks in rank_generator.get_ranks('tp'): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp', nccl_comm_cfgs) - ) - if rank in ranks: - if ps._TENSOR_MODEL_PARALLEL_GROUP is not None: - raise RuntimeError('tensor model parallel group is already initialized') - - ps._TENSOR_MODEL_PARALLEL_GROUP = group - ps._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks - - # Build the pipeline model-parallel groups and embedding groups - # (first and last rank in each pipeline model-parallel group). - - for ranks in rank_generator.get_ranks('pp'): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('pp', nccl_comm_cfgs) - ) - if rank in ranks: - if ps._PIPELINE_MODEL_PARALLEL_GROUP is not None: - raise RuntimeError('pipeline model parallel group is already initialized') - - ps._PIPELINE_MODEL_PARALLEL_GROUP = group - ps._PIPELINE_GLOBAL_RANKS = ranks - # Setup embedding group (to exchange gradients between - # first and last stages). - if len(ranks) > 1: - embedding_ranks = [ranks[0], ranks[-1]] - position_embedding_ranks = [ranks[0]] - if pipeline_model_parallel_split_rank is not None: - if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: - embedding_ranks = [ - ranks[0], - ranks[pipeline_model_parallel_split_rank], - ranks[-1], - ] - if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks: - position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]] - else: - embedding_ranks = ranks - position_embedding_ranks = ranks - - group = torch.distributed.new_group( - embedding_ranks, timeout=timeout, pg_options=ps.get_nccl_options('embd', nccl_comm_cfgs) - ) - if rank in embedding_ranks: - if ps._EMBEDDING_GROUP is not None: - raise RuntimeError('embedding group is already initialized') - ps._EMBEDDING_GROUP = group - if rank in ranks: - ps._EMBEDDING_GLOBAL_RANKS = embedding_ranks - - group = torch.distributed.new_group( - position_embedding_ranks, - timeout=timeout, - pg_options=ps.get_nccl_options('embd', nccl_comm_cfgs), - ) - if rank in position_embedding_ranks: - if ps._POSITION_EMBEDDING_GROUP is not None: - raise RuntimeError('position embedding group is already initialized') - ps._POSITION_EMBEDDING_GROUP = group - if rank in ranks: - ps._POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks - - # Build the tensor + data parallel groups. - for ranks in rank_generator.get_ranks('tp-dp-cp'): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp_dp_cp', nccl_comm_cfgs) - ) - if rank in ranks: - if ps._TENSOR_AND_DATA_PARALLEL_GROUP is not None: - raise RuntimeError('Tensor + data parallel group is already initialized') - ps._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group - for ranks in rank_generator.get_ranks('tp-dp'): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp_dp', nccl_comm_cfgs) - ) - if rank in ranks: - ps._TENSOR_AND_DATA_PARALLEL_GROUP = group - - # Build the tensor + expert parallel groups - for ranks in rank_generator.get_ranks('tp-ep', independent_ep=True): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp_exp', nccl_comm_cfgs) - ) - if rank in ranks: - if ps._TENSOR_AND_EXPERT_PARALLEL_GROUP is not None: - raise RuntimeError('Tensor + expert parallel group is already initialized') - ps._TENSOR_AND_EXPERT_PARALLEL_GROUP = group - - for ranks in rank_generator.get_ranks('ep', independent_ep=True): - group = torch.distributed.new_group( - ranks, pg_options=ps.get_nccl_options('exp', nccl_comm_cfgs) - ) - if rank in ranks: - if ps._EXPERT_MODEL_PARALLEL_GROUP is not None: - raise RuntimeError('Expert parallel group is already initialized') - ps._EXPERT_MODEL_PARALLEL_GROUP = group - - for ranks in rank_generator.get_ranks('dp', independent_ep=True): - group = torch.distributed.new_group( - ranks, timeout=timeout, pg_options=ps.get_nccl_options('dp_modulo_exp', nccl_comm_cfgs) - ) - group_gloo = torch.distributed.new_group(ranks, backend="gloo") - if rank in ranks: - if ps._DATA_MODULO_EXPERT_PARALLEL_GROUP is not None: - raise RuntimeError('Data modulo expert group is already initialized') - ps._DATA_MODULO_EXPERT_PARALLEL_GROUP = group - ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = group_gloo - for ranks in rank_generator.get_ranks('dp-cp', independent_ep=True): - # Lazy initialization of the group - group = ps._DATA_MODULO_EXPERT_PARALLEL_GROUP - group_gloo = ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO - if rank in ranks: - ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = group - ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = group_gloo - # Initialize global memory buffer - # This isn't really "parallel state" but there isn't another good place to - # put this. If we end up with a more generic initialization of megatron-core - # we could stick it there - if not is_second_megatron: # global memory buffer should only be set once - ps._set_global_memory_buffer() - else: - _MEGATRON2_INITIALIZED = True - - -def is_mg2_first_rank(): - """ - Check if current node is the first node in the Megatron2 local group. - Use this to extend the old usage rank == 0. - """ - if _MEGATRON2_LOCAL_RANKS is None: - raise RuntimeError('Megatron2 group is not initialized') - return torch.distributed.get_rank() == _MEGATRON2_LOCAL_RANKS[0] - - -def in_mg2_inference_group(): - """ - """ - if _MEGATRON2_LOCAL_RANKS is None: - raise RuntimeError('Megatron2 group is not initialized') - return _MEGATRON2_LOCAL_RANKS[0] != 0 - - -def get_mg2_local_group(): - if _MEGATRON2_LOCAL_GROUP is None: - raise RuntimeError('Megatron2 group is not initialized') - return _MEGATRON2_LOCAL_GROUP - - -def get_mg2_local_ranks(): - if _MEGATRON2_LOCAL_RANKS is None: - raise RuntimeError('Megatron2 ranks are not initialized') - return _MEGATRON2_LOCAL_RANKS - - -def get_mg2_first_rank(): - """ - When the same world size is divided into multiple process groups in the actor-train - and actor-rollout worker roles, this method needs to be converted to local. - """ - if _MEGATRON2_LOCAL_RANKS is None: - raise RuntimeError('Megatron2 group is not initialized') - return _MEGATRON2_LOCAL_RANKS[0] - - -def get_mg2_local_rank(): - """ - When the same world size is divided into multiple process groups in the actor-train - and actor-rollout worker roles, this method needs to be converted to local. - """ - return torch.distributed.get_rank() - get_mg2_first_rank() - - -def rank_generator_init_wrapper(init_func): - @wraps(init_func) - def wrapper(self, *args, **kwargs): - if 'offset' in kwargs: - self.offset = kwargs.pop('offset') - else: - self.offset = 0 - init_func(self, *args, **kwargs) - return wrapper - - -def rank_generator_get_ranks_wrapper(get_ranks): - @wraps(get_ranks) - def wrapper(self, *args, **kwargs): - ranks_list = get_ranks(self, *args, **kwargs) - return [[item + self.offset for item in ranks] for ranks in ranks_list] - return wrapper \ No newline at end of file diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/parameter_mapping.py b/mindspeed_llm/tasks/posttrain/rlxf/training/parameter_mapping.py deleted file mode 100644 index 8dc3227ae2af364426789cd86423bdeff9ecbc54..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/training/parameter_mapping.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. -import hashlib - -import torch -import torch.distributed as dist -from megatron.training import get_args -from megatron.core import parallel_state as mpu - -import mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state as ps - - -RECEIVE_PARAM_NUMS = None -MODEL_SEND_GROUP = None -MODEL_RECEIVE_GROUPS = [] - - -def init_comm_groups(): - """ - Initialize model auto-mapping communication groups - Scenario example: - Training: tp 2 pp 2 - Inference: tp 2 dp 2 - Principle: Same tp rank weight data is sequentially broadcast to different dp on the inference side, - and aggregated across different dp for training pp dimension. - Number of communication groups: tp * pp - groups = [[0,4,6],[2,4,6],[1,5,7],[3,5,7]] - """ - args = get_args() - pipeline_parallel_groups = [] - data_parallel_groups = [] - for i in range(args.tensor_model_parallel_size): - ranks = list(range(i, args.num_gpus_for_train, args.tensor_model_parallel_size)) - pipeline_parallel_groups.append(ranks) - ranks = list(range(args.num_gpus_for_train + i, args.num_gpus_for_train + args.num_gpus_for_infer, - args.tensor_model_parallel_size)) - data_parallel_groups.append(ranks) - comm_groups = [] - for data_parallel_group, pipeline_parallel_group in zip(data_parallel_groups, pipeline_parallel_groups): - for rank in pipeline_parallel_group: - tmp_group = [rank] + list(data_parallel_group) - comm_groups.append(tmp_group) - print(comm_groups) - return comm_groups - - -def init_parameter_mapping_distributed(): - """ - Automatically mapping communication groups for model initialization - Scenario example: - Training: tp 2 pp 2 - Inference: tp 2 dp 2 - Given a world size of 8 as above - rank 0 acts as the sender, the communication group is [0, 4, 6] - rank 2 acts as the sender, the communication group is [2, 4, 6] - rank 4 acts as the receiver, the communication group is [[0, 4, 6], [2, 4, 6]] - """ - global MODEL_SEND_GROUP, MODEL_RECEIVE_GROUPS - groups = init_comm_groups() - rank = dist.get_rank() - for group in groups: - tmp_group = dist.new_group(group) - if not ps.in_mg2_inference_group() and rank in group: - MODEL_SEND_GROUP = tmp_group - if ps.in_mg2_inference_group() and rank in group: - MODEL_RECEIVE_GROUPS.append((group[0], tmp_group)) - print("init_distributed_sucess") - - -def get_model_send_group(): - return MODEL_SEND_GROUP - - -def get_model_receive_groups(): - return MODEL_RECEIVE_GROUPS - - -def get_receive_param_nums(): - return RECEIVE_PARAM_NUMS - - -def param_nums_is_initialized(): - return get_receive_param_nums() is not None - - -def sync_param_nums(moudle: torch.nn.Module): - """ - Synchronize the number of parameters for sending and receiving, ensuring that - the number of broadcast communications aligns and the model parameters align. - """ - args = get_args() - if not ps.in_mg2_inference_group(): - args_need_broadcast = torch.tensor([args.iteration, args.consumed_train_samples], dtype=torch.int64, - device=torch.cuda.current_device()) - dist.broadcast(args_need_broadcast, group=get_model_send_group(), src=dist.get_rank()) - num_parameters = torch.tensor([sum(1 for _ in moudle.named_parameters())], dtype=torch.int64, - device=torch.cuda.current_device()) - dist.broadcast(num_parameters, group=get_model_send_group(), src=dist.get_rank()) - else: - global RECEIVE_PARAM_NUMS - - recv_param_nums = [] - for group in get_model_receive_groups(): - args_need_broadcast = torch.empty(2, dtype=torch.int64, device=torch.cuda.current_device()) - dist.broadcast(args_need_broadcast, group=group[1], src=group[0], async_op=True) - args.iteration, args.consumed_train_samples = args_need_broadcast - - tmp_num_parameters = torch.empty(1, dtype=torch.int64, device=torch.cuda.current_device()) - dist.broadcast(tmp_num_parameters, group=group[1], src=group[0], async_op=True) - recv_param_nums.append(tmp_num_parameters) - RECEIVE_PARAM_NUMS = recv_param_nums - - -def compute_model_hash(model, hash_func): - hash_value = hash_func() - for param in model.parameters(): - param_bytes = param.data.cpu().numpy().tobytes() - hash_value.update(param_bytes) - md5_tensor = torch.tensor([int(h, 16) for h in hash_value.hexdigest()]) - return md5_tensor - - -def send_model_to_infer_model(moudle: torch.nn.Module): - """ - Decompose model information and transfer it directly from the model. - """ - args = get_args() - model_send_group = get_model_send_group() - if args.md5_validate: - hash_value = hashlib.md5() - - is_reuse_output_weights = (not args.untie_embeddings_and_output_weights and - args.pipeline_model_parallel_size >= 2 and - mpu.is_pipeline_last_stage(ignore_virtual=True)) - for name, param in moudle.named_parameters(): - if is_reuse_output_weights and 'output_layer.weight' in name: - continue - param_info_data = param.data - dist.broadcast(param_info_data, group=model_send_group, src=dist.get_rank(), async_op=True) - if args.md5_validate: - param_bytes = param_info_data.to(torch.float32).cpu().numpy().tobytes() - hash_value.update(param_bytes) - if args.md5_validate: - md5_tensor = torch.tensor([int(h, 16) for h in hash_value.hexdigest()], dtype=torch.int64, - device=torch.cuda.current_device()) - dist.broadcast(md5_tensor, group=model_send_group, src=dist.get_rank(), async_op=True) - - -def recv_model_from_train_model(moudle: torch.nn.Module): - """ - Decompose model parameters and directly loaded them into the model. - """ - args = get_args() - model_receive_groups = get_model_receive_groups() - recv_param_nums = get_receive_param_nums() - flag = True - idx = 0 - if args.md5_validate: - hash_value = hashlib.md5() - for _, param in moudle.named_parameters(): - if flag: - cur_num = int(recv_param_nums[idx]) - cur_group = model_receive_groups[idx] - flag = False - - param_info_data = param.data - torch.distributed.broadcast(param_info_data, group=cur_group[1], src=cur_group[0], async_op=False) - if args.md5_validate: - param_bytes = param_info_data.to(torch.float32).cpu().numpy().tobytes() - hash_value.update(param_bytes) - - cur_num -= 1 - if cur_num == 0: - if args.md5_validate: - md5_tensor = torch.tensor([int(h, 16) for h in hash_value.hexdigest()], dtype=torch.int64, - device=torch.cuda.current_device()) - md5_tensor_src = torch.zeros_like(md5_tensor, dtype=torch.int64, device=torch.cuda.current_device()) - dist.broadcast(md5_tensor_src, group=cur_group[1], src=cur_group[0], async_op=False) - if torch.equal(md5_tensor_src, md5_tensor): - print("MD5 Hash: The weights of the two models match.") - else: - print("MD5 Hash: The weights of the two models do not match.") - hash_value = hashlib.md5() - flag = True - idx += 1 - - if cur_num != 0: - if args.md5_validate: - md5_tensor = torch.tensor([int(h, 16) for h in hash_value.hexdigest()], dtype=torch.int64, - device=torch.cuda.current_device()) - md5_tensor_src = torch.zeros_like(md5_tensor, dtype=torch.int64, device=torch.cuda.current_device()) - dist.broadcast(md5_tensor_src, group=cur_group[1], src=cur_group[0], async_op=False) - if torch.equal(md5_tensor_src, md5_tensor): - print("MD5 Hash: The weights of the two models match.") - else: - print("MD5 Hash: The weights of the two models do not match.") - - -def run_auto_mapping(model): - if ps.in_mg2_inference_group(): - recv_model_from_train_model(model) - else: - send_model_to_infer_model(model) diff --git a/mindspeed_llm/tasks/posttrain/rlxf/utils/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mindspeed_llm/tasks/posttrain/rlxf/utils/loggers.py b/mindspeed_llm/tasks/posttrain/rlxf/utils/loggers.py deleted file mode 100644 index 6d884ea6daacf72eed03cfd1db99197d68322563..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/utils/loggers.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. - -import logging -from datetime import datetime - - -class Loggers(object): - def __init__(self, - name='root', - logger_level='DEBUG', - ): - self.logger = logging.getLogger(name) - self.logger.setLevel(logger_level) - - def handle_msg(self, msg, level, iteration, steps): - current_time = str(datetime.now()).split(".")[0] - - if isinstance(msg, dict): - fmt_msg = f"[{current_time}] " - fmt_msg += f"iteration: {iteration} / {steps} | " - for key in msg: - fmt_msg += f"{key} : {format(msg[key], '.4f')} | " - fmt_msg = fmt_msg[:-2] - else: - fmt_msg = f"[{current_time}] {level} " + str(msg) - return fmt_msg - - def info(self, msg, iteration, steps): - format_msg = self.handle_msg(msg, "INFO", iteration, steps) - self.logger.info(format_msg) - - def warning(self, msg, iteration, steps): - format_msg = self.handle_msg(msg, "WARNING", iteration, steps) - self.logger.warning(format_msg) - - def debug(self, msg, iteration, steps): - format_msg = self.handle_msg(msg, "DEBUG", iteration, steps) - self.logger.debug(format_msg) - - def error(self, msg, iteration, steps): - format_msg = self.handle_msg(msg, "ERROR", iteration, steps) - self.logger.error(format_msg) - - def flush(self): - for handler in self.logger.handlers: - handler.flush() \ No newline at end of file diff --git a/mindspeed_llm/tasks/posttrain/rlxf/utils/protocol.py b/mindspeed_llm/tasks/posttrain/rlxf/utils/protocol.py deleted file mode 100644 index 62f61f9501ea9c4e019a8579797609a740b267e5..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/utils/protocol.py +++ /dev/null @@ -1,534 +0,0 @@ -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Implement base data transfer protocol between any two functions, modules. -We can subclass Protocol to define more detailed batch info with specific keys -""" - -__all__ = ['DataProto', 'union_tensor_dict'] - -import copy -from dataclasses import dataclass, field -from typing import Callable, Dict, List, Union - -import numpy as np -import ray -import torch -import tensordict -from tensordict import TensorDict -from torch.utils.data import DataLoader - -try: - tensordict.set_lazy_legacy(False).set() -except Exception as e: - pass - - -def union_two_dict(dict1: Dict, dict2: Dict): - """Union two dict. Will throw an error if there is an item not the same object with the same key. - - Args: - dict1: - dict2: - - Returns: - - """ - for key, val in dict2.items(): - if key in dict1: - if dict2[key] != dict1[key]: - raise ValueError(f'{key} in dict1 and dict2 are not the same object') - dict1[key] = val - - return dict1 - - -def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: - """Union two tensordicts.""" - if tensor_dict1.batch_size != tensor_dict2.batch_size: - raise ValueError(f'Two tensor dicts must have identical batch sizes. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}') - for key in tensor_dict2.keys(): - if key not in tensor_dict1.keys(): - tensor_dict1[key] = tensor_dict2[key] - else: - if not tensor_dict1[key].equal(tensor_dict2[key]): - raise ValueError(f'{key} in tensor_dict1 and tensor_dict2 are not the same object') - - return tensor_dict1 - - -def union_numpy_dict(tensor_dict1, tensor_dict2): - for key, val in tensor_dict2.items(): - if key in tensor_dict1: - if not isinstance(tensor_dict2[key], np.ndarray): - raise TypeError(f"The value for key '{key}' in tensor_dict2 is not a numpy.ndarray.") - if not isinstance(tensor_dict1[key], np.ndarray): - raise TypeError(f"The value for key '{key}' in tensor_dict1 is not a numpy.ndarray.") - if not np.all(tensor_dict2[key] == tensor_dict1[key]): - raise ValueError(f"Arrays for key '{key}' in tensor_dict1 and tensor_dict2 are not the same object.") - tensor_dict1[key] = val - - return tensor_dict1 - - -def list_of_dict_to_dict_of_list(list_of_dict: List[dict]): - if len(list_of_dict) == 0: - return {} - keys = list_of_dict[0].keys() - output = {key: [] for key in keys} - for data in list_of_dict: - for key, item in data.items(): - if key not in output: - raise KeyError(f"Key '{key}' is not found in the output dictionary") - output[key].append(item) - return output - - -def collate_fn(x: List['DataProtoItem']): - batch = [] - non_tensor_batch = [] - for data in x: - batch.append(data.batch) - non_tensor_batch.append(data.non_tensor_batch) - batch = torch.stack(batch).contiguous() - non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch) - for key, val in non_tensor_batch.items(): - non_tensor_batch[key] = np.array(val, dtype=object) - return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) - - -@dataclass -class DataProtoItem: - batch: TensorDict = None - non_tensor_batch: Dict = field(default_factory=dict) - meta_info: Dict = field(default_factory=dict) - - -@dataclass -class DataProto: - """ - A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. - It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict. - TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the - same batch size should be put inside batch. - """ - batch: TensorDict = None - non_tensor_batch: Dict = field(default_factory=dict) - meta_info: Dict = field(default_factory=dict) - - def __post_init__(self): - # perform necessary checking - self.check_consistency() - - def __len__(self): - return self.batch.batch_size[0] - - def __getitem__(self, item): - tensor_data = self.batch[item] - non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} - return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) - - def __getstate__(self): - import io - buffer = io.BytesIO() - if tensordict.__version__ >= '0.5.0' and self.batch is not None: - self.batch = self.batch.contiguous() - self.batch = self.batch.consolidate() - torch.save(self.batch, buffer) - return buffer, self.non_tensor_batch, self.meta_info - - def __setstate__(self, data): - batch_deserialized, non_tensor_batch, meta_info = data - batch_deserialized.seek(0) - batch = torch.load(batch_deserialized, - weights_only=False, - map_location='cpu' if not torch.cuda.is_available() else None) - self.batch = batch - self.non_tensor_batch = non_tensor_batch - self.meta_info = meta_info - - def check_consistency(self): - """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch - We expose this function as a public one so that user can call themselves directly - """ - if self.batch is not None: - if len(self.batch.batch_size) != 1: - raise ValueError('only support num_batch_dims=1') - - if len(self.non_tensor_batch) != 0: - if len(self.batch.batch_size) != 1: - raise ValueError('only support num_batch_dims=1 when non_tensor_batch is not empty.') - - batch_size = self.batch.batch_size[0] - for key, val in self.non_tensor_batch.items(): - if not isinstance(val, np.ndarray) or val.dtype != object: - raise TypeError(f'data in the non_tensor_batch must be a numpy.array with dtype=object, ' - f'but found {key} with dtype {val.dtype}') - if val.shape[0] != batch_size: - raise ValueError(f'key {key} length {len(val)} is not equal to batch size {batch_size}') - - @classmethod - def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None): - tensors = {} - non_tensors = {} - - for key, val in data.items(): - if isinstance(val, torch.Tensor): - tensors[key] = val - elif isinstance(val, np.ndarray): - non_tensors[key] = val - else: - raise ValueError(f'Unsupported type in data {type(val)}') - - return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) - - @classmethod - def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1): - """Create a DataProto from a dict of tensors. This assumes that - 1. All the tensor in tensors have the same dim0 - 2. Only dim0 is the batch dim - """ - if len(tensors) == 0: - raise ValueError('tensors must not be empty') - - if num_batch_dims <= 0: - raise ValueError('num_batch_dims must be greater than zero') - - if non_tensors is not None: - if num_batch_dims != 1: - raise ValueError('only support num_batch_dims=1 when non_tensors is not None.') - - if meta_info is None: - meta_info = {} - if non_tensors is None: - non_tensors = {} - - if not isinstance(non_tensors, dict): - raise TypeError('non_tensors must be a dictionary') - - # get and check batch size - batch_size = None - pivot_key = None - for key, tensor in tensors.items(): - if batch_size is None: - batch_size = tensor.shape[:num_batch_dims] - pivot_key = key - else: - current_batch = tensor.shape[:num_batch_dims] - if batch_size != current_batch: - raise ValueError(f'Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. ' - f'Got {pivot_key} has {batch_size}, {key} has {current_batch}') - - for key, val in non_tensors.items(): - non_tensors[key] = np.array(val, dtype=object) - - tensor_dict = TensorDict(source=tensors, batch_size=batch_size) - return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) - - def to(self, device) -> 'DataProto': - """move the batch to device - - Args: - device (torch.device, str): torch device - - Returns: - DataProto: the current DataProto - - """ - if self.batch is not None: - self.batch = self.batch.to(device) - return self - - def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto': - """Select a subset of the DataProto via batch_keys and meta_info_keys - - Args: - batch_keys (list, optional): a list of strings indicating the keys in batch to select - meta_info_keys (list, optional): a list of keys indicating the meta info to select - - Returns: - DataProto: the DataProto with the selected batch_keys and meta_info_keys - """ - if batch_keys is not None: - batch_keys = tuple(batch_keys) - sub_batch = self.batch.select(*batch_keys) - else: - sub_batch = self.batch - - if non_tensor_batch_keys is not None: - non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys} - else: - non_tensor_batch = self.non_tensor_batch - - if deepcopy: - non_tensor_batch = copy.deepcopy(non_tensor_batch) - - if meta_info_keys is not None: - sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys} - else: - sub_meta_info = self.meta_info - - if deepcopy: - sub_meta_info = copy.deepcopy(sub_meta_info) - - return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) - - def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto': - """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` - - Args: - batch_keys (list, optional): a list of strings indicating the keys in batch to pop - meta_info_keys (list, optional): a list of keys indicating the meta info to pop - - Returns: - DataProto: the DataProto with the poped batch_keys and meta_info_keys - """ - if batch_keys is None: - raise ValueError("batch_keys cannot be None. Please provide a valid list of keys.") - - if meta_info_keys is None: - meta_info_keys = [] - if non_tensor_batch_keys is None: - non_tensor_batch_keys = [] - - tensors = {} - # tensor batch - for key in batch_keys: - if key not in self.batch.keys(): - raise KeyError(f"Key '{key}' not found in self.batch.") - tensors[key] = self.batch.pop(key) - non_tensors = {} - # non tensor batch - for key in non_tensor_batch_keys: - if key not in self.non_tensor_batch.keys(): - raise KeyError(f"Key '{key}' not found in self.non_tensor_batch.") - non_tensors[key] = self.non_tensor_batch.pop(key) - meta_info = {} - for key in meta_info_keys: - if key not in self.meta_info.keys(): - raise KeyError(f"Key '{key}' not found in self.meta_info.") - meta_info[key] = self.meta_info.pop(key) - return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) - - def rename(self, old_keys=None, new_keys=None) -> 'DataProto': - """ - Note that this function only rename the key in the batch - """ - - def validate_input(keys): - if keys is not None: - if isinstance(keys, str): - keys = [keys] - elif isinstance(keys, list): - pass - else: - raise TypeError(f'keys must be a list or a string, but got {type(keys)}') - return keys - - old_keys = validate_input(old_keys) - new_keys = validate_input(new_keys) - - if len(new_keys) != len(old_keys): - raise ValueError( - f'new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}') - - self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) - - return self - - def union(self, other: 'DataProto') -> 'DataProto': - """Union with another DataProto. Union batch and meta_info separately. - Throw an error if - - there are conflict keys in batch and they are not equal - - the batch size of two data batch is not the same - - there are conflict keys in meta_info and they are not the same. - - Args: - other (DataProto): another DataProto to union - - Returns: - DataProto: the DataProto after union - """ - self.batch = union_tensor_dict(self.batch, other.batch) - self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) - self.meta_info = union_two_dict(self.meta_info, other.meta_info) - return self - - def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): - """Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch - dataset. - - Args: - mini_batch_size (int): mini-batch size when iterating the dataset. We require that - ``batch.batch_size[0] % mini_batch_size == 0`` - epochs (int): number of epochs when iterating the dataset. - dataloader_kwargs: internally, it returns a DataLoader over the batch. - The dataloader_kwargs is the kwargs passed to the DataLoader - - Returns: - Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is - ``self.batch.batch_size * epochs // mini_batch_size`` - """ - if self.batch.batch_size[0] % mini_batch_size != 0: - raise ValueError(f"Batch size {self.batch.batch_size[0]} is not divisible by mini_batch_size {mini_batch_size}.") - # we can directly create a dataloader from TensorDict - if dataloader_kwargs is None: - dataloader_kwargs = {} - - if seed is not None: - generator = torch.Generator() - generator.manual_seed(seed) - else: - generator = None - - if not isinstance(dataloader_kwargs, dict): - raise TypeError(f"dataloader_kwargs should be a dictionary, but got {type(dataloader_kwargs)}.") - train_dataloader = DataLoader(dataset=self, - batch_size=mini_batch_size, - collate_fn=collate_fn, - generator=generator, - **dataloader_kwargs) - - def get_data(): - for _ in range(epochs): - for d in train_dataloader: - d.meta_info = self.meta_info - yield d - - return iter(get_data()) - - def chunk(self, chunks: int) -> List['DataProto']: - """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. - - Args: - chunks (int): the number of chunks to split on dim=0 - - Returns: - List[DataProto]: a list of DataProto after splitting - """ - if self.batch is not None: - batch_lst = self.batch.chunk(chunks=chunks, dim=0) - else: - batch_lst = [None for _ in range(chunks)] - - non_tensor_batch_lst = [{} for _ in range(chunks)] - for key, val in self.non_tensor_batch.items(): - if not isinstance(val, np.ndarray): - raise TypeError(f"Expected value of type np.ndarray for key '{key}', but got {type(val)}.") - non_tensor_lst = np.array_split(val, chunks) - if len(non_tensor_lst) != chunks: - raise ValueError(f"After splitting, the number of chunks for key '{key}' is {len(non_tensor_lst)}, " - f"which does not match the expected number of chunks: {chunks}.") - for i in range(chunks): - non_tensor_batch_lst[i][key] = non_tensor_lst[i] - - output = [] - for i in range(chunks): - output.append( - DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)) - - return output - - @staticmethod - def concat(data: List['DataProto']) -> 'DataProto': - """Concat a list of DataProto. The batch is concatenated among dim=0. - The meta_info is assumed to be identical and will use the first one. - - Args: - data (List[DataProto]): list of DataProto - - Returns: - DataProto: concatenated DataProto - """ - batch_lst = [] - for batch in data: - batch_lst.append(batch.batch) - if batch_lst[0] is not None: - new_batch = torch.cat(batch_lst, dim=0) - else: - new_batch = None - - non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) - for key, val in non_tensor_batch.items(): - non_tensor_batch[key] = np.concatenate(val, axis=0) - - return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) - - def reorder(self, indices): - """ - Note that this operation is in-place - """ - indices_np = indices.detach().numpy() - self.batch = self.batch[indices] - self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} - - -@dataclass -class DataProtoFuture: - """ - DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait - for data so that asynchronous execution becomes possible. - DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. - - collect_fn is a Callable that reduces the list of futures to a DataProto - - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select - - Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination - - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any - operation on the DataProtoFuture in driver. - """ - collect_fn: Callable - futures: List[ray.ObjectRef] - dispatch_fn: Callable = None - - @staticmethod - def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture': - output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) - return output - - def chunk(self, chunks: int) -> List['DataProtoFuture']: - from functools import partial - - arg_future_lst = [] - for i in range(chunks): - # note that we can't directly pass i and chunks - def dispatch_fn(x, i, chunks): - return x.chunk(chunks=chunks)[i] - - arg_future = DataProtoFuture(collect_fn=self.collect_fn, - dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), - futures=self.futures) - arg_future_lst.append(arg_future) - return arg_future_lst - - def get(self): - output = ray.get(self.futures) # dp_size. - for single_output in output: - if not isinstance(single_output, DataProto): - raise TypeError(f"Expected instance of DataProto, but got {type(single_output)}.") - output = self.collect_fn(output) # select dp, concat - if self.dispatch_fn is not None: - output = self.dispatch_fn(output) # split in batch dim, select using dp - return output - - -def make_batch_generator(batches, vpp_size): - if vpp_size > 1: - # has vpp - batch_generator = [batches] * vpp_size # number of vpp chunks - batch_generator = [iter(b) for b in batch_generator] - else: - # no vpp - batch_generator = iter(batches) - return batch_generator diff --git a/mindspeed_llm/tasks/posttrain/rlxf/utils/torch_functional.py b/mindspeed_llm/tasks/posttrain/rlxf/utils/torch_functional.py deleted file mode 100644 index 927d8d2bd203f62bc80f51643e4a6974b1e6ffd3..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/utils/torch_functional.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Contain small torch utilities -""" - -from typing import List - -import torch -import torch.distributed -from tensordict import TensorDict - - -def clip_by_value(x, tensor_min, tensor_max): - """ - Tensor extenstion to torch.clamp - """ - clipped = torch.max(torch.min(x, tensor_max), tensor_min) - return clipped - - -def masked_mean(values, mask, axis=None): - """Compute mean of tensor with a masked values.""" - return (values * mask).sum(axis=axis) / mask.sum(axis=axis) - - -def masked_var(values, mask, unbiased=True): - """Compute variance of tensor with masked values.""" - mean = masked_mean(values, mask) - centered_values = values - mean - variance = masked_mean(centered_values ** 2, mask) - if unbiased: - mask_sum = mask.sum() - if mask_sum == 0: - raise ValueError("At least one element in the mask has to be 1.") - # note that if mask_sum == 1, then there is a division by zero issue - # to avoid it you just need to use a larger minibatch_size - elif mask_sum == 1: - bessel_correction = mask_sum - else: - bessel_correction = mask_sum / (mask_sum - 1) - variance = variance * bessel_correction - return variance - - -def masked_whiten(values, mask, shift_mean=True): - """Whiten values with masked values.""" - mean, var = masked_mean(values, mask), masked_var(values, mask) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[TensorDict]: - if tensors.batch_size[0] % batch_size != 0: - raise ValueError(f"Input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}. " - f"Please ensure that the input batch size is divisible by the split batch size.") - return tensors.split(batch_size) diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/actor_train_infer.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/actor_train_infer.py deleted file mode 100644 index eeee7de9475f8b9e2bef9d033ecb4c1b15703b1b..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/workers/actor_train_infer.py +++ /dev/null @@ -1,890 +0,0 @@ -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. -import copy -import os -from functools import partial -from typing import Iterable, Dict - -import ray -import torch -import torch_npu -import torch.nn.functional as F -from tensordict import TensorDict - -from megatron.training import get_args, initialize_megatron, get_timers, get_tokenizer -from megatron.core.pipeline_parallel.schedules import get_forward_backward_func -from megatron.core import parallel_state as mpu, tensor_parallel -from megatron.training.training import append_to_progress_log, build_train_valid_test_data_iterators, print_datetime -from megatron.training import get_model -from megatron.training.utils import unwrap_model -from megatron.training.checkpointing import save_checkpoint -from megatron.training.training import num_floating_point_operations -from mindspeed_llm.tasks.posttrain.rlxf.training.core_algos import compute_policy_loss, find_first_eos_index, compute_grpo_policy_loss -from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import split_dict_tensor_into_batches -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import MegatronWorker -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch, Execute -from mindspeed_llm.tasks.posttrain.rlxf.training.parameter_mapping import sync_param_nums, run_auto_mapping, \ - init_parameter_mapping_distributed -from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, make_batch_generator -from mindspeed_llm.tasks.posttrain.base import BaseTrainer -import mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state as ps -from mindspeed_llm.tasks.inference.module import MegatronModuleForCausalLM -from mindspeed_llm.tasks.preprocess.blended_mtf_dataset import build_blended_mtf_dataset -from mindspeed_llm.training.initialize import set_jit_fusion_options -from mindspeed_llm.training.utils import get_finetune_data_on_this_tp_rank, get_tune_attention_mask -from mindspeed_llm.tasks.posttrain.utils import compute_log_probs, append_to_dict - - -def train_valid_test_datasets_provider(train_val_test_num_samples): - """Build the train test and validation datasets. - - Args: - train_val_test_num_samples : A list containing the number of samples in train test and validation. - """ - args = get_args() - print("> building train, validation, and test datasets for GPT ...") - - train_ds, valid_ds, test_ds = build_blended_mtf_dataset( - data_prefix=args.data_path, - splits_string=args.split, - train_valid_test_num_samples=train_val_test_num_samples, - seq_length=args.max_prompt_length, - seed=args.seed) - - print("> finished creating GPT datasets ...") - - return train_ds, valid_ds, test_ds - - -@ray.remote -class PPOActorWorker(MegatronWorker): - """ - A basic class to launch two megatron instances with different communication groups. - Currently assume that the first group is for training and the second group is for inference. - """ - - def __init__(self, config, role): - super().__init__() - self.config = config - self.role = role - self.IGNORE_INDEX = -100 - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - - initialize_megatron(args_defaults={'no_load_rng': True, 'no_load_optim': True}, - role=self.role, - config=self.config, - two_megatron=True) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def initialize(self): - init_parameter_mapping_distributed() - self.is_inference_node = ps.in_mg2_inference_group() - if self.is_inference_node: - self.node = PPOActorInferWorker() - else: - self.node = PPOActorTrainWorker() - torch.cuda.empty_cache() - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def auto_mapping(self): - if self.is_inference_node: - run_auto_mapping(self.node.inf_model) - else: - run_auto_mapping(self.node.actor.model[0]) - torch.cuda.empty_cache() - - @register(dispatch_mode=Dispatch.DP_ALL_GATHER_INFER, execute_mode=Execute.INFER) - def generate_sequences(self): - args = get_args() - output = self.node.run_inference() - tokenizer = get_tokenizer() - meta_info = {'eos_token_id': tokenizer.eos_token_id, 'pad_token_id': tokenizer.pad_token_id, 'num_samples_per_step':args.num_samples_per_step} - output.meta_info.update(meta_info) - torch.cuda.empty_cache() - return output - - @register(dispatch_mode=Dispatch.DP_ALL_GATHER_TRAIN, execute_mode=Execute.TRAIN) - def update_actor(self, data): - device = next(self.node.actor.model[0].parameters()).device - data = data.to(device) - - dataloader = self.node.actor.make_minibatch_iterator(data=data) - - metrics = self.node.actor.update_policy(dataloader=dataloader) - - output = DataProto(meta_info={'metrics': metrics}) - output = output.to('cpu') - torch.cuda.empty_cache() - return output - - @register(dispatch_mode=Dispatch.DP_ALL_GATHER_TRAIN, execute_mode=Execute.TRAIN) - def get_log_probs(self, data): - old_log_probs = self.node.actor.compute_log_prob(data) - if old_log_probs is not None: # pp last stage - data.batch['old_log_probs'] = old_log_probs - data = data.to('cpu') - else: # pp intermediate stage, no useful results - data = None - # clear kv cache - torch.cuda.empty_cache() - return data - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, execute_mode=Execute.TRAIN) - def save_checkpoint(self, iteration): - self.node.actor.save_checkpoint(iteration) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, execute_mode=Execute.TRAIN) - def get_iteration(self): - return self.node.actor.get_iteration() - - -class PPOActorTrainWorker(BaseTrainer): - def __init__(self): - super().__init__() - - def initialize(self): - self.args = get_args() - model, optimizer, opt_param_scheduler = self._build_model_and_optimizer() - self.actor = MegatronPPOActor(model=model, optimizer=optimizer, opt_param_scheduler=opt_param_scheduler) - sync_param_nums(model[0]) - if self.args.stage == "ray_online_dpo": - self.args.micro_batch_size *= 2 - self.args.ppo_mini_batch_size *= 2 - - def _build_model_and_optimizer(self): - from megatron.training.training import setup_model_and_optimizer - model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, self.model_type) - return model, optimizer, opt_param_scheduler - - def get_batch(self, data_iterator): - """ - Retrieves a batch of data from the data iterator. - Called during each forward step. - """ - pass - - def loss_func(self, input_tensor, output_tensor): - """ - Computes the loss function. - Called during each forward step. - """ - pass - - def forward_step(self, data_iterator, model): - """ - Performs a forward pass and computes the loss. - Called during each training iteration. - """ - pass - - -def pad_to_tensor_dict(data, padding_side="right", pad_multi_of=16): - max_length = torch.LongTensor([max(len(val) for val in data)]).cuda() - max_length = max_length if max_length % pad_multi_of == 0 else (max_length // pad_multi_of + 1) * pad_multi_of - torch.distributed.all_reduce(max_length, op=torch.distributed.ReduceOp.MAX) - - tokenizer = get_tokenizer() - - pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id - context_lengths = [len(val) for val in data] - - data_length = len(data) - for i in range(data_length): - if context_lengths[i] < max_length: - if padding_side == "right": - data[i].extend([pad_id] * (max_length - context_lengths[i])) - else: - data[i] = [pad_id] * (max_length - context_lengths[i]) + data[i] - return context_lengths, max_length - - -class PPOActorInferWorker(BaseTrainer): - def __init__(self): - super().__init__() - self.count = 0 - self.keys = None - - def model_provider(self, pre_process=True, post_process=True): - """Builds the inference model. - - If you set the use_mcore_models to True, it will return the mcore GPT model. - - Args: - pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. - post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. - - - Returns: - Union[GPTModelInfer, GPTModel]: The returned model - """ - - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec, \ - get_gpt_layer_local_spec - from megatron.core.transformer.spec_utils import import_module - from megatron.training import get_args, print_rank_0 - from megatron.training.arguments import core_transformer_config_from_args - from megatron.training.yaml_arguments import core_transformer_config_from_yaml - - from mindspeed_llm.tasks.inference.module import GPTModelInfer - - args = get_args() - use_te = args.transformer_impl == "transformer_engine" - - print_rank_0('building GPT Rollout model ...') - # Experimental loading arguments from yaml - if args.yaml_cfg is not None: - config = core_transformer_config_from_yaml(args, "language_model") - else: - config = core_transformer_config_from_args(args) - - if args.spec is not None: - transformer_layer_spec = import_module(args.spec) - else: - if use_te: - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, - args.moe_grouped_gemm) - else: - transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) - - model = GPTModelInfer( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=args.padded_vocab_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=False, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor - ) - - return model - - def initialize(self): - train_valid_test_datasets_provider.is_distributed = True - self.args = get_args() - self.timers = get_timers() - self.train_valid_test_datasets_provider = train_valid_test_datasets_provider - - if self.args.log_progress: - append_to_progress_log("Starting job") - # Set pytorch JIT layer fusion options and warmup JIT functions. - set_jit_fusion_options() - - self.timers('train/valid/test-data-iterators-setup', log_level=0).start( - barrier=True) - - self.args.num_layer_list = None - self.args.micro_batch_size = 1 - self.args.sequence_parallel = False - - self.args.model = unwrap_model(get_model(self.model_provider, wrap_with_ddp=False)) - self.inf_model = self.args.model[0] - self.args.dataset_additional_keys = eval(self.args.dataset_additional_keys[0]) if self.args.dataset_additional_keys else [] - - sync_param_nums(self.inf_model) - true_pad_to_multiple_of = self.args.pad_to_multiple_of - self.args.pad_to_multiple_of = 1 # we don't want to pad data here - self.train_data_iterator, self.valid_data_iterator, self.test_data_iterator \ - = build_train_valid_test_data_iterators( - self.train_valid_test_datasets_provider) - self.args.pad_to_multiple_of = true_pad_to_multiple_of - self.timers('train/valid/test-data-iterators-setup').stop() - print_datetime('after dataloaders are built') - - # Print setup timing. - print('done with setup ...') - self.timers.log(['model-setup', 'train/valid/test-data-iterators-setup'], barrier=True) - - def get_batch(self, data_iterator): - """Generate a batch identical to Llama factory""" - args = get_args() - - self.keys = ['input_ids', *self.args.dataset_additional_keys] - - if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - if args.variable_seq_lengths and args.pipeline_model_parallel_size > 2: - tokens, _ = get_finetune_data_on_this_tp_rank(data_iterator) - else: - tokens = None - return tokens - - # Items and their type. - - data_type = torch.int64 - cur_data = next(data_iterator) - - # add problem category for reward choosing - if args.dataset_category is not None: - dataset_category = [int(item) for item in args.dataset_category.split(",")] - categories = [dataset_category[id.item()] for id in cur_data['dataset_id']] - cur_data['categories'] = torch.tensor(categories, dtype=torch.int64) - - # Broadcast data. - data_b = tensor_parallel.broadcast_data(self.keys, cur_data, data_type) - - # Unpack - batch = {} - for key in self.keys: - batch[key] = data_b.get(key).long() - - return batch - - def run_inference(self): - args = get_args() - num_infer_steps = args.global_batch_size // (args.data_parallel_size * args.num_samples_per_step) - responses = [] - idx_list = [] - idx_list_per_step = [] - additional_dict = {} - additional_dict_per_step = {} - - for k in self.args.dataset_additional_keys: - if not hasattr(additional_dict, k): - additional_dict[k] = [] - - max_new_tokens = args.seq_length - args.max_prompt_length - if max_new_tokens % args.pad_to_multiple_of != 0: - raise ValueError(f"Please adjust pad_to_multiple_of so that max_new_tokens % args.pad_to_multiple_of == 0. " - f"Current max_new_tokens: {max_new_tokens}, pad_to_multiple_of: {args.pad_to_multiple_of}") - for _ in range(num_infer_steps): - for k in self.args.dataset_additional_keys: - if not hasattr(additional_dict_per_step, k): - additional_dict_per_step[k] = [] - - for _ in range(args.num_samples_per_step): - batch = self.get_batch(self.train_data_iterator) - - tokens = batch["input_ids"] - tokens_list = tokens.view(-1).cpu().numpy().tolist() - - for additional_key in self.args.dataset_additional_keys: - additional_val = batch.get(additional_key).view(-1).cpu().numpy().tolist() - - for _ in range(args.n_samples_per_prompt): - additional_dict_per_step.get(additional_key).append(copy.deepcopy(additional_val)) - - for _ in range(args.n_samples_per_prompt): - idx_list_per_step.append(copy.deepcopy(tokens_list)) - - if args.stage == "ray_online_dpo": - idx_list_per_step.append(copy.deepcopy(tokens_list)) - - responses_per_step = self.inf_model.generate( - copy.deepcopy(idx_list_per_step), - max_new_tokens=max_new_tokens, - temperature=args.temperature, - do_sample=args.do_sample, - detokenize=False, - broadcast=False, - truncate=True - ) - - if not isinstance(responses_per_step, list): - responses_per_step = [responses_per_step] - - responses.extend(responses_per_step) - idx_list.extend(idx_list_per_step) - idx_list_per_step = [] - - for k in additional_dict: - additional_dict[k].extend(additional_dict_per_step[k]) - - additional_dict_per_step = {} - - - responses_ori_length, responses_pad_length = pad_to_tensor_dict( - responses, - pad_multi_of=args.pad_to_multiple_of - ) - prompts_ori_length, prompts_pad_length = pad_to_tensor_dict( - idx_list, "left", - pad_multi_of=args.pad_to_multiple_of - ) - - for additional_key in self.args.dataset_additional_keys: - tmp_val = additional_dict.get(additional_key) - pad_to_tensor_dict( - tmp_val, - pad_multi_of=args.pad_to_multiple_of - ) - additional_dict[additional_key] = tmp_val - - input_ids = [prompt + response for prompt, response in zip(idx_list, responses)] - - attention_mask = generate_attention_mask(input_ids, prompts_ori_length, prompts_pad_length, - responses_ori_length, responses_pad_length) - - position_ids = generate_position_ids_from_attention_mask(input_ids, prompts_ori_length, prompts_pad_length) - if self.args.stage == "ray_online_dpo": - batch_size = args.global_batch_size // args.data_parallel_size * 2 - else: - batch_size = args.global_batch_size // args.data_parallel_size * args.n_samples_per_prompt - - batch = TensorDict( - dict( - { - "prompts": idx_list, - "responses": responses, - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "responses_ori_length": responses_ori_length - }, **additional_dict - ), - batch_size=batch_size - ) - - return DataProto(batch=batch) - - def loss_func(self, input_tensor, output_tensor): - """ - Computes the loss function. - Called during each forward step. - """ - pass - - def forward_step(self, data_iterator, model): - """ - Performs a forward pass and computes the loss. - Called during each training iteration. - """ - pass - - -def generate_position_ids_from_attention_mask(input_ids_list, prompts_ori_length, prompts_pad_length): - """ - 生成与 attention_mask 对应的 position_ids 列表。 - - 参数: - input_ids_list (list of lists): 包含 input_ids 的列表,每个元素是一个列表。 - prompts_ori_length (list of lists): 包含 prompt_ori_length 的列表,每个元素是int。 - prompts_pad_length int: prompts_pad_length,int。 - - 返回: - list of lists: 包含 position_ids 的列表,每个元素是一个列表。 - """ - position_ids_list = [] - for idx, input_ids in enumerate(input_ids_list): - prompt_pad_length = prompts_pad_length - prompts_ori_length[idx] - position_ids = [0] * prompt_pad_length + list(range(len(input_ids) - prompt_pad_length)) - position_ids_list.append(position_ids) - - return position_ids_list - - -def generate_attention_mask(input_ids_list, prompts_ori_length, prompts_pad_length, responses_ori_length, - responses_pad_length): - """ - 生成与 input_ids 对应的 attention_mask 列表。 - - 参数: - input_ids_list (list of lists): 包含 input_ids 的列表,每个元素是一个列表。 - prompts_ori_length (list of lists): 包含 prompt_ori_length 的列表,每个元素是int。 - prompts_pad_length int: prompts_pad_length,int。 - responses_ori_length (list of lists): 包含 response_ori_length 的列表,每个元素是int。 - responses_pad_length int: responses_pad_length,int。 - - 返回: - list of lists: 包含 attention_mask 的列表,每个元素是一个列表。 - """ - attention_mask_list = [] - - for idx, input_ids in enumerate(input_ids_list): - attention_mask = torch.ones_like(torch.tensor(input_ids)) - prompt_pad_length = prompts_pad_length - prompts_ori_length[idx] - response_pad_length = responses_pad_length - responses_ori_length[idx] - attention_mask[:prompt_pad_length] = 0 - if response_pad_length > 0: - attention_mask[-response_pad_length:] = 0 - attention_mask_list.append(attention_mask.numpy().tolist()) - - return attention_mask_list - - -def split_two_prompts(origin_tensor): - origin_tensor = origin_tensor.reshape(-1, 2) - first_half, second_half = origin_tensor.split(1, dim=1) - return first_half.reshape(-1), second_half.reshape(-1) - - - -class MegatronPPOActor(): - - def __init__(self, model, optimizer, opt_param_scheduler): - """MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron. - - Args: - model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and - ``model_config.hidden_size`` - megatron_config (OmegaConf): megatron configuration. It must contains - - ``sequence_parallel_enabled``: whether the sequence parallel is enabled. - - ``param_dtype``: the dtype of the parameters. - - ``virtual_pipeline_model_parallel_size``: virtual pipeline model parallel size. a.k.a number of chunks in each pp stage. - actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this pp stage. - each nn.Module in this rank holds a vpp module chunk. - The actor module has some constraints to follow in order to use the updating logics implemented here - - 1. It must implement unpad_input before any computation and pad_input after all the computation. Remove padding is an - optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn - - 2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size], - where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size - of the hidden state is [total_nnz // tp, 1, hidden_size]. - actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. It implements - zero1 optimizer that shards the optimizer state across dp ranks. - - """ - self.args = get_args() - self.model = model - self.optimizer = optimizer - self.opt_param_scheduler = opt_param_scheduler - self._beta = self.args.dpo_beta - self.num_floating_point_operations_so_far = 0 - - def get_iteration(self): - return self.args.iteration - - def save_checkpoint(self, iteration): - - save_checkpoint(iteration, self.model, self.optimizer, self.opt_param_scheduler, - self.num_floating_point_operations_so_far) - - def compute_log_prob(self, data) -> torch.Tensor: - """Compute the log probability of the responses given input_ids, attention_mask and position_ids - - Args: - data (DataProto): a DataProto containing keys - - ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the - concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. - - ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``responses``: tensor of shape [batch_size, response_length]. torch.int64. - - Returns: - DataProto: torch.Tensor: the log_prob tensor - """ - data.batch = data.batch.contiguous() - - def compute_logprobs_fn(output, data): - response = data['responses'] - response_length = response.size(1) - logits = output - logits = logits[:, -response_length - 1:-1] - _, _, log_probs = compute_log_probs(logits, response, per_token=True) - return {'log_probs': log_probs} - - # We make recompute_old_log_prob by default here. - data = data.to(next(self.model[0].parameters()).device) - with torch.no_grad(): - output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn) - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # only on last rank. It should be on every tp rank - log_probs = torch.cat([single_output['log_probs'] for single_output in output], dim=0) # (bs, seq_size) - log_probs = log_probs.to(torch.float32) - else: - log_probs = None - - # add empty cache after each compute - torch.cuda.empty_cache() - - return log_probs - - @property - def beta(self): - if isinstance(self._beta, list): - epoch = self.state.epoch - return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1] - else: - return self._beta - - def make_minibatch_iterator(self, data): - """Make minibatch iterator for updating the actor - - Args: - data (DataProto): a DataProto containing keys - - ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where ``sequence_length = prompt_length + response_length`` - - ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64 - - ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64 - - ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that responses = input_ids[:, -response_length:] - - ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability of responses. - - ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of responses. - See PPO paper for details. - - Returns: - - """ - return data.make_iterator(mini_batch_size=self.args.ppo_mini_batch_size, - epochs=self.args.ppo_epochs, - dataloader_kwargs={'shuffle': self.args.shuffle_minibatch}) - - def forward_backward_batch(self, data, forward_only=False, post_process_fn=None): - """ - We assume: - - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input - - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled - """ - # broadcast from last pp rank to all other pp ranks - - data.batch['attention_mask'] = data.batch['attention_mask'].to(bool) - - batch_size = self.args.micro_batch_size - batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size) - - n_micro_batch = len(batches) - seq_len = batches[0]['input_ids'].shape[1] - - forward_backward_func = get_forward_backward_func() - - def loss_func_ppo(output, data, meta_info): - """ - This loss_func has two modes - 1. when forward_only is True: use post_process_fn to calculate the log_probs - 2. when forward_only is False: calculate the policy loss - """ - if forward_only: - if post_process_fn is None: - return 1.0, {'logits': output} - else: - return 1.0, post_process_fn(output, data) - - responses = data['responses'] - response_length = responses.size(1) - attention_mask = data['attention_mask'] - response_mask = attention_mask[:, -response_length:] - old_log_prob = data['old_log_probs'] - advantages = data['advantages'] - - clip_ratio = meta_info['clip_ratio'] - - # compute policy loss - logits = output - logits = logits[:, -response_length - 1:-1] - _, _, log_prob = compute_log_probs(logits, responses, per_token=True) - pg_loss, pg_clipfrac, ppo_kl = compute_policy_loss(old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - eos_mask=response_mask, - cliprange=clip_ratio) - policy_loss = pg_loss - # return loss and stats - stats = { - 'actor/pg_loss': abs(pg_loss.detach().item()), - 'actor/pg_clipfrac': pg_clipfrac.detach().item(), - 'actor/ppo_kl': ppo_kl.detach().item() - } - return policy_loss, stats - - def loss_func_grpo(output, data, meta_info): - """ - This loss_func has two modes - 1. when forward_only is True: use post_process_fn to calculate the log_probs - 2. when forward_only is False: calculate the policy loss - """ - if forward_only: - if post_process_fn is None: - return 1.0, {'logits': output} - else: - return 1.0, post_process_fn(output, data) - - responses = data['responses'] - response_length = responses.size(1) - attention_mask = data['attention_mask'] - response_mask = attention_mask[:, -response_length:] - old_log_prob = data['old_log_probs'] - advantages = data['advantages'] - ref_log_prob = data['ref_log_prob'] - clip_ratio = meta_info['clip_ratio'] - - # compute policy loss - logits = output - logits = logits[:, -response_length - 1:-1] - _, _, log_prob = compute_log_probs(logits, responses, per_token=True) - - pg_loss, pg_clipfrac, ppo_kl = compute_grpo_policy_loss(old_log_prob=old_log_prob, - log_prob=log_prob, - ref_log_prob=ref_log_prob, - advantages=advantages, - eos_mask=response_mask, - cliprange=clip_ratio, - kl_ctrl=self.args.kl_ctrl) - policy_loss = pg_loss - - stats = { - 'actor/pg_loss': abs(pg_loss.detach().item()), - 'actor/pg_clipfrac': pg_clipfrac.detach().item(), - 'actor/ppo_kl': ppo_kl.detach().item() - } - return policy_loss, stats - - def loss_func_online_dpo(output, data, meta_info): - """ - calculate the policy loss - """ - args = get_args() - scores = data['rm_scores'] - responses = data['responses'] - device = responses.device - ref_logprobs = data['ref_log_prob'] - response_length = responses.size(1) - attention_mask = data['attention_mask'] - response_mask = attention_mask[:, -response_length:] - num_examples = responses.shape[0] // 2 - - actual_start = torch.arange(responses.size(0), device=responses.device) - tokenizer = get_tokenizer() - score_first_eos_index, reward_first_eos_index = find_first_eos_index(responses, tokenizer.eos_token_id) - - scores = scores[[actual_start, score_first_eos_index]] - contain_eos_token = torch.any(responses == tokenizer.eos_token_id, dim=-1) - if args.missing_eos_penalty is not None: - scores[~contain_eos_token] -= args.missing_eos_penalty - data['rm_scores'] = scores - first_half, second_half = split_two_prompts(scores) - - mask = first_half >= second_half - num_examples_range = torch.arange(num_examples, device=device) - chosen_indices = num_examples_range + (~mask * num_examples) - rejected_indices = num_examples_range + (mask * num_examples) - # Build tensor so that the first half is the chosen examples and the second half the rejected examples - cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected - logits = output[:, -response_length - 1:-1] - _, _, log_prob = compute_log_probs(logits, responses, per_token=True) - - cr_logprobs = log_prob[cr_indices] - cr_ref_logprobs = ref_logprobs[cr_indices] - - # mask out the padding tokens - padding_mask = ~response_mask.bool() - cr_padding_mask = padding_mask[cr_indices] - - cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1) - cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1) - - # Split the chosen and rejected examples - chosen_logprobs_sum, rejected_logprobs_sum = split_two_prompts(cr_logprobs_sum) - chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = split_two_prompts(cr_ref_logprobs_sum) - pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum - ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum - - logits = pi_logratios - ref_logratios - - if args.dpo_loss_type == "sigmoid": - losses = -F.logsigmoid(self.beta * logits) - elif args.dpo_loss_type == "ipo": - losses = (logits - 1 / (2 * self.beta)) ** 2 - else: - raise NotImplementedError(f"invalid loss type {self.loss_type}") - - loss = losses.mean() - chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) - rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) - - stats = { - 'actor/pg_loss': loss.detach().item(), - 'beta': self.beta, - 'logps/chosen': chosen_logprobs_sum.mean().detach().item(), - 'logps/rejected': rejected_logprobs_sum.mean().detach().item(), - 'rewards/chosen': chosen_rewards.mean().detach().item(), - 'rewards/rejected': rejected_rewards.mean().detach().item(), - } - return loss, stats - - - def forward_step(batch_iter, model): - batch = next(batch_iter) - input_ids = batch['input_ids'] - attention_mask_1d = batch['attention_mask'] - position_ids = batch['position_ids'] - attention_mask = get_tune_attention_mask(attention_mask_1d) - output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) - if forward_only: - meta_info = None - else: - meta_info = {'clip_ratio': self.args.clip_ratio} - - loss_funcs = { - "ray_ppo": loss_func_ppo, - "ray_online_dpo": loss_func_online_dpo, - "ray_grpo": loss_func_grpo - } - - loss_func = loss_funcs.get(self.args.stage) - return output, partial(loss_func, data=batch, meta_info=meta_info) - - # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(batches, vpp_size=len(self.model)) - losses_reduced = forward_backward_func( - forward_step_func=forward_step, - data_iterator=batch_generator, - model=self.model, - num_microbatches=n_micro_batch, - seq_length=seq_len, # unused when variable_seq_lengths - micro_batch_size=self.args.micro_batch_size, # unused when variable_seq_lengths - forward_only=forward_only - ) - - return losses_reduced - - def update_policy(self, dataloader: Iterable[DataProto]) -> Dict: - """Update the policy with an iterator of DataProto - - Args: - dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator`` - The keys of each data batch is described in the make_minibatch_iterator. - - Returns: - Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage - and users have to combine the output in each dp rank manually. - - """ - metrics = {} - model = self.model - optimizer = self.optimizer - opt_param_scheduler = self.opt_param_scheduler - - - for model_module in self.model: - model_module.train() - - for data in dataloader: - - for model_chunk in model: - model_chunk.zero_grad_buffer() - optimizer.zero_grad() - if self.args.stage == 'ray_grpo': - self.args.kl_ctrl = data.meta_info['kl_ctrl'] - metric_micro_batch = self.forward_backward_batch(data) - - update_successful, grad_norm, num_zeros_in_grad = optimizer.step() - - if update_successful: - increment = 1 - opt_param_scheduler.step(increment=increment) - - for metric in metric_micro_batch: - append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics. - - self.args.consumed_train_samples += self.args.global_batch_size - self.num_floating_point_operations_so_far += num_floating_point_operations(self.args, self.args.global_batch_size) - - # add empty cache after each compute - torch.cuda.empty_cache() - - return metrics diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/critic.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/critic.py deleted file mode 100644 index af2aa1b822d51982686000b4424f862832285839..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/workers/critic.py +++ /dev/null @@ -1,350 +0,0 @@ -import time -import json -import os -from typing import Dict, Union, Iterable -from functools import partial - -import ray -import torch -import torch_npu - -import megatron -from megatron.training import print_rank_0 -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, -) -from megatron.core.utils import get_model_config -from megatron.core import mpu -from megatron.training.arguments import core_transformer_config_from_args -from megatron.training.yaml_arguments import core_transformer_config_from_yaml -from megatron.core.transformer.spec_utils import import_module -from megatron.core.models.gpt import GPTModel - -from megatron.training.training import ( - print_datetime, - get_one_logger, - append_to_progress_log, -) -from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.training.checkpointing import save_checkpoint -from megatron.training.training import num_floating_point_operations -from megatron.training import get_args, initialize_megatron, get_timers - -from mindspeed_llm.tasks.posttrain.utils import append_to_dict -from mindspeed_llm.training.utils import get_tune_attention_mask -from mindspeed_llm.tasks.posttrain.base import BaseTrainer -from mindspeed_llm.tasks.posttrain.orm.orm_model import GPTRewardModel -from mindspeed_llm.training.initialize import set_jit_fusion_options - -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import MegatronWorker -from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import masked_mean, split_dict_tensor_into_batches -from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, make_batch_generator -from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import clip_by_value - -_TRAIN_START_TIME = time.time() - - -@ray.remote -class CriticWorker(MegatronWorker): - def __init__(self, config, role): - """ - """ - super().__init__() - - self.config = config - self.role = role - self.IGNORE_INDEX = -100 - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - initialize_megatron(role=self.role, - config=self.config) - - self.args = get_args() - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def initialize(self): - self.critic = MegatronPPOCritic() - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def compute_values(self, data: DataProto): - data = data.to('cuda') - values = self.critic.compute_values(data=data) - output = DataProto.from_dict(tensors={'values': values}) - output = output.to('cpu') - return output - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def update_critic(self, data: DataProto): - data = data.to('cuda') - dataloader = self.critic.make_minibatch_iterator(data) - metrics = self.critic.update_critic(dataloader=dataloader) - output = DataProto(batch=None, meta_info={'metrics': metrics}) - output = output.to('cpu') - return output - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, iteration): - self.critic.save_checkpoint(iteration) - - -class MegatronPPOCritic(BaseTrainer): - def __init__(self): - super().__init__() - - def initialize(self): - self.args = get_args() - self.timers = get_timers() - self.num_floating_point_operations_so_far = 0 - - if self.args.log_progress: - append_to_progress_log("Starting job") - # Set pytorch JIT layer fusion options and warmup JIT functions. - set_jit_fusion_options() - # Adjust the startup time, so it reflects the largest value. - # This will be closer to what scheduler will see (outside of - # image ... launches. - global _TRAIN_START_TIME - start_time_tensor = torch.tensor( - [_TRAIN_START_TIME], - dtype=torch.float, - device='cuda' - ) - torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN) - _TRAIN_START_TIME = start_time_tensor.item() - print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME)) - print_datetime('after megatron is initialized') - one_logger = get_one_logger() - if one_logger: - one_logger.log_metrics({ - 'train_iterations_warmup': 5 - }) - - from megatron.training.training import setup_model_and_optimizer - # Model, optimizer, and learning rate. - self.timers('model-and-optimizer-setup', log_level=0).start(barrier=True) - model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, self.model_type) - - self.timers('model-and-optimizer-setup').stop() - print_datetime('after model, optimizer, and learning rate ' - 'scheduler are built') - model_config = get_model_config(model[0]) - - self.model = model - self.optimizer = optimizer - self.opt_param_scheduler = opt_param_scheduler - self.model_config = model_config - self.process_non_loss_data_func = None - - def save_checkpoint(self, iteration): - save_checkpoint(iteration, self.model, self.optimizer, self.opt_param_scheduler, - self.num_floating_point_operations_so_far) - - - @staticmethod - def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: - """Builds the model. - - Currently supports only the mcore GPT model. - - Args: - pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. - post_process (bool, optional): Set to true if you need to want to compute output logits/loss. - Defaults to True. - - Returns: - Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model - """ - args = get_args() - use_te = args.transformer_impl == "transformer_engine" - - print_rank_0('building GPT model ...') - # Experimental loading arguments from yaml - if args.yaml_cfg is not None: - config = core_transformer_config_from_yaml(args, "language_model") - else: - config = core_transformer_config_from_args(args) - - if not args.use_mcore_models: - raise ValueError("Reward model training currently supports mcore only.") - - if args.spec is not None: - transformer_layer_spec = import_module(args.spec) - else: - if use_te: - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, - args.moe_grouped_gemm) - else: - transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) - - model = GPTRewardModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=args.padded_vocab_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - post_layer_norm=not args.no_post_layer_norm, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - ) - - return model - - def critic_data_padding(self, data: DataProto) -> DataProto: - if 'response_mask' in data.batch.keys(): - return data - - prompt_length = data.batch['prompts'].shape[1] - response_mask = data.batch['attention_mask'] - response_mask[..., :prompt_length] = 0 - data.batch['response_mask'] = response_mask - - return data - - def get_batch(self, data_iterator): - self.timers('batch-generator', log_level=2).start() - batch = next(data_iterator) - input_ids = batch["input_ids"] - attention_mask_1d = batch["attention_mask"] - attention_mask = get_tune_attention_mask(attention_mask_1d) - position_ids = batch["position_ids"] - - return batch, input_ids, attention_mask, position_ids - - def forward_backward_batch(self, data_proto: DataProto, forward_only=False): - data_proto.batch = data_proto.batch.contiguous() - args = get_args() - data = data_proto.batch - - forward_batch_size = data["input_ids"].shape[0] - forward_num_microbatches = forward_batch_size // args.micro_batch_size - - batches = split_dict_tensor_into_batches(data, batch_size=args.micro_batch_size) - data_iterator = make_batch_generator(batches, vpp_size=len(self.model)) - - forward_backward_func = get_forward_backward_func() - losses_reduced = forward_backward_func( - forward_step_func=self.forward_step, - data_iterator=data_iterator, - model=self.model, - num_microbatches=forward_num_microbatches, - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, - collect_non_loss_data=forward_only, - forward_only=forward_only) - - return losses_reduced - - def compute_values(self, data: DataProto): - responses = data.batch['responses'] - attention_mask = data.batch['attention_mask'] - response_length = responses.size(1) - - for model_module in self.model: - model_module.eval() - - with torch.no_grad(): - output = self.forward_backward_batch(data, forward_only=True) - - if mpu.is_pipeline_last_stage(ignore_virtual=True): - values = torch.cat(output, dim=0).squeeze(-1) - values = values.to(torch.float32) - else: - values = torch.empty_like(attention_mask, dtype=torch.float32) - - values = values * attention_mask - values = values[:, -response_length - 1:-1] - values = values.contiguous() - - self.args.consumed_train_samples += self.args.global_batch_size - self.num_floating_point_operations_so_far += num_floating_point_operations(self.args, self.args.global_batch_size) - torch.cuda.empty_cache() - return values - - def update_critic(self, dataloader: Iterable[DataProto]): - metrics = {} - for model_module in self.model: - model_module.train() - - for data in dataloader: - for model_chunk in self.model: - model_chunk.zero_grad_buffer() - - self.optimizer.zero_grad() - - metric_micro_batch = self.forward_backward_batch(data, forward_only=False) - - # Empty unused memory. - if self.args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() - - update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step() - - # # Update learning rate. - if update_successful: - increment = self.args.critic_mini_batch_size - self.opt_param_scheduler.step(increment=increment) - - for metric in metric_micro_batch: - append_to_dict(metrics, metric) - - torch.cuda.empty_cache() - return metrics - - def make_minibatch_iterator(self, data: DataProto): - select_keys = data.batch.keys() - data = data.select(batch_keys=select_keys) - return data.make_iterator(mini_batch_size=self.args.critic_mini_batch_size, - epochs=self.args.critic_update_epochs, - ) - - def loss_func(self, data, output_tensor, non_loss_data=False): - if non_loss_data: - return output_tensor - - responses = data['responses'] - response_length = responses.size(1) - attention_mask = data['attention_mask'] - eos_mask = attention_mask[:, -response_length:] - eos_p1_index = torch.min(torch.cumsum(eos_mask, dim=-1)[:, -1], - torch.tensor(eos_mask.shape[1], device=eos_mask.device)) - eos_mask[:, eos_p1_index] = 1 - - cliprange_value = self.args.cliprange_value - curr_values = output_tensor.squeeze(-1) - curr_values = curr_values[:, -response_length - 1:-1] - curr_values = torch.masked_fill(curr_values, ~eos_mask, 0) - - returns = data['returns'] - - if cliprange_value > 0.0: - prev_values = data['values'] - vpredclipped = clip_by_value(curr_values, prev_values - cliprange_value, prev_values + cliprange_value) - vf_losses1 = (vpredclipped - returns) ** 2 - else: - vf_losses1 = torch.tensor(0.0).to(curr_values.device) - - vf_losses2 = (curr_values - returns) ** 2 - - vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask) - vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask) - - stats = { - 'critic/vf_loss': vf_loss.detach().item(), - 'critic/vf_clipfrac': vf_clipfrac.detach().item(), - 'critic/vpred_mean': masked_mean(curr_values, eos_mask).detach().item(), - } - - return vf_loss, stats - - def forward_step(self, data_iterator, model): - batch, input_ids, attention_mask, position_ids = self.get_batch(data_iterator) - output_tensor = model(input_ids, position_ids, attention_mask) - - return output_tensor, partial(self.loss_func, batch) diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py deleted file mode 100644 index 268cd8b9b661864e101f6148d5d482c51f884440..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. -import os -import time -from functools import partial - -import ray -import torch - -from megatron.training import get_args, get_timers, get_one_logger -from megatron.training import print_rank_0 -from megatron.core import mpu -from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.training.training import append_to_progress_log, print_datetime, get_model -from megatron.training.utils import unwrap_model -from megatron.training.initialize import initialize_megatron -from mindspeed_llm.tasks.checkpoint.models import load_checkpoint -from mindspeed_llm.tasks.posttrain.base import BaseTrainer -from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, make_batch_generator -from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import split_dict_tensor_into_batches -from mindspeed_llm.tasks.posttrain.utils import compute_log_probs -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import MegatronWorker -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch -from mindspeed_llm.training.initialize import set_jit_fusion_options -from mindspeed_llm.training.utils import get_tune_attention_mask - -_TRAIN_START_TIME = time.time() - - -@ray.remote -class ReferenceWorker(MegatronWorker): - """ - Ray ReferenceWorker - """ - - def __init__(self, config, role): - super().__init__() - - self.config = config - self.role = role - - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - initialize_megatron(role=self.role, - config=self.config) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def initialize(self): - self.reference = MegatronPPOReference() - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def compute_ref_log_prob(self, data: DataProto): - output = self.reference.compute_log_prob(data=data) - if output is not None: - output = DataProto.from_dict(tensors={'ref_log_prob': output}) - output = output.to('cpu') - torch.cuda.empty_cache() - return output - - -class MegatronPPOReference(BaseTrainer): - def __init__(self): - super().__init__() - - def initialize(self): - self.args = get_args() - self.timers = get_timers() - - if self.args.log_progress: - append_to_progress_log("Starting job") - # Set pytorch JIT layer fusion options and warmup JIT functions. - set_jit_fusion_options() - # Adjust the startup time, so it reflects the largest value. - # This will be closer to what scheduler will see (outside of - # image ... launches. - global _TRAIN_START_TIME - start_time_tensor = torch.tensor( - [_TRAIN_START_TIME], - dtype=torch.float, - device='cuda' - ) - torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN) - _TRAIN_START_TIME = start_time_tensor.item() - print_rank_0('Time to initialize Megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME)) - print_datetime('after megatron is initialized') - one_logger = get_one_logger() - if one_logger: - one_logger.log_metrics({ - 'train_iterations_warmup': 5 - }) - - self.timers('model-setup', log_level=0).start(barrier=True) - - self.model = get_model(self.model_provider, self.model_type, wrap_with_ddp=False) - unwrapped_model = unwrap_model(self.model) - if self.args.stage == "ray_online_dpo": - self.args.micro_batch_size *= 2 - - if self.args.load is not None or self.args.pretrained_checkpoint is not None: - self.timers('load-checkpoint', log_level=0).start(barrier=True) - self.args.iteration, self.args.num_floating_point_operations_so_far = load_checkpoint( - self.model, None, None) - self.timers('load-checkpoint').stop(barrier=True) - self.timers.log(['load-checkpoint']) - else: - self.args.iteration = 0 - self.args.num_floating_point_operations_so_far = 0 - - # get model without FP16 and/or DDP wrappers - if self.args.iteration == 0 and len(unwrapped_model) == 1 \ - and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'): - print_rank_0("Initializing ICT from pretrained BERT model") - unwrapped_model[0].init_state_dict_from_bert() - - self.timers('model-setup').stop() - print_datetime('after model built') - - def compute_log_prob(self, data: DataProto): - - data.batch = data.batch.contiguous() - - for model_module in self.model: - model_module.eval() - - data = data.to(next(self.model[0].parameters()).device) - with torch.no_grad(): - output = self.forward_backward_batch(data, forward_only=True) - if mpu.is_pipeline_last_stage(ignore_virtual=True): - ref_log_probs = torch.cat([out['ref_log_probs'] for out in output], dim=0) # (bs, seq_size) - ref_log_probs = ref_log_probs.to(torch.float32) - else: - ref_log_probs = None - - return ref_log_probs - - def forward_backward_batch(self, data, forward_only=False): - """ - We assume: - - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input - - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled - """ - args = get_args() - data.batch['attention_mask'] = data.batch['attention_mask'].to(bool) - - if data.meta_info.get('micro_batch_size', None) is not None: - batch_size = data.meta_info['micro_batch_size'] - else: - batch_size = args.micro_batch_size - batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size) - - n_micro_batch = len(batches) - seq_len = batches[0]['input_ids'].shape[1] - - forward_backward_func = get_forward_backward_func() - - # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(batches, vpp_size=len(self.model)) - losses_reduced = forward_backward_func( - forward_step_func=self.forward_step, - data_iterator=batch_generator, - model=self.model, - num_microbatches=n_micro_batch, - seq_length=seq_len, # unused when variable_seq_lengths - micro_batch_size=args.micro_batch_size, # unused when variable_seq_lengths - forward_only=forward_only - ) - return losses_reduced - - def forward_step(self, batch_iter, model): - input_ids, attention_mask, position_ids, response = self.get_batch(batch_iter) - output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) - return output, partial(self.loss_func, response=response) - - def get_batch(self, batch_iter): - batch = next(batch_iter) - input_ids = batch['input_ids'] - attention_mask_1d = batch['attention_mask'] - position_ids = batch['position_ids'] - attention_mask = get_tune_attention_mask(attention_mask_1d) - response = batch['responses'] - return input_ids, attention_mask, position_ids, response - - def loss_func(self, output, response): - response_length = response.size(1) - logits = output - logits = logits[:, -response_length - 1:-1] - log_probs = compute_log_probs(logits, response, per_token=True)[2] - return 1.0, {'ref_log_probs': log_probs} diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/reward.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/reward.py deleted file mode 100644 index f46cf8089d87077421f97dcca4804e6e57a06a26..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/posttrain/rlxf/workers/reward.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. -import os -import time -from functools import partial -import re - -import ray -import torch -from transformers import AutoTokenizer - -from megatron.training import print_rank_0 -from megatron.training import get_args, get_timers, get_one_logger -from megatron.training.utils import unwrap_model -from megatron.training.initialize import initialize_megatron -from megatron.training.checkpointing import load_checkpoint -from megatron.training.training import append_to_progress_log, print_datetime, get_model -from megatron.core import mpu -from megatron.core.utils import get_model_config -from megatron.core.pipeline_parallel import get_forward_backward_func - -from mindspeed_llm.training.initialize import set_jit_fusion_options -from mindspeed_llm.training.utils import get_tune_attention_mask -from mindspeed_llm.tasks.posttrain.orm import ORMTrainer -from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import pad_to_tensor_dict, \ - generate_attention_mask, generate_position_ids_from_attention_mask -from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto -from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import split_dict_tensor_into_batches -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import MegatronWorker -from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch - - -_TRAIN_START_TIME = time.time() - - -@ray.remote -class RewardWorker(MegatronWorker): - """ - Ray RewardWorker - """ - def __init__(self, config, role): - super().__init__() - - self.config = config - self.role = role - self.rm = None - - os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' - initialize_megatron(role=self.role, - config=self.config) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def initialize(self): - self.rm = MegatronPPORM() - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def compute_rm_score(self, data: DataProto): - data = data.to('cuda') - output = self.rm.compute_rm_score(data=data) - output = DataProto.from_dict(tensors={'rm_scores': output}) - output = output.to('cpu') - torch.cuda.empty_cache() - return output - - -class MegatronPPORM(ORMTrainer): - def __init__(self): - super().__init__() - - def initialize(self): - self.args = get_args() - self.timers = get_timers() - - if self.args.log_progress: - append_to_progress_log("Starting job") - # Set pytorch JIT layer fusion options and warmup JIT functions. - set_jit_fusion_options() - # Adjust the startup time, so it reflects the largest value. - # This will be closer to what scheduler will see (outside of - # image ... launches. - global _TRAIN_START_TIME - start_time_tensor = torch.tensor( - [_TRAIN_START_TIME], - dtype=torch.float, - device='cuda' - ) - torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN) - _TRAIN_START_TIME = start_time_tensor.item() - print_rank_0('Time to initialize Megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME)) - print_datetime('after megatron is initialized') - one_logger = get_one_logger() - if one_logger: - one_logger.log_metrics({ - 'train_iterations_warmup': 5 - }) - - if self.args.stage == "ray_online_dpo": - self.args.micro_batch_size *= 2 - self.timers('model-setup', log_level=0).start(barrier=True) - - model = get_model(self.model_provider, self.model_type, wrap_with_ddp=False) - unwrapped_model = unwrap_model(model) - - if self.args.load is not None or self.args.pretrained_checkpoint is not None: - self.timers('load-checkpoint', log_level=0).start(barrier=True) - self.args.iteration, self.args.num_floating_point_operations_so_far = load_checkpoint( - model, None, None, strict=True) - self.timers('load-checkpoint').stop(barrier=True) - self.timers.log(['load-checkpoint']) - else: - self.args.iteration = 0 - self.args.num_floating_point_operations_so_far = 0 - - # get model without FP16 and/or DDP wrappers - if self.args.iteration == 0 and len(unwrapped_model) == 1 \ - and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'): - print_rank_0("Initializing ICT from pretrained BERT model") - unwrapped_model[0].init_state_dict_from_bert() - - self.timers('model-and-optimizer-setup').stop() - print_datetime('after model built') - config = get_model_config(model[0]) - - # Print setup timing. - self.train_args = [self.forward_step, model, config] - self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer_name_or_path) - - def compute_rm_score(self, data: DataProto): - args = get_args() - forward_step_func, model, config = self.train_args - prompt_lens = data.batch["prompts"].size(1) - - for model_module in model: - model_module.eval() - - with torch.no_grad(): - batches = split_dict_tensor_into_batches(data.batch, batch_size=args.micro_batch_size) - n_micro_batch = len(batches) - - batch_generator = iter(batches) - forward_backward_func = get_forward_backward_func() - - rm_score = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=batch_generator, - model=model, - num_microbatches=n_micro_batch, - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, - collect_non_loss_data=True, - forward_only=True - ) - - # Empty unused memory - if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() - if mpu.is_pipeline_last_stage(): - rm_score = torch.cat(rm_score, dim=0).squeeze(-1) # (bs, seq_size) - rm_score = rm_score.to(torch.float32) - rm_score = rm_score[:, prompt_lens:] - else: - rm_score = torch.zeros(1) - - return rm_score - - def forward_step(self, data_iterator, model): - """ReWardModel forward step to calculate rm scores. - - Args: - data_iterator : Data iterator which wait to get input ids from Queue generated in Actor Server - model (GPTModel): The GPT Model - """ - self.timers('batch-generator', log_level=2).start() - input_ids, attention_mask, position_ids = self._get_tokens(data_iterator) - self.timers('batch-generator').stop() - - scores = model(input_ids, position_ids, attention_mask) - - return scores, self.loss_func - - def loss_func(self, scores: torch.Tensor, non_loss_data=False): - return scores - - def _get_tokens(self, data_iterator): - self.timers('batch-generator', log_level=2).start() - batch = next(data_iterator) - - if self.args.extract_content_for_reward: - str_responses = tokenizer.batch_decode(batch["responses"]) - pattern = r'(.*?)' - contents = [] - for str_response in str_responses: - first_pad_position = str_response.find(tokenizer.pad_token) - if first_pad_position != -1: - str_response = str_response[:first_pad_position] - within_answer = re.findall(pattern, str_response) - if within_answer: - content = within_answer[0] - else: - content = str_response - contents.append(tokenizer.encode(content)) - - responses_ori_length, responses_pad_length = pad_to_tensor_dict( - contents, - pad_multi_of=self.args.pad_to_multiple_of - ) - - prompts = batch["prompts"] - prompts_pad_length = torch.LongTensor([len(prompts[0])]).cuda() - pad_token_id = tokenizer.pad_token_id - prompts_ori_length = [len(prompts[i]) - (prompts[i] == pad_token_id).sum().item() for i in range(len(prompts))] - prompts = prompts.cpu().numpy().tolist() - - input_ids = [prompt + response for prompt, response in zip(prompts, contents)] - attention_mask = generate_attention_mask(input_ids, prompts_ori_length, prompts_pad_length, - responses_ori_length, responses_pad_length) - position_ids = generate_position_ids_from_attention_mask(input_ids, prompts_ori_length, prompts_pad_length) - - device = batch["input_ids"].device - input_ids = torch.tensor(input_ids).long().to(device) - attention_mask_1d = torch.tensor(attention_mask).long().to(device) - attention_mask = get_tune_attention_mask(attention_mask_1d) - position_ids = torch.tensor(position_ids).long().to(device) - - else: - input_ids = batch["input_ids"] - attention_mask_1d = batch["attention_mask"] - attention_mask = get_tune_attention_mask(attention_mask_1d) - position_ids = batch["position_ids"] - - return input_ids, attention_mask, position_ids diff --git a/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOEngine.py b/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOEngine.py index 45c6761f67892db39f9698bd61397a68911536e2..6ec5c27183fe048324b7c7cf8aa9f10337ac9f18 100644 --- a/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOEngine.py +++ b/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOEngine.py @@ -22,7 +22,6 @@ from megatron.training.training import ( get_optimizer_param_scheduler, build_train_valid_test_data_iterators ) -from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import masked_mean, masked_whiten from mindspeed_llm.training.utils import get_tune_attention_mask from mindspeed_llm.tasks.posttrain.utils import train_valid_test_datasets_provider from mindspeed_llm.tasks.posttrain.trl_ppo.actor_model import ActorModel @@ -316,3 +315,36 @@ class TrlPPOEngine(): def set_model_train(self, model): for model_module in model: model_module.train() + + +def masked_mean(values, mask, axis=None): + """Compute mean of tensor with a masked values.""" + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + + +def masked_var(values, mask, unbiased=True): + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values ** 2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError("At least one element in the mask has to be 1.") + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + elif mask_sum == 1: + bessel_correction = mask_sum + else: + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values, mask, shift_mean=True): + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened diff --git a/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOTrainer.py b/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOTrainer.py index 23e145a86d737a5df78c9b000b42cc342ed79b62..ca6e57ce5afa1082513bdd3f4a3d577cdc8e38c1 100644 --- a/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOTrainer.py +++ b/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOTrainer.py @@ -25,7 +25,6 @@ from megatron.training.training import ( training_log ) from megatron.inference.text_generation.communication import broadcast_from_last_pipeline_stage -from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import generate_attention_mask, generate_position_ids_from_attention_mask from mindspeed_llm.tasks.posttrain.trl_ppo.TrlPPOEngine import TrlPPOEngine from mindspeed_llm.training.training import get_profiler, is_profile_enabled @@ -601,3 +600,53 @@ class TrlPPOTrainer(): args.micro_batch_size = origin_micro_batch_size return output_tensor + + +def generate_position_ids_from_attention_mask(input_ids_list, prompts_ori_length, prompts_pad_length): + """ + 生成与 attention_mask 对应的 position_ids 列表。 + + 参数: + input_ids_list (list of lists): 包含 input_ids 的列表,每个元素是一个列表。 + prompts_ori_length (list of lists): 包含 prompt_ori_length 的列表,每个元素是int。 + prompts_pad_length int: prompts_pad_length,int。 + + 返回: + list of lists: 包含 position_ids 的列表,每个元素是一个列表。 + """ + position_ids_list = [] + for idx, input_ids in enumerate(input_ids_list): + prompt_pad_length = prompts_pad_length - prompts_ori_length[idx] + position_ids = [0] * prompt_pad_length + list(range(len(input_ids) - prompt_pad_length)) + position_ids_list.append(position_ids) + + return position_ids_list + + +def generate_attention_mask(input_ids_list, prompts_ori_length, prompts_pad_length, responses_ori_length, + responses_pad_length): + """ + 生成与 input_ids 对应的 attention_mask 列表。 + + 参数: + input_ids_list (list of lists): 包含 input_ids 的列表,每个元素是一个列表。 + prompts_ori_length (list of lists): 包含 prompt_ori_length 的列表,每个元素是int。 + prompts_pad_length int: prompts_pad_length,int。 + responses_ori_length (list of lists): 包含 response_ori_length 的列表,每个元素是int。 + responses_pad_length int: responses_pad_length,int。 + + 返回: + list of lists: 包含 attention_mask 的列表,每个元素是一个列表。 + """ + attention_mask_list = [] + + for idx, input_ids in enumerate(input_ids_list): + attention_mask = torch.ones_like(torch.tensor(input_ids)) + prompt_pad_length = prompts_pad_length - prompts_ori_length[idx] + response_pad_length = responses_pad_length - responses_ori_length[idx] + attention_mask[:prompt_pad_length] = 0 + if response_pad_length > 0: + attention_mask[-response_pad_length:] = 0 + attention_mask_list.append(attention_mask.numpy().tolist()) + + return attention_mask_list diff --git a/ray_gpt.py b/ray_gpt.py deleted file mode 100644 index 70c2a5b64fc32608c832e1db208d2c1d9bcbd199..0000000000000000000000000000000000000000 --- a/ray_gpt.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. -""" -import ray -import hydra - -from mindspeed_llm.tasks.posttrain.launcher import get_trainer - - -@hydra.main(config_path='configs/rlxf', config_name='ppo_trainer_llama32_1b', version_base=None) -def main(config): - if not ray.is_initialized(): - # this is for local ray cluster - ray.init(runtime_env={ - 'env_vars': {"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "True", - 'TOKENIZERS_PARALLELISM': 'true', - 'NCCL_DEBUG': 'WARN'}}) - - ray.get(main_task.remote(config)) - - -@ray.remote -def main_task(config): - trainer = get_trainer(config.training.stage)(config) - trainer.train() - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ed68f4313bae7031fdc166b63b2f379137f59cbc..534cfc296b84375e3fa5dfb23e029c72859528b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,8 +14,6 @@ protobuf peft==0.7.1 tiktoken ray==2.10.0 -tensordict==0.1.2 -hydra-core==1.3.2 codetiming bitsandbytes-npu-beta==0.45.3 word2number diff --git a/tests/rlxf/ray_grpo_full_llama32_1b_tp1pp1.sh b/tests/rlxf/ray_grpo_full_llama32_1b_tp1pp1.sh deleted file mode 100644 index cde21233b0b57fc4a5b861d2dab0c73a658ab2e0..0000000000000000000000000000000000000000 --- a/tests/rlxf/ray_grpo_full_llama32_1b_tp1pp1.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export HCCL_DETERMINISTIC=True - - -basepath=$(cd `dirname $0`; cd ../../../; pwd) - - -python $basepath/ray_gpt.py --config-dir=$basepath/tests/pipeline/configs --config-name=ray_grpo_full_llama32_1b_tp1pp1 \ No newline at end of file diff --git a/tests/rlxf/ray_online_dpo_full_llama32_1b_tp1pp1.sh b/tests/rlxf/ray_online_dpo_full_llama32_1b_tp1pp1.sh deleted file mode 100644 index eff76dc8532438bf8244cf44fcfab71872192ea6..0000000000000000000000000000000000000000 --- a/tests/rlxf/ray_online_dpo_full_llama32_1b_tp1pp1.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export HCCL_DETERMINISTIC=True - - -basepath=$(cd `dirname $0`; cd ../../../; pwd) - - -python $basepath/ray_gpt.py --config-dir=$basepath/tests/pipeline/configs --config-name=ray_online_dpo_full_llama32_1b_tp1pp1 \ No newline at end of file diff --git a/tests/rlxf/ray_ppo_full_llama32_1b_tp1pp1.sh b/tests/rlxf/ray_ppo_full_llama32_1b_tp1pp1.sh deleted file mode 100644 index 80bcd00c32f717f113187da84bbf5be65accc50b..0000000000000000000000000000000000000000 --- a/tests/rlxf/ray_ppo_full_llama32_1b_tp1pp1.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export HCCL_DETERMINISTIC=True - - -basepath=$(cd `dirname $0`; cd ../../../; pwd) - - -python $basepath/ray_gpt.py --config-dir=$basepath/tests/pipeline/configs --config-name=ray_ppo_full_llama32_1b_tp1pp1 \ No newline at end of file diff --git a/tests/st/configs/model/llama32-1b.yaml b/tests/st/configs/model/llama32-1b.yaml deleted file mode 100644 index 324548e026623607a47cfb6c480c2376496626a9..0000000000000000000000000000000000000000 --- a/tests/st/configs/model/llama32-1b.yaml +++ /dev/null @@ -1,35 +0,0 @@ -llama32-1b: - use_mcore_models: true - sequence_parallel: true - use_mc2: true - use_flash_attn: true - use_rotary_position_embeddings: true - use_fused_rmsnorm: true - use_fused_swiglu: true - rope_scaling_type: llama3 - rope_scaling_factor: 32.0 - low_freq_factor: 1.0 - high_freq_factor: 4.0 - original_max_position_embeddings: 8192 - max_position_embeddings: 8192 - num_layers: 16 - hidden_size: 2048 - ffn_hidden_size: 8192 - num_attention_heads: 32 - group_query_attention: true - num_query_groups: 8 - make_vocab_size_divisible_by: 1 - padded_vocab_size: 128256 - disable_bias_linear: true - attention_dropout: 0.0 - init_method_std: 0.01 - hidden_dropout: 0.0 - position_embedding_type: rope - rotary_base: 500000 - normalization: RMSNorm - norm_epsilon: 1e-5 - swiglu: true - no_masked_softmax_fusion: true - attention_softmax_in_fp32: true - no_gradient_accumulation_fusion: true - bf16: true \ No newline at end of file