diff --git a/configs/rlxf/model/llama2-7b.yaml b/configs/rlxf/model/llama2-7b.yaml
deleted file mode 100644
index 8a84f722589705f5854c26837a7a467ef1d08f45..0000000000000000000000000000000000000000
--- a/configs/rlxf/model/llama2-7b.yaml
+++ /dev/null
@@ -1,27 +0,0 @@
-llama2-7b:
- use_mcore_models: true
- sequence_parallel: true
- num_layers: 32
- hidden_size: 4096
- ffn_hidden_size: 11008
- num_attention_heads: 32
- seq_length: 4096
- max_position_embeddings: 4096
- make_vocab_size_divisible_by: 1
- untie_embeddings_and_output_weights: true
- disable_bias_linear: true
- attention_dropout: 0.0
- init_method_std: 0.01
- hidden_dropout: 0.0
- position_embedding_type: rope
- normalization: RMSNorm
- use_fused_rmsnorm: true
- swiglu: true
- use_flash_attn: true
- use_mc2: true
- no_masked_softmax_fusion: true
- attention_softmax_in_fp32: true
- no_gradient_accumulation_fusion: true
- use_fused_swiglu: true
- use_fused_rotary_pos_emb: true
- bf16: true
\ No newline at end of file
diff --git a/configs/rlxf/model/llama3-8b.yaml b/configs/rlxf/model/llama3-8b.yaml
deleted file mode 100644
index b6d54851bdcba69528f38a3908dcb2972139633f..0000000000000000000000000000000000000000
--- a/configs/rlxf/model/llama3-8b.yaml
+++ /dev/null
@@ -1,36 +0,0 @@
-llama3-8b:
- use_mcore_models: true
- sequence_parallel: true
- use_flash_attn: true
- use_rotary_position_embeddings: true
- rope_scaling_type: llama3
- rope_scaling_factor: 8.0
- rotary_percent: 1.0
- low_freq_factor: 1.0
- high_freq_factor: 4.0
- original_max_position_embeddings: 8192
- num_layers: 32
- hidden_size: 4096
- ffn_hidden_size: 14336
- num_attention_heads: 32
- group_query_attention: true
- num_query_groups: 8
- max_position_embeddings: 8192
- make_vocab_size_divisible_by: 1
- padded_vocab_size: 128256
- untie_embeddings_and_output_weights: true
- disable_bias_linear: true
- attention_dropout: 0.0
- init_method_std: 0.02
- hidden_dropout: 0.0
- position_embedding_type: rope
- rotary_base: 500000
- normalization: RMSNorm
- norm_epsilon: 1e-5
- swiglu: true
- no_masked_softmax_fusion: true
- attention_softmax_in_fp32: true
- no_gradient_accumulation_fusion: true
- bf16: true
- seed: 42
- vocab_size: 128256
\ No newline at end of file
diff --git a/configs/rlxf/model/llama32-1b.yaml b/configs/rlxf/model/llama32-1b.yaml
deleted file mode 100644
index 324548e026623607a47cfb6c480c2376496626a9..0000000000000000000000000000000000000000
--- a/configs/rlxf/model/llama32-1b.yaml
+++ /dev/null
@@ -1,35 +0,0 @@
-llama32-1b:
- use_mcore_models: true
- sequence_parallel: true
- use_mc2: true
- use_flash_attn: true
- use_rotary_position_embeddings: true
- use_fused_rmsnorm: true
- use_fused_swiglu: true
- rope_scaling_type: llama3
- rope_scaling_factor: 32.0
- low_freq_factor: 1.0
- high_freq_factor: 4.0
- original_max_position_embeddings: 8192
- max_position_embeddings: 8192
- num_layers: 16
- hidden_size: 2048
- ffn_hidden_size: 8192
- num_attention_heads: 32
- group_query_attention: true
- num_query_groups: 8
- make_vocab_size_divisible_by: 1
- padded_vocab_size: 128256
- disable_bias_linear: true
- attention_dropout: 0.0
- init_method_std: 0.01
- hidden_dropout: 0.0
- position_embedding_type: rope
- rotary_base: 500000
- normalization: RMSNorm
- norm_epsilon: 1e-5
- swiglu: true
- no_masked_softmax_fusion: true
- attention_softmax_in_fp32: true
- no_gradient_accumulation_fusion: true
- bf16: true
\ No newline at end of file
diff --git a/configs/rlxf/model/qwen25-7b.yaml b/configs/rlxf/model/qwen25-7b.yaml
deleted file mode 100644
index b5eb4b00e7601bcb6c9a6fbe9c46e1836428c643..0000000000000000000000000000000000000000
--- a/configs/rlxf/model/qwen25-7b.yaml
+++ /dev/null
@@ -1,33 +0,0 @@
-qwen25-7b:
- use_mcore_models: true
- num_layers: 28
- hidden_size: 3584
- ffn_hidden_size: 18944
- num_attention_heads: 28
- seq_length: 4096
- rotary_base: 1000000
- max_position_embeddings: 32768
- make_vocab_size_divisible_by: 1
- padded_vocab_size: 152064
- untie_embeddings_and_output_weights: true
- add_qkv_bias: true
- disable_bias_linear: true
- group_query_attention: true
- num_query_groups: 4
- attention_dropout: 0.0
- init_method_std: 0.01
- hidden_dropout: 0.0
- adam_beta1: 0.9
- adam_beta2: 0.95
- position_embedding_type: rope
- normalization: RMSNorm
- use_fused_rmsnorm: true
- swiglu: true
- use_mc2: true
- use_flash_attn: true
- no_masked_softmax_fusion: true
- attention_softmax_in_fp32: true
- no_gradient_accumulation_fusion: true
- use_fused_swiglu: true
- use_fused_rotary_pos_emb: true
- bf16: true
\ No newline at end of file
diff --git a/configs/rlxf/online_dpo_trainer_llama32_1b.yaml b/configs/rlxf/online_dpo_trainer_llama32_1b.yaml
deleted file mode 100644
index 49fb3fbceaa80a998f811a3f3f9866698074faac..0000000000000000000000000000000000000000
--- a/configs/rlxf/online_dpo_trainer_llama32_1b.yaml
+++ /dev/null
@@ -1,86 +0,0 @@
-defaults:
- - model:
- - llama32-1b
-
-training:
- global_batch_size: 4
- seq_length: 512
- tokenizer_type: PretrainedFromHF
- tokenizer_name_or_path: ./models/llama-3.2-1b-instruct/
- train_iters: 1000
- distributed_backend: nccl
- no_shared_storage: true
- save_interval: 10000
- no_load_optim: true
- no_load_rng: true
- bf16: true
- is_instruction_dataset: true
- variable_seq_lengths: true
- stage: ray_online_dpo
-
-
-actor_rollout_ref:
- actor_rollout:
- model: llama32-1b
- micro_batch_size: 4
- ppo_mini_batch_size: 4
- max_prompt_length: 256
- ppo_epochs: 1
- clip_ratio: 0.2
- entropy_coeff: 0.001
- do_sample: true
- shuffle: false
- use_kv_cache: true
- num_samples_per_step: 4
- tensor_model_parallel_size: 1
- pipeline_model_parallel_size: 1
- lr: 1e-7
- lr_decay_style: constant
- min_lr: 0.0
- weight_decay: 0.0
- lr_warmup_fraction: 0.0
- clip_grad: 10000.0
- adam_beta1: 0.9
- adam_beta2: 0.999
- initial_loss_scale: 4096
- finetune: true
- load: ./models/llama-3.2-1b-instruct-tp1-pp1
- save: ./ckpt
- num_gpus_for_train: 1
- num_gpus_for_infer: 1
- pad_to_multiple_of: 1
- inference_tensor_model_parallel_size: 1
- data_path: ./dataset/descriptiveness/descriptiveness
- split: 100,0,0
- no_shuffle: true
- missing_eos_penalty: 1.0
-
- ref:
- model: llama32-1b
- tensor_model_parallel_size: 1
- pipeline_model_parallel_size: 1
- micro_batch_size: 4
- load: ./models/llama-3.2-1b-instruct-tp1-pp1
-
-reward:
- model: llama32-1b
- tensor_model_parallel_size: 1
- pipeline_model_parallel_size: 1
- micro_batch_size: 4
- sequence_parallel: False
- load: ./models/llama-3.2-1b-rm-mcore-tp1-pp1
-
-algorithm:
- gamma: 1.0
- lam: 0.95
- adv_estimator: gae
- kl_penalty: kl
- kl_ctrl:
- type: fixed
- kl_coef: 0.05
- missing_eos_penalty: 1.0
-
-resource_pool:
- actor_rollout: [2]
- ref: [1]
- reward: [1]
diff --git a/configs/rlxf/ppo_trainer_llama2_7b.yaml b/configs/rlxf/ppo_trainer_llama2_7b.yaml
deleted file mode 100644
index 1dd1788100f68f7336285a8305b2e7522e5b4cf2..0000000000000000000000000000000000000000
--- a/configs/rlxf/ppo_trainer_llama2_7b.yaml
+++ /dev/null
@@ -1,108 +0,0 @@
-defaults:
- - model:
- - llama2-7b
-
-training:
- global_batch_size: 1
- seq_length: 512
- tokenizer_type: PretrainedFromHF
- tokenizer_name_or_path: ./models/llama-2-7b/
- train_iters: 1000
- distributed_backend: nccl
- no_shared_storage: true
- save_interval: 10000
- no_load_optim: true
- no_load_rng: true
- bf16: true
- is_instruction_dataset: true
- variable_seq_lengths: true
- no_shuffle: true
- stage: ray_ppo
- sequence_parallel: False
-
-actor_rollout_ref:
- actor_rollout:
- model: llama2-7b
- do_sample: false
- micro_batch_size: 1
- ppo_mini_batch_size: 1
- num_samples_per_step: 1
- max_prompt_length: 256
- ppo_epochs: 1
- clip_ratio: 0.2
- entropy_coeff: 0.001
- shuffle_minibatch: false
- use_kv_cache: true
- tensor_model_parallel_size: 4
- pipeline_model_parallel_size: 1
- lr: 1e-7
- lr_decay_style: constant
- min_lr: 0.0
- weight_decay: 0.0
- lr_warmup_fraction: 0.0
- clip_grad: 1.0
- adam_beta1: 0.9
- adam_beta2: 0.95
- initial_loss_scale: 1
- finetune: true
- load: ./ckpt
- save: ./ckpt
- num_gpus_for_train: 4
- num_gpus_for_infer: 4
- pad_to_multiple_of: 1
- data_path: ./dataset/descriptiveness/descriptiveness
- split: 100,0,0
-
- ref:
- model: llama2-7b
- tensor_model_parallel_size: 2
- pipeline_model_parallel_size: 1
- micro_batch_size: 1
- load: ./ckpt
-
-critic:
- model: llama2-7b
- tensor_model_parallel_size: 4
- pipeline_model_parallel_size: 1
- use_mcore_models: True
- micro_batch_size: 1
- lr: 1e-7
- lr_decay_style: constant
- min_lr: 0.0
- weight_decay: 0.0
- lr_warmup_fraction: 0.0
- use_distributed_optimizer: true
- clip_grad: 1.0
- adam_beta1: 0.9
- adam_beta2: 0.95
- initial_loss_scale: 1
- no_load_optim: True
- no_load_rng: True
- load: ./ckpt
- save: ./ckpt
- cliprange_value: 0.2
- critic_mini_batch_size: 1
- critic_update_epochs: 1
-
-reward:
- model: llama2-7b
- tensor_model_parallel_size: 2
- pipeline_model_parallel_size: 1
- micro_batch_size: 1
- load: ./ckpt
-
-algorithm:
- gamma: 1.0
- lam: 0.95
- adv_estimator: gae
- kl_penalty: kl
- kl_ctrl:
- type: fixed
- kl_coef: 0.05
- missing_eos_penalty: 0.0
-
-resource_pool:
- actor_rollout: [8]
- ref: [2]
- critic: [4]
- reward: [2]
\ No newline at end of file
diff --git a/configs/rlxf/ppo_trainer_llama32_1b.yaml b/configs/rlxf/ppo_trainer_llama32_1b.yaml
deleted file mode 100644
index 5f1d3c16c953f4b11ad13c525dc86414dbacf140..0000000000000000000000000000000000000000
--- a/configs/rlxf/ppo_trainer_llama32_1b.yaml
+++ /dev/null
@@ -1,111 +0,0 @@
-defaults:
- - model:
- - llama32-1b
-
-training:
- global_batch_size: 8
- seq_length: 512
- tokenizer_type: PretrainedFromHF
- tokenizer_name_or_path: ./models/llama-3.2-1b-instruct/
- train_iters: 1000
- distributed_backend: nccl
- no_shared_storage: true
- save_interval: 10000
- no_load_optim: true
- no_load_rng: true
- bf16: true
- is_instruction_dataset: true
- variable_seq_lengths: true
- no_shuffle: true
- stage: ray_ppo
-
-actor_rollout_ref:
- actor_rollout:
- model: llama32-1b
- do_sample: false
- micro_batch_size: 4
- ppo_mini_batch_size: 4
- num_samples_per_step: 2
- max_prompt_length: 256
- ppo_epochs: 1
- clip_ratio: 0.2
- entropy_coeff: 0.001
- shuffle_minibatch: false
- use_kv_cache: true
- tensor_model_parallel_size: 1
- pipeline_model_parallel_size: 1
- lr: 1e-7
- lr_decay_style: constant
- min_lr: 0.0
- weight_decay: 0.0
- lr_warmup_fraction: 0.0
- clip_grad: 10000.0
- adam_beta1: 0.9
- adam_beta2: 0.999
- initial_loss_scale: 4096
- finetune: true
- load: ./ckpt
- save: ./ckpt
- num_gpus_for_train: 1
- num_gpus_for_infer: 1
- pad_to_multiple_of: 1
- data_path: ./dataset/descriptiveness/descriptiveness
- split: 100,0,0
-
- ref:
- model: llama32-1b
- tensor_model_parallel_size: 1
- pipeline_model_parallel_size: 1
- micro_batch_size: 8
- load: ./ckpt
-
-critic:
- model: llama32-1b
- tensor_model_parallel_size: 1
- pipeline_model_parallel_size: 1
- use_mcore_models: True
- micro_batch_size: 4
- sequence_parallel: False
- lr: 1e-7
- lr_decay_style: constant
- min_lr: 0.0
- weight_decay: 0.0
- lr_warmup_fraction: 0.0
- use_distributed_optimizer: true
- clip_grad: 10000.0
- adam_beta1: 0.9
- adam_beta2: 0.999
- initial_loss_scale: 4096
- no_load_optim: True
- no_load_rng: True
- is_instruction_dataset: true
- variable_seq_lengths: true
- load: ./ckpt
- save: ./ckpt
- cliprange_value: 0.2
- critic_mini_batch_size: 4
- critic_update_epochs: 1
-
-reward:
- model: llama32-1b
- tensor_model_parallel_size: 1
- pipeline_model_parallel_size: 1
- micro_batch_size: 8
- sequence_parallel: false
- load: ./ckpt
-
-algorithm:
- gamma: 1.0
- lam: 0.95
- adv_estimator: gae
- kl_penalty: kl
- kl_ctrl:
- type: fixed
- kl_coef: 0.05
- missing_eos_penalty: 0.0
-
-resource_pool:
- actor_rollout: [2]
- ref: [1]
- critic: [1]
- reward: [1]
\ No newline at end of file
diff --git a/configs/rlxf/ppo_trainer_llama3_8b.yaml b/configs/rlxf/ppo_trainer_llama3_8b.yaml
deleted file mode 100644
index d5e37bc60f813b1d4e963781290fd946b5424e07..0000000000000000000000000000000000000000
--- a/configs/rlxf/ppo_trainer_llama3_8b.yaml
+++ /dev/null
@@ -1,108 +0,0 @@
-defaults:
- - model:
- - llama3-8b
-
-training:
- global_batch_size: 2
- seq_length: 512
- tokenizer_type: PretrainedFromHF
- tokenizer_name_or_path: ./models/llama-3-8b/
- train_iters: 1000
- distributed_backend: nccl
- no_shared_storage: true
- save_interval: 10000
- no_load_optim: true
- no_load_rng: true
- bf16: true
- is_instruction_dataset: true
- variable_seq_lengths: true
- no_shuffle: true
- stage: ray_ppo
- sequence_parallel: False
-
-actor_rollout_ref:
- actor_rollout:
- model: llama3-8b
- do_sample: false
- micro_batch_size: 1
- ppo_mini_batch_size: 1
- num_samples_per_step: 1
- max_prompt_length: 256
- ppo_epochs: 1
- clip_ratio: 0.2
- entropy_coeff: 0.001
- shuffle_minibatch: false
- use_kv_cache: true
- tensor_model_parallel_size: 2
- pipeline_model_parallel_size: 2
- lr: 1e-7
- lr_decay_style: constant
- min_lr: 0.0
- weight_decay: 0.0
- lr_warmup_fraction: 0.0
- clip_grad: 1.0
- adam_beta1: 0.9
- adam_beta2: 0.999
- initial_loss_scale: 4096
- finetune: true
- load: ./ckpt
- save: ./ckpt
- num_gpus_for_train: 4
- num_gpus_for_infer: 4
- pad_to_multiple_of: 1
- data_path: ./dataset/descriptiveness/descriptiveness
- split: 100,0,0
-
- ref:
- model: llama3-8b
- tensor_model_parallel_size: 2
- pipeline_model_parallel_size: 1
- micro_batch_size: 1
- load: ./ckpt
-
-critic:
- model: llama3-8b
- tensor_model_parallel_size: 2
- pipeline_model_parallel_size: 2
- use_mcore_models: True
- micro_batch_size: 1
- lr: 1e-7
- lr_decay_style: constant
- min_lr: 0.0
- weight_decay: 0.0
- lr_warmup_fraction: 0.0
- use_distributed_optimizer: true
- clip_grad: 1.0
- adam_beta1: 0.9
- adam_beta2: 0.999
- initial_loss_scale: 1
- no_load_optim: True
- no_load_rng: True
- load: ./ckpt
- save: ./ckpt
- cliprange_value: 0.2
- critic_mini_batch_size: 1
- critic_update_epochs: 1
-
-reward:
- model: llama3-8b
- tensor_model_parallel_size: 2
- pipeline_model_parallel_size: 1
- micro_batch_size: 1
- load: ./ckpt
-
-algorithm:
- gamma: 1.0
- lam: 0.95
- adv_estimator: gae
- kl_penalty: kl
- kl_ctrl:
- type: fixed
- kl_coef: 0.05
- missing_eos_penalty: 0.0
-
-resource_pool:
- actor_rollout: [8]
- ref: [2]
- critic: [4]
- reward: [2]
\ No newline at end of file
diff --git a/docs/pytorch/solutions/preference-alignment/ray_online_dpo.md b/docs/pytorch/solutions/preference-alignment/ray_online_dpo.md
deleted file mode 100644
index 3531230e48da36c4d6b859eead034fdb170bedfe..0000000000000000000000000000000000000000
--- a/docs/pytorch/solutions/preference-alignment/ray_online_dpo.md
+++ /dev/null
@@ -1,149 +0,0 @@
-# 后训练方法 Ray Online DPO
-
-Online Direct Preference Optimization (Online DPO) 是 Direct Preference Optimization (DPO) 的一种扩展或变体,旨在通过 在线学习 的方式进一步优化大型语言模型(LLMs)。DPO 是一种基于人类偏好数据的训练方法,而 Online DPO 则专注于在 动态、实时 的环境中使用偏好数据来持续改进模型。
-
-Online DPO方法中包含了三个模型:Actor,Reference,Reward。其中Actor/Reference模型是经过预训练和指令微调(Supervised Fine-Tuning,SFT)得到的大语言模型,Reward是训练得到的奖励模型。Online DPO 的训练目标是使得 Actor 模型的回答可以更加符合人类偏好。
-
-# 使用说明
-
-## 环境配置
-
-配置MindSpeed-LLM基础环境: 参考[安装指南](../../../features/install_guide.md)
-
-## 数据预处理
-
-数据集转换参考脚本:MindSpeed-LLM/examples/mcore/llama3/data_convert_llama3_ppo.sh
-以 [descriptiveness 数据集](https://huggingface.co/datasets/trl-internal-testing/descriptiveness-sentiment-trl-style/tree/main/data) 为例。
-
-```bash
-source /usr/local/Ascend/ascend-toolkit/set_env.sh
-mkdir ./dataset/llama3-hf/
-
-python ./preprocess_data.py \
- --input ./dataset/descriptiveness-00000-of-00001.parquet \
- --tokenizer-name-or-path ./model_from_hf/llama3-hf/ \
- --output-prefix ./dataset/llama3-hf/descriptiveness \
- --workers 16 \
- --log-interval 1000 \
- --tokenizer-type PretrainedFromHF \
- --handler-name PPOAlpacaStyleInstructionHandler \
- --prompt-type llama3 \
- --map-keys '{"prompt":"prompt", "query":"", "response": "prompt", "system":""}'
-```
-
-## 模型权重转换
-
-根据 Online DPO 算法要求,Actor 和 Reference 模型应该使用 SFT 微调后的模型进行初始化,Critic 和 Reward 模型应该使用奖励模型训练后的模型进行初始化。DPO算法模型权重均使用Megatron-mcore格式,其他格式的权重需要进行模型权重转换,具体可参考[权重转换](../checkpoint_convert.md)。
-
-下面以llama3.2-1b模型作为示例参考:
-
-actor_rollout_ref 涉及到的actor_rollout 与 ref 均需要 SFT 微调后的模型,涉及到的权重转换操作与 SFT 阶段的一致。权重转换示例脚本:
-llama32-1b
-
-critic 与 reward 模型需要使用奖励模型训练后的模型,权重转换示例脚本:
llama32-1b-orm |
-
-
-相应的oneline_dpo_trainer_llama32_1b.yaml配置如下
-```
- actor_rollout_ref:
- actor_rollout:
- ...
- load: ./model_weights/llama32-mcore/
- save: ./model_weights/llama32-mcore-save/
-
- ref:
- ...
- load: ./model_weights/llama32-mcore/
-
- reward:
- ...
- load: ./model_weights/llama32-mcore-orm/
-```
-
-## 启动方式
-
-### 单机
-
-通过 --config-name 传递选取的 config 文件名(不添加.yaml后缀),可以通过下列命令直接启动训练(Llama32 1B 模型可单机运行)。
-目前已支持的配置文件放置在 configs/rlxf/ 文件夹下。配置文件的具体说明见下文。
-
-```bash
-python ray_gpt.py --config-name online_dpo_trainer_llama32_1b
-```
-
-### 多机
-
-多机运行程序时,需要首先进入对应目录,并激活conda或docker环境:
-
-```bash
-cd MindSpeed-LLM
-conda activate xxx
-```
-
-然后,在主节点上启动 Ray 集群:
-
-```bash
-# 创建一个集群,端口6344,dashboard端口8260,有8个NPU
-ray start --head --port 6344 --dashboard-host=0.0.0.0 --dashboard-port=8260 --resources='{"NPU": 8}'
-```
-
-随后,在其他节点加入主节点的集群
-
-```bash
-# IP_ADDRESS 处填写主节点 IP 地址
-ray start --address="IP_ADDRESS:6344" --resources='{"NPU": 8}'
-```
-
-在完成 Ray 集群构建后,在主节点启动运行程序即可(Llama3 8B 模型可双机运行)
-
-```bash
-python ray_gpt.py --config-name online_dpo_trainer_llama3_8b
-```
-结束之后,使用如下命令行结束 ray 进程
-```
-ray stop
-```
-## 配置文件
-
-由于 Online DPO 训练过程中涉及 3 个模型,通过将模型参数和训练配置解耦的层级化参数配置,来简化 Online DPO 训练的参数配置过程。RLXF 训练涉及到的所有配置文件均存储在 configs/rlxf 路径下,其中 model 文件夹下存储了模型结构相关的配置文件,Online DPO训练相关的模型参数文件以online_dpo_trainer_{模型名}.yaml方式命名。
-
-在每个 online_dpo_trainer 配置文件中,需要包含defaults,training,resource_pool,algorithm等字段,以及 Online DPO 训练过程中涉及到的 3 个角色 actor,reward,ref的配置。其中:
-
-1. defaults 负责引入模型配置文件,在 defaults 中应列举本配置文件中所需要用到的所有模型配置,模型配置可以在下方3个角色的具体配置中通过 model 字段进行选择。
-2. training 字段设置的参数为所有 3 个角色通用的默认参数,这些参数可以在下方进一步被角色的单独配置所覆盖。
-3. resource_pool 字段指定了各个角色所需的 NPU 资源数量。
-4. actor,reward,ref 字段分别指定了Online DPO算法中三个角色训练相关的参数配置。
-
-## 参数解析
-
-相较于普通模型训练,DPO增加一些特殊参数:
-
-### `training:`
-
-* `stage`:用于指定训练算法,使用 Ray Online DPO 训练须设置为`ray_online_dpo`;
-
-### `actor_rollout:`
-
-* `do_sample`:控制 Actor 模型进行推理时是否采样,默认为 False,Online DPO 需要设置为True ;
-* `ppo_mini_batch_size`:Actor 模型的 mini_batch_size,默认为1;
-* `max_prompt_length`:DPO 训练中最大 prompt 长度,默认为512;
-* `num_samples_per_step`:Actor 推理时每个step的推理样本数量,默认为1;
-* `ppo_epochs`:Actor 训练对同一批经验数据的重复次数,默认为1;
-* `clip_ratio`:Actor模型训练计算损失函数时的clip比例,默认为0.2 一般取值范围 [0.1,0.3] 最大取值范围[0,1] 该数值越大允许策略更新的幅度越大,反之不然;
-* `shuffle_minibatch`:Actor 训练时是否对 minibatch 进行 shuffle,默认为 False;
-* `num_gpus_for_train` :Actor 模型分配给训练部分的显卡数量;
-* `num_gpus_for_infer` :Actor 模型分配给推理部分的显卡数量;
-* `missing_eos_penalty`:缺少序列结束符EOS时的惩罚系数;
-
-### `resource_pool:`
-
-* `actor_rollout`:给 Actor 模型训练和推理总共分配的显卡数量;
-* `ref`:给 Reference 模型分配的显卡数量;
-* `reward`:给 Reward 模型分配的显卡数量;
-
-# 精度对比
-
-我们与 HuggingFace 的强化学习开源仓库 [TRL](https://github.com/huggingface/trl/) 进行了精度对比,来辅助验证算法实现的正确性。因为 Online DPO 1Q2A的特性需求,推理状态do sample 设置为 True,为了与基准方法进行精度对齐,在 Actor 推理时固定 responses 方式进行精度对齐的实验。可以看到,固定 responses 后 loss 能够较好地实现对齐。
-
-
-
diff --git a/docs/pytorch/solutions/preference-alignment/ray_ppo.md b/docs/pytorch/solutions/preference-alignment/ray_ppo.md
deleted file mode 100644
index c06947e14554f15913d30a112e6a0ebb873aee68..0000000000000000000000000000000000000000
--- a/docs/pytorch/solutions/preference-alignment/ray_ppo.md
+++ /dev/null
@@ -1,195 +0,0 @@
-# 后训练方法 Ray PPO
-
-[PPO(Proximal Policy Optimization)](https://arxiv.org/abs/1707.06347)是一种强化对齐微调方法,常用于人类反馈强化学习(Reinforcement Learning with Human Feedback, RLHF)任务。
-
-PPO方法中包含了四个模型:Actor,Critic,Reference,Reward。其中Actor/Reference模型是经过预训练和指令微调(Supervised Fine-Tuning,SFT)得到的大语言模型,Critic和Reward是训练得到的奖励模型。PPO 的训练目标是使得 Actor 模型的回答可以更加符合人类偏好。
-
-# 使用说明
-
-## 环境配置
-
-配置MindSpeed-LLM基础环境: 参考[安装指南](../../../features/install_guide.md)
-
-## 数据预处理
-
-数据集转换参考脚本:MindSpeed-LLM\examples\mcore\llama3\data_convert_llama3_ppo.sh
-以 [descriptiveness 数据集](https://huggingface.co/datasets/trl-internal-testing/descriptiveness-sentiment-trl-style/tree/main/data) 为例。
-
-```bash
-source /usr/local/Ascend/ascend-toolkit/set_env.sh
-mkdir ./dataset/llama3-hf/
-
-python ./preprocess_data.py \
- --input ./dataset/descriptiveness-00000-of-00001.parquet \
- --tokenizer-name-or-path ./model_from_hf/llama3-hf/ \
- --output-prefix ./dataset/llama3-hf/descriptiveness \
- --workers 16 \
- --log-interval 1000 \
- --tokenizer-type PretrainedFromHF \
- --handler-name PPOAlpacaStyleInstructionHandler \
- --prompt-type llama3 \
- --map-keys '{"prompt":"prompt", "query":"", "response": "prompt", "system":""}'
-```
-
-## 模型权重转换
-
-根据 PPO 算法要求,Actor 和 Reference 模型应该使用 SFT 微调后的模型进行初始化,Critic 和 Reward 模型应该使用奖励模型训练后的模型进行初始化。PPO算法模型权重均使用Megatron-mcore格式,其他格式的权重需要进行模型权重转换,具体可参考[权重转换](../checkpoint_convert.md)。
-
-下面以llama3.2-1b模型作为示例参考:
-
-actor_rollout_ref 涉及到的actor_rollout 与 ref 均需要 SFT 微调后的模型,涉及到的权重转换操作与 SFT 阶段的一致。权重转换示例脚本:
-llama32-1b
-
-critic 与 reward 模型需要使用奖励模型训练后的模型,权重转换示例脚本:llama32-1b-orm |
-
-
-相应的ppo_trainer_llama32_1b.yaml配置如下
-```
- actor_rollout_ref:
- actor_rollout:
- ...
- load: ./model_weights/llama32-mcore/
- save: ./model_weights/llama32-mcore-save/
-
- ref:
- ...
- load: ./model_weights/llama32-mcore/
-
- critic:
- ...
- load: ./model_weights/llama32-mcore-orm/
- save: ./model_weights/llama32-mcore-orm-save/
-
- reward:
- ...
- load: ./model_weights/llama32-mcore-orm/
-```
-
-## 启动方式
-
-### 单机
-
-通过 --config-name 传递选取的 config 文件名(不添加.yaml后缀),可以通过下列命令直接启动训练(Llama32 1B 模型可单机运行)。
-目前已支持的配置文件放置在 configs/rlxf/ 文件夹下。配置文件的具体说明见下文。
-
-```bash
-python ray_gpt.py --config-name ppo_trainer_llama32_1b
-```
-
-### 多机
-
-多机运行程序时,需要首先进入对应目录,并激活conda或docker环境:
-
-```bash
-cd MindSpeed-LLM
-conda activate xxx
-```
-
-然后,在主节点上启动 Ray 集群:
-
-```bash
-# 配置最大文件描述符环境变量
-ulimit -n 32768
-# 创建一个集群,端口6344,dashboard端口8260,有8个NPU
-ray start --head --port 6344 --dashboard-host=0.0.0.0 --dashboard-port=8260 --resources='{"NPU": 8}'
-```
-
-随后,在其他节点加入主节点的集群
-
-```bash
-# 配置最大文件描述符环境变量
-ulimit -n 32768
-# IP_ADDRESS 处填写主节点 IP 地址
-ray start --address="IP_ADDRESS:6344" --resources='{"NPU": 8}'
-```
-
-在完成 Ray 集群构建后,在主节点启动运行程序即可(Llama3 8B 模型可双机运行)
-
-```bash
-python ray_gpt.py --config-name ppo_trainer_llama3_8b
-```
-结束之后,使用如下命令行结束 ray 进程
-```
-ray stop
-```
-## 配置文件
-
-由于 PPO 训练过程中涉及 4 个模型,通过将模型参数和训练配置解耦的层级化参数配置,来简化 PPO 训练的参数配置过程。RLXF 训练涉及到的所有配置文件均存储在 configs/rlxf 路径下,其中 model 文件夹下存储了模型结构相关的配置文件,PPO训练相关的模型参数文件以ppo_trainer_{模型名}.yaml方式命名。
-
-在每个 ppo_trainer 配置文件中,需要包含defaults,training,resource_pool,algorithm等字段,以及 PPO 训练过程中涉及到的 4 个角色 actor,critic,reward,ref的配置。其中:
-
-1. defaults 负责引入模型配置文件,在 defaults 中应列举本配置文件中所需要用到的所有模型配置,模型配置可以在下方四个角色的具体配置中通过 model 字段进行选择。
-2. training 字段设置的参数为所有 4 个角色通用的默认参数,这些参数可以在下方进一步被角色的单独配置所覆盖。
-3. resource_pool 字段指定了各个角色所需的 NPU 资源数量。
-4. algorithm 字段配置计算PPO中advantages算法的相关参数。
-5. actor,critic,reward,ref 字段分别指定了PPO算法中四个角色训练相关的参数配置。
-
-## 参数解析
-
-相较于普通模型训练,PPO增加一些特殊参数:
-
-### `training:`
-
-* `stage`:用于指定训练算法,使用 Ray PPO 训练须设置为`ray_ppo`;
-
-### `actor_rollout:`
-
-* `do_sample`:控制 Actor 模型进行推理时是否采样,默认为 False ;
-* `ppo_mini_batch_size`:Actor 模型的 mini_batch_size,默认为1;
-* `max_prompt_length`:PPO 训练中最大 prompt 长度,默认为512;
-* `num_samples_per_step`:Actor 推理时每个step的推理样本数量,默认为1;
-* `ppo_epochs`:Actor 训练对同一批经验数据的重复次数,默认为1;
-* `clip_ratio`:Actor模型训练计算损失函数时的clip比例,默认为0.2 一般取值范围 [0.1,0.3] 最大取值范围[0,1] 该数值越大允许策略更新的幅度越大,反之不然;
-* `shuffle_minibatch`:Actor 训练时是否对 minibatch 进行 shuffle,默认为 False;
-* `num_gpus_for_train` :Actor 模型分配给训练部分的显卡数量;
-* `num_gpus_for_infer` :Actor 模型分配给推理部分的显卡数量;
-
-### `critic:`
-
-* `cliprange_value`:Critic 模型计算损失函数时 clip 范围,默认为0.2;
-* `critic_mini_batch_size`:Critic 模型设置的 mini_batch_size,默认为1;
-* `critic_update_epochs`:Critic 训练对同一批经验数据的重复次数,默认为1;
-
-### `algorithm:`
-
-* `adv_estimator`:advantages计算的方式,通常采用gae(广义优势估计Generalized Advantage Estimation, GAE);
-* `gamma`:计算 advantage 时的折扣因子,取值范围[0, 1]取值趋向于0表示侧重瞬时奖励,趋向于1表示趋向于延迟奖励;
-* `lam`:GAE 优势计算的 lambda 值,取值范围[0, 1]取值趋向于0会减少方差,提高收敛速度,但可能引入偏差。趋向于1会增大方差,降低收敛速度,偏差较小。取值为0等价于单步的TD误差,取值为1等价于蒙特卡洛返回;
-* `kl_penalty`:KL 散度计算方式;
-* `kl_ctrl:`
- * `kl_coef`:施加 KL 散度惩罚的系数,取值范围[0, 1]越大越会限制策略更新,越小则允许策略更新越大;
- * `type`:KL 散度惩罚的系数类型;
-* `missing_eos_penalty`:缺少序列结束符EOS时的惩罚系数;
-
-### `resource_pool:`
-
-* `actor_rollout`:给 Actor 模型训练和推理总共分配的显卡数量;
-* `ref`:给 Reference 模型分配的显卡数量;
-* `critic`:给 Critic 模型分配的显卡数量;
-* `reward`:给 Reward 模型分配的显卡数量;
-
-# 精度对比
-
-我们与 HuggingFace 的强化学习开源仓库 [TRL](https://github.com/huggingface/trl/) 进行了精度对比,来辅助验证算法实现的正确性。为了与基准方法进行精度对齐,在 Actor 推理时采用贪婪(greedy)策略去除随机性,训练过程中的 critic loss和 actor loss对比如下图所示。
-
-
-
- 未固定 responses 时 loss 对比图 (左为 actor loss,右为 critc loss)
-
-
-然而,由于 greedy 方法的策略为选取 logits 最大的 token,当如果两个 token 的 logits 值十分接近时,可能会导致选取的 token 的结果产生偏差。这种误差会被多次迭代逐步累计放大,最终影响到 loss 精度对齐。
-
-因此,我们额外补充了固定 responses 方式进行精度对齐的实验。可以看到,固定 responses 后 loss 能够较好地实现对齐。
-
-
-
- 固定 responses 后的 loss 对比图 (左为 actor loss,右为 critc loss)
-
-
-注: 为了验证 actor loss 的精度对齐效果,这里并未直接对比 PPO 算法中记录的 actor loss。这是由于 PPO 算法在计算 advantages 时,为保证算法的稳定性,会在 Actor 训练过程中在 minibatch 间做白化操作(将其分布的均值调整为0,方差调整为1)。这导致 Actor 虽然在进行梯度更新时使用每个 minibatch 计算的 loss,但记录下来的 minibatch 间的 loss 均值接近于 0 。因此,我们选择了记录 Actor 每个 minibatch loss 绝对值的均值,来验证精度对齐效果。
-
-
-# 参考文献
-
-[PPO](https://arxiv.org/abs/1707.06347)
-
diff --git a/mindspeed_llm/core/pipeline_parallel/p2p_communication.py b/mindspeed_llm/core/pipeline_parallel/p2p_communication.py
index 7ba04f58c2b16b7942a9aab9f059267f7d2815d2..079bf27cb286c500a8bb3bd1c66be69630fd7943 100644
--- a/mindspeed_llm/core/pipeline_parallel/p2p_communication.py
+++ b/mindspeed_llm/core/pipeline_parallel/p2p_communication.py
@@ -68,93 +68,3 @@ def _batched_p2p_ops(
else:
reqs = []
return reqs
-
-
-
-def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config):
- """
- Add group=get_pipeline_model_parallel_group() in P2POp shape communications to avoid error.
- Currently only enable in ray scenarios.
- """
-
- recv_prev_shape_tensor = None
- recv_next_shape_tensor = None
- send_prev_shape_tensor = None
- send_next_shape_tensor = None
- if recv_prev:
- recv_prev_shape_tensor = torch.empty(
- (3), device=torch.cuda.current_device(), dtype=torch.int64
- )
- if recv_next:
- recv_next_shape_tensor = torch.empty(
- (3), device=torch.cuda.current_device(), dtype=torch.int64
- )
- if tensor_send_prev is not None:
- send_prev_shape_tensor = torch.tensor(
- tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64
- )
- if tensor_send_next is not None:
- send_next_shape_tensor = torch.tensor(
- tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64
- )
-
- if config.use_ring_exchange_p2p:
- torch.distributed.ring_exchange(
- tensor_send_prev=send_prev_shape_tensor,
- tensor_recv_prev=recv_prev_shape_tensor,
- tensor_send_next=send_next_shape_tensor,
- tensor_recv_next=recv_next_shape_tensor,
- group=get_pipeline_model_parallel_group(),
- )
- else:
- ops = []
- if send_prev_shape_tensor is not None:
- send_prev_op = torch.distributed.P2POp(
- torch.distributed.isend,
- send_prev_shape_tensor,
- get_pipeline_model_parallel_prev_rank(),
- group=get_pipeline_model_parallel_group(),
- )
- ops.append(send_prev_op)
- if recv_prev_shape_tensor is not None:
- recv_prev_op = torch.distributed.P2POp(
- torch.distributed.irecv,
- recv_prev_shape_tensor,
- get_pipeline_model_parallel_prev_rank(),
- group=get_pipeline_model_parallel_group(),
- )
- ops.append(recv_prev_op)
- if send_next_shape_tensor is not None:
- send_next_op = torch.distributed.P2POp(
- torch.distributed.isend,
- send_next_shape_tensor,
- get_pipeline_model_parallel_next_rank(),
- group=get_pipeline_model_parallel_group(),
- )
- ops.append(send_next_op)
- if recv_next_shape_tensor is not None:
- recv_next_op = torch.distributed.P2POp(
- torch.distributed.irecv,
- recv_next_shape_tensor,
- get_pipeline_model_parallel_next_rank(),
- group=get_pipeline_model_parallel_group(),
- )
- ops.append(recv_next_op)
- if len(ops) > 0:
- reqs = torch.distributed.batch_isend_irecv(ops)
- for req in reqs:
- req.wait()
-
- # To protect against race condition when using batch_isend_irecv().
- # should take this out once the bug with batch_isend_irecv is resolved.
- torch.cuda.synchronize()
-
- recv_prev_shape = [0, 0, 0]
- if recv_prev_shape_tensor is not None:
- recv_prev_shape = recv_prev_shape_tensor.tolist()
-
- recv_next_shape = [0, 0, 0]
- if recv_next_shape_tensor is not None:
- recv_next_shape = recv_next_shape_tensor.tolist()
-
- return recv_prev_shape, recv_next_shape
diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py
index 9214ba6daaf3f95a917ed7e682492a78acc827d5..3048a4693dfcf7f4f5485835cfd4f3158de19bc7 100644
--- a/mindspeed_llm/tasks/megatron_adaptor.py
+++ b/mindspeed_llm/tasks/megatron_adaptor.py
@@ -18,7 +18,6 @@ import sys
import types
import argparse
import torch
-import tensordict
from torch_npu.contrib import transfer_to_npu
from mindspeed_llm.features_manager import FEATURES_LIST
@@ -763,7 +762,6 @@ class LegacyAdaptation(MegatronAdaptationABC):
self.patch_inference()
self.patch_log_handler()
self.patch_optimizer()
- self.patch_2megatron()
def patch_log_handler(self):
from megatron.training.log_handler import CustomHandler
@@ -929,24 +927,4 @@ class LegacyAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.training.setup_model_and_optimizer',
setup_model_and_optimizer_wrapper)
- def patch_2megatron(self):
- # This patch is only for running ray scenarios
- if "--raylet-name" in "".join(sys.argv):
- from mindspeed_llm.core.pipeline_parallel.p2p_communication import _communicate_shapes
- from mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state import (rank_generator_init_wrapper, rank_generator_get_ranks_wrapper)
- from mindspeed_llm.tasks.posttrain.rlxf.training.initialize import initialize_megatron
- from mindspeed_llm.tasks.posttrain.rlxf.training.initialize import barrier_wrapper, broadcast_wrapper, is_last_rank, \
- get_world_size_wrapper, get_elapsed_time_all_ranks
-
- MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._communicate_shapes', _communicate_shapes)
- MegatronAdaptation.register('megatron.core.parallel_state.RankGenerator.__init__', rank_generator_init_wrapper)
- MegatronAdaptation.register('megatron.core.parallel_state.RankGenerator.get_ranks', rank_generator_get_ranks_wrapper)
- MegatronAdaptation.register('megatron.training.utils.is_last_rank', is_last_rank)
- MegatronAdaptation.register('megatron.core.timers.Timers._get_elapsed_time_all_ranks', get_elapsed_time_all_ranks)
- MegatronAdaptation.register('torch.distributed.barrier', barrier_wrapper)
- MegatronAdaptation.register('torch.distributed.all_reduce', barrier_wrapper)
- MegatronAdaptation.register('torch.distributed.broadcast', broadcast_wrapper)
- MegatronAdaptation.register('torch.distributed.get_world_size', get_world_size_wrapper)
- MegatronAdaptation.register('megatron.training.initialize.initialize_megatron', initialize_megatron, force_patch=True)
-
MegatronAdaptation.execute()
\ No newline at end of file
diff --git a/mindspeed_llm/tasks/posttrain/launcher.py b/mindspeed_llm/tasks/posttrain/launcher.py
index fc0c6e6d15f38355f9f22c5fffc2f007e83b7b83..f3d1a15a00024992097265a5d92e49bda7d992a2 100644
--- a/mindspeed_llm/tasks/posttrain/launcher.py
+++ b/mindspeed_llm/tasks/posttrain/launcher.py
@@ -3,9 +3,6 @@ import logging
from megatron.training import get_args
from megatron.training.initialize import initialize_megatron
-from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.online_dpo_trainer import RayOnlineDPOTrainer
-from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.ppo_trainer import RayPPOTrainer
-from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.grpo_trainer import RayGRPOTrainer
from mindspeed_llm.tasks.posttrain.sft import SFTTrainer
from mindspeed_llm.tasks.posttrain.dpo import DPOTrainer
from mindspeed_llm.tasks.posttrain.orm import ORMTrainer
@@ -33,14 +30,8 @@ def get_trainer(stage):
return PRMTrainer()
elif stage == "simpo":
return SimPOTrainer()
- elif stage == "ray_ppo":
- return RayPPOTrainer
- elif stage == "ray_online_dpo":
- return RayOnlineDPOTrainer
elif stage == "trl_ppo":
return TrlPPOTrainer()
- elif stage == "ray_grpo":
- return RayGRPOTrainer
else:
logger.info(f'Unknown Stage: {stage}')
return None
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/__init__.py
deleted file mode 100644
index d3f0b85929e87e26ba235b74cb211ba2ac519154..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
\ No newline at end of file
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/__init__.py
deleted file mode 100644
index d3f0b85929e87e26ba235b74cb211ba2ac519154..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
\ No newline at end of file
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/grpo_trainer.py b/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/grpo_trainer.py
deleted file mode 100644
index 396a9dd34c25d250d51165e57c076336b7675d8a..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/grpo_trainer.py
+++ /dev/null
@@ -1,188 +0,0 @@
-from typing import Type
-from codetiming import Timer
-
-from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.ppo_trainer import ResourcePoolManager, Role
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.megatron import NVMegatronRayWorkerGroup
-from mindspeed_llm.tasks.posttrain.rlxf.training.core_algos import compute_grpo_data_metrics, reduce_metrics, \
- compute_advantage, compute_score, FixedKLController, AdaptiveKLController
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import Worker
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.base import create_colocated_worker_cls, \
- set_actor_infer_world_size, set_actor_train_world_size, RayClassWithInitArgs
-from mindspeed_llm.tasks.posttrain.rlxf.utils.loggers import Loggers
-from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import PPOActorWorker
-from mindspeed_llm.tasks.posttrain.rlxf.workers.reference import ReferenceWorker
-from mindspeed_llm.tasks.posttrain.rlxf.workers.reward import RewardWorker
-
-
-WorkerType = Type[Worker]
-
-
-class RayGRPOTrainer(object):
- """
- Note that this trainer runs on the driver process on a single CPU/GPU node.
- """
-
- def __init__(self, config):
-
- self.config = config
- if hasattr(self.config.training, "dataset_additional_keys"):
- self.config.training.dataset_additional_keys = config.training.dataset_additional_keys.strip().split(" ") if config.training.dataset_additional_keys else []
-
- self.role_worker_mapping = {
- Role.ActorRollout: PPOActorWorker,
- Role.RefPolicy: ReferenceWorker,
- Role.RewardModel: RewardWorker
- }
- actor_pool_id = 'actor_pool'
- ref_pool_id = 'ref_pool'
- reward_pool_id = 'reward_pool'
-
- if config.resource_pool.reward:
- resource_pool_spec = {
- actor_pool_id: config.resource_pool.actor_rollout,
- ref_pool_id: config.resource_pool.ref,
- reward_pool_id: config.resource_pool.reward,
- }
- mapping = {
- Role.ActorRollout: actor_pool_id,
- Role.RefPolicy: ref_pool_id,
- Role.RewardModel: reward_pool_id,
- }
- else:
- resource_pool_spec = {
- actor_pool_id: config.resource_pool.actor_rollout,
- ref_pool_id: config.resource_pool.ref
- }
-
- mapping = {
- Role.ActorRollout: actor_pool_id,
- Role.RefPolicy: ref_pool_id
- }
-
- self.resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
- self.use_reference_policy = Role.RefPolicy in self.role_worker_mapping
- self.ray_worker_group_cls = NVMegatronRayWorkerGroup
-
- # define KL control
- if self.use_reference_policy:
- if config.algorithm.kl_ctrl.type == 'fixed':
- self.kl_ctrl = FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef)
- elif config.algorithm.kl_ctrl.type == 'adaptive':
- if config.algorithm.kl_ctrl.horizon <= 0:
- raise ValueError(f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}')
- self.kl_ctrl = AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef,
- target_kl=config.algorithm.kl_ctrl.target_kl,
- horizon=config.algorithm.kl_ctrl.horizon)
- else:
- raise NotImplementedError
- else:
- self.kl_ctrl = FixedKLController(kl_coef=0.)
- self.init_workers()
-
- def init_workers(self):
- """Init resource pool and worker group"""
- set_actor_infer_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_infer)
- set_actor_train_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_train)
- self.resource_pool_manager.create_resource_pool()
- self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
-
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
- actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout],
- config=self.config,
- role='actor_rollout')
- self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
-
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
- ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],
- config=self.config,
- role='ref')
- self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls
-
- if self.config.resource_pool.reward:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
- reward_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.RewardModel],
- config=self.config,
- role='reward')
- self.resource_pool_to_cls[resource_pool]['reward'] = reward_cls
-
- # initialize WorkerGroup
- all_wg = {}
- for resource_pool, class_dict in self.resource_pool_to_cls.items():
- worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
- wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
- spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
- all_wg.update(spawn_wg)
-
- self.ref_policy_wg = all_wg.get('ref')
- self.ref_policy_wg.initialize()
-
- self.actor_rollout_wg = all_wg.get('actor_rollout')
- self.actor_rollout_wg.initialize()
-
- if self.config.resource_pool.reward:
- self.reward_wg = all_wg.get('reward')
- self.reward_wg.initialize()
- else:
- self.reward_wg = None
-
- def train(self):
- """
- The training loop of PPO.
- """
- logger = Loggers()
-
- iteration = self.actor_rollout_wg.get_iteration()[0]
-
- while iteration < self.config.training.train_iters:
- with Timer(name='gen', logger=None) as all_timer:
- metrics = {}
- self.actor_rollout_wg.auto_mapping()
- # generate a batch
- with Timer(name='gen', logger=None) as timer:
- batch = self.actor_rollout_wg.generate_sequences()
- batch = self.actor_rollout_wg.get_log_probs(batch)
- metrics['timing/gen'] = timer.last
-
- if self.use_reference_policy:
- # compute reference log_prob
- with Timer(name='ref', logger=None) as timer:
- ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
- batch = batch.union(ref_log_prob)
- metrics['timing/ref'] = timer.last
-
- with Timer(name='adv', logger=None) as timer:
- # compute rm scores.
- batch = compute_score(
- self.reward_wg,
- batch,
- metrics,
- self.config
- )
-
- # compute advantages, executed on the driver process
- batch = compute_advantage(
- batch,
- self.config
- )
-
- metrics['timing/adv'] = timer.last
- kl_info = {'kl_ctrl': self.kl_ctrl}
- batch.meta_info.update(kl_info)
-
- # update actor
- with Timer(name='update_actor', logger=None) as timer:
- actor_output = self.actor_rollout_wg.update_actor(batch)
- metrics['timing/update_actor'] = timer.last
- actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
- metrics.update(actor_output_metrics)
-
- # collect metrics
- data_metrics = compute_grpo_data_metrics(batch=batch)
- metrics.update(data_metrics)
- metrics['timing/all'] = all_timer.last
- iteration += 1
- logger.info(metrics, iteration, self.config.training.train_iters)
- logger.flush()
-
- if iteration % self.config.training.save_interval == 0:
- self.actor_rollout_wg.save_checkpoint(iteration)
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/online_dpo_trainer.py b/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/online_dpo_trainer.py
deleted file mode 100644
index 347e4919e72b6d5dc4669134cb74d26c94c6707f..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/online_dpo_trainer.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
-from typing import Type
-
-from codetiming import Timer
-
-from mindspeed_llm.tasks.posttrain.rlxf.ray_trainer.ppo_trainer import ResourcePoolManager, Role
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import Worker
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.base import create_colocated_worker_cls, \
- set_actor_infer_world_size, set_actor_train_world_size, RayClassWithInitArgs
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.megatron import NVMegatronRayWorkerGroup
-from mindspeed_llm.tasks.posttrain.rlxf.training.core_algos import compute_data_online_dpo_metrics, reduce_metrics
-from mindspeed_llm.tasks.posttrain.rlxf.utils.loggers import Loggers
-from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import PPOActorWorker
-from mindspeed_llm.tasks.posttrain.rlxf.workers.reference import ReferenceWorker
-from mindspeed_llm.tasks.posttrain.rlxf.workers.reward import RewardWorker
-
-WorkerType = Type[Worker]
-
-
-class RayOnlineDPOTrainer(object):
- """
- Note that this trainer runs on the driver process on a single CPU/GPU node.
- """
- def __init__(self, config):
- self.config = config
- self.role_worker_mapping = {
- Role.ActorRollout: PPOActorWorker,
- Role.RefPolicy: ReferenceWorker,
- Role.RewardModel: RewardWorker
- }
- actor_pool_id = 'actor_pool'
- ref_pool_id = 'ref_pool'
- reward_pool_id = 'reward_pool'
-
- resource_pool_spec = {
- actor_pool_id: config.resource_pool.actor_rollout,
- ref_pool_id: config.resource_pool.ref,
- reward_pool_id: config.resource_pool.reward,
- }
-
- mapping = {
- Role.ActorRollout: actor_pool_id,
- Role.RefPolicy: ref_pool_id,
- Role.RewardModel: reward_pool_id,
- }
-
- self.resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
- self.use_reference_policy = Role.RefPolicy in self.role_worker_mapping
- self.ray_worker_group_cls = NVMegatronRayWorkerGroup
- self.init_workers()
-
-
- def init_workers(self):
- """Init resource pool and worker group"""
- set_actor_infer_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_infer)
- set_actor_train_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_train)
- self.resource_pool_manager.create_resource_pool()
- self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
-
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
- actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout],
- config=self.config,
- role='actor_rollout')
- self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
-
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
- ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],
- config=self.config,
- role='ref')
- self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls
-
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
- reward_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.RewardModel],
- config=self.config,
- role='reward')
- self.resource_pool_to_cls[resource_pool]['reward'] = reward_cls
-
- # initialize WorkerGroup
- all_wg = {}
- for resource_pool, class_dict in self.resource_pool_to_cls.items():
- worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
- wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
- spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
- all_wg.update(spawn_wg)
-
- self.ref_policy_wg = all_wg.get('ref')
- self.ref_policy_wg.initialize()
-
- self.actor_rollout_wg = all_wg.get('actor_rollout')
- self.actor_rollout_wg.initialize()
-
- self.reward_wg = all_wg.get('reward')
- self.reward_wg.initialize()
-
-
- def train(self):
- """
- The training loop of online DPO.
- """
- logger = Loggers()
- iteration = self.actor_rollout_wg.get_iteration()[0]
- while iteration < self.config.training.train_iters:
- with Timer(name='gen', logger=None) as all_timer:
- metrics = {}
- self.actor_rollout_wg.auto_mapping()
- # generate a batch
- with Timer(name='gen', logger=None) as timer:
- batch = self.actor_rollout_wg.generate_sequences()
- metrics['timing/gen'] = timer.last
-
- if self.use_reference_policy:
- # compute reference log_prob
- with Timer(name='ref', logger=None) as timer:
- ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
- batch = batch.union(ref_log_prob)
- metrics['timing/ref'] = timer.last
-
- with Timer(name='adv', logger=None) as timer:
- # compute scores.
- reward_tensor = self.reward_wg.compute_rm_score(batch)
- batch = batch.union(reward_tensor)
- metrics['timing/adv'] = timer.last
-
- # update actor
- with Timer(name='update_actor', logger=None) as timer:
- actor_output = self.actor_rollout_wg.update_actor(batch)
- metrics['timing/update_actor'] = timer.last
- actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
- metrics.update(actor_output_metrics)
-
- # collect metrics
- data_metrics = compute_data_online_dpo_metrics(batch=batch)
- metrics.update(data_metrics)
- metrics['timing/all'] = all_timer.last
- iteration += 1
- logger.info(metrics, iteration, self.config.training.train_iters)
-
- if iteration % self.config.training.save_interval == 0:
- self.actor_rollout_wg.save_checkpoint(iteration)
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/ppo_trainer.py b/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/ppo_trainer.py
deleted file mode 100644
index 8c31b706c5ff664adcfa7520c665b89cbbd40b68..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/ray_trainer/ppo_trainer.py
+++ /dev/null
@@ -1,229 +0,0 @@
-from dataclasses import dataclass, field
-from enum import Enum
-from typing import Type, Dict, List
-
-from codetiming import Timer
-
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.megatron import NVMegatronRayWorkerGroup
-from mindspeed_llm.tasks.posttrain.rlxf.training.core_algos import compute_data_metrics, reduce_metrics, compute_advantage, \
- apply_kl_penalty, FixedKLController, AdaptiveKLController
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import Worker
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.base import create_colocated_worker_cls, \
- set_actor_infer_world_size, set_actor_train_world_size, RayResourcePool, RayClassWithInitArgs
-from mindspeed_llm.tasks.posttrain.rlxf.utils.loggers import Loggers
-from mindspeed_llm.tasks.posttrain.rlxf.workers.critic import CriticWorker
-from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import PPOActorWorker
-from mindspeed_llm.tasks.posttrain.rlxf.workers.reference import ReferenceWorker
-from mindspeed_llm.tasks.posttrain.rlxf.workers.reward import RewardWorker
-
-WorkerType = Type[Worker]
-
-
-class Role(Enum):
- """
- To create more roles dynamically, you can subclass Role and add new members
- """
- ActorRollout = 0
- Critic = 1
- RefPolicy = 2
- RewardModel = 3
-
-
-@dataclass
-class ResourcePoolManager:
- """
- Define a resource pool specification. Resource pool will be initialized first.
- Mapping
- """
- resource_pool_spec: Dict[str, List[int]]
- mapping: Dict[Role, str]
- resource_pool_dict: Dict[str, RayResourcePool] = field(default_factory=dict)
-
- def create_resource_pool(self):
- for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
- resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1)
- self.resource_pool_dict[resource_pool_name] = resource_pool
-
- def get_resource_pool(self, role: Role) -> RayResourcePool:
- """Get the resource pool of the worker_cls"""
- print(f"role:{role}, self.mapping[role]:{self.mapping[role]}")
- return self.resource_pool_dict[self.mapping[role]]
-
-
-class RayPPOTrainer(object):
- """
- Note that this trainer runs on the driver process on a single CPU/GPU node.
- """
-
- def __init__(self, config):
-
- self.config = config
- self.role_worker_mapping = {
- Role.ActorRollout: PPOActorWorker,
- Role.RefPolicy: ReferenceWorker,
- Role.Critic: CriticWorker,
- Role.RewardModel: RewardWorker
- }
- actor_pool_id = 'actor_pool'
- ref_pool_id = 'ref_pool'
- critic_pool_id = 'critic_pool'
- reward_pool_id = 'reward_pool'
-
- resource_pool_spec = {
- actor_pool_id: config.resource_pool.actor_rollout,
- ref_pool_id: config.resource_pool.ref,
- critic_pool_id: config.resource_pool.critic,
- reward_pool_id: config.resource_pool.reward,
- }
-
- mapping = {
- Role.ActorRollout: actor_pool_id,
- Role.RefPolicy: ref_pool_id,
- Role.Critic: critic_pool_id,
- Role.RewardModel: reward_pool_id,
- }
-
- self.resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
- self.use_reference_policy = Role.RefPolicy in self.role_worker_mapping
- self.ray_worker_group_cls = NVMegatronRayWorkerGroup
-
- # define KL control
- if self.use_reference_policy:
- if config.algorithm.kl_ctrl.type == 'fixed':
- self.kl_ctrl = FixedKLController(kl_coef=config.algorithm.kl_ctrl.kl_coef)
- elif config.algorithm.kl_ctrl.type == 'adaptive':
- if config.algorithm.kl_ctrl.horizon <= 0:
- raise ValueError(f'horizon must be larger than 0. Got {config.critic.kl_ctrl.horizon}')
- self.kl_ctrl = AdaptiveKLController(init_kl_coef=config.algorithm.kl_ctrl.kl_coef,
- target_kl=config.algorithm.kl_ctrl.target_kl,
- horizon=config.algorithm.kl_ctrl.horizon)
- else:
- raise NotImplementedError
- else:
- self.kl_ctrl = FixedKLController(kl_coef=0.)
- self.init_workers()
-
- def init_workers(self):
- """Init resource pool and worker group"""
- set_actor_infer_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_infer)
- set_actor_train_world_size(self.config.actor_rollout_ref.actor_rollout.num_gpus_for_train)
- self.resource_pool_manager.create_resource_pool()
- self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
-
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
- actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout],
- config=self.config,
- role='actor_rollout')
- self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
-
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
- ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],
- config=self.config,
- role='ref')
- self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls
-
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
- critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic],
- config=self.config,
- role='critic')
- self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls
-
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
- reward_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.RewardModel],
- config=self.config,
- role='reward')
- self.resource_pool_to_cls[resource_pool]['reward'] = reward_cls
-
- # initialize WorkerGroup
- all_wg = {}
- for resource_pool, class_dict in self.resource_pool_to_cls.items():
- worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
- wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
- spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
- all_wg.update(spawn_wg)
-
- self.ref_policy_wg = all_wg.get('ref')
- self.ref_policy_wg.initialize()
-
- self.actor_rollout_wg = all_wg.get('actor_rollout')
- self.actor_rollout_wg.initialize()
-
- self.critic_wg = all_wg.get('critic')
- self.critic_wg.initialize()
-
- self.reward_wg = all_wg.get('reward')
- self.reward_wg.initialize()
-
- def train(self):
- """
- The training loop of PPO.
- """
- logger = Loggers()
-
- iteration = self.actor_rollout_wg.get_iteration()[0]
-
- while iteration < self.config.training.train_iters:
- with Timer(name='gen', logger=None) as all_timer:
- metrics = {}
- self.actor_rollout_wg.auto_mapping()
- # generate a batch
- with Timer(name='gen', logger=None) as timer:
- batch = self.actor_rollout_wg.generate_sequences()
- batch = self.actor_rollout_wg.get_log_probs(batch)
- metrics['timing/gen'] = timer.last
-
- if self.use_reference_policy:
- # compute reference log_prob
- with Timer(name='ref', logger=None) as timer:
- ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
- batch = batch.union(ref_log_prob)
- metrics['timing/ref'] = timer.last
-
- # compute values
- with Timer(name='values', logger=None) as timer:
- values = self.critic_wg.compute_values(batch)
- batch = batch.union(values)
- metrics['timing/values'] = timer.last
-
- with Timer(name='adv', logger=None) as timer:
- # compute rm scores.
- reward_tensor = self.reward_wg.compute_rm_score(batch)
- batch = batch.union(reward_tensor)
-
- # compute rewards. apply_kl_penalty if available
- batch, kl_metrics = apply_kl_penalty(self.config, batch,
- kl_ctrl=self.kl_ctrl)
- metrics.update(kl_metrics)
-
- # compute advantages, executed on the driver process
- batch = compute_advantage(
- batch,
- self.config
- )
- metrics['timing/adv'] = timer.last
-
- # update critic
- with Timer(name='update_critic', logger=None) as timer:
- critic_output = self.critic_wg.update_critic(batch)
- metrics['timing/update_critic'] = timer.last
- critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
- metrics.update(critic_output_metrics)
-
- # update actor
- with Timer(name='update_actor', logger=None) as timer:
- actor_output = self.actor_rollout_wg.update_actor(batch)
- metrics['timing/update_actor'] = timer.last
- actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
- metrics.update(actor_output_metrics)
-
- # collect metrics
- data_metrics = compute_data_metrics(batch=batch)
- metrics.update(data_metrics)
- metrics['timing/all'] = all_timer.last
- iteration += 1
- logger.info(metrics, iteration, self.config.training.train_iters)
-
- if iteration % self.config.training.save_interval == 0:
- self.critic_wg.save_checkpoint(iteration)
- self.actor_rollout_wg.save_checkpoint(iteration)
-
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/__init__.py
deleted file mode 100644
index 75846436cd1285259d2bae6d4a7f190aebed1a80..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .worker import Worker
-from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/decorator.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/decorator.py
deleted file mode 100644
index 1597d16fd929f251c0a51854c21977b69ac20189..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/decorator.py
+++ /dev/null
@@ -1,484 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from enum import Enum
-from functools import wraps, partial
-from typing import Dict, List, Tuple
-from types import FunctionType
-
-
-# here we add a magic number of avoid user-defined function already have this attribute
-MAGIC_ATTR = 'attrs_3141562937'
-
-
-class Dispatch(Enum):
- RANK_ZERO = 0
- ONE_TO_ALL = 1
- ALL_TO_ALL = 2
- MEGATRON_COMPUTE = 3
- MEGATRON_PP_AS_DP = 4
- MEGATRON_PP_ONLY = 5
- MEGATRON_COMPUTE_PROTO = 6
- MEGATRON_PP_AS_DP_PROTO = 7
- DP_COMPUTE = 8
- DP_COMPUTE_PROTO = 9
- DP_COMPUTE_PROTO_WITH_FUNC = 10
- DP_COMPUTE_METRIC = 11
- DP_ALL_GATHER_TRAIN = 12
- DP_ALL_GATHER_INFER = 13
-
-
-
-class Execute(Enum):
- ALL = 0
- RANK_ZERO = 1
- INFER = 2
- TRAIN = 3
-
-
-def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
- from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, DataProtoFuture
- splitted_args = []
- for arg in args:
- if not isinstance(arg, (DataProto, DataProtoFuture)):
- raise TypeError(f"Argument {arg} must be an instance of DataProto or DataProtoFuture. Got {type(arg)}")
- splitted_args.append(arg.chunk(chunks=chunks))
-
- splitted_kwargs = {}
- for key, val in kwargs.items():
- if not isinstance(val, (DataProto, DataProtoFuture)):
- raise TypeError(f"Value for key {key} must be an instance of DataProto or DataProtoFuture. Got {type(val)}")
- splitted_kwargs[key] = val.chunk(chunks=chunks)
-
- return splitted_args, splitted_kwargs
-
-
-def dispatch_one_to_all(worker_group, *args, **kwargs):
- args = tuple([arg] * worker_group.world_size for arg in args)
- kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
- return args, kwargs
-
-
-def dispatch_all_to_all(worker_group, *args, **kwargs):
- return args, kwargs
-
-
-def collect_all_to_all(worker_group, output):
- return output
-
-
-def dispatch_megatron_compute(worker_group, *args, **kwargs):
- """
- User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
- """
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
- if not isinstance(worker_group, MegatronWorkerGroup):
- raise TypeError(f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}')
-
- all_args = []
- for arg in args:
- if not isinstance(arg, (Tuple, List)) or len(arg) != worker_group.dp_size:
- raise ValueError(f'Each argument must be a Tuple or List of length {worker_group.dp_size}, Got length {len(arg)}')
- transformed_args = []
- for i in range(worker_group.world_size):
- local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
- transformed_args.append(arg[local_dp_rank])
- all_args.append(transformed_args)
- all_args = tuple(all_args)
-
- all_kwargs = {}
- for k, v in kwargs.items():
- if not isinstance(v, (Tuple, List)) or len(v) != worker_group.dp_size:
- raise ValueError(f'Each argument in kwargs must be a Tuple or List of length {worker_group.dp_size}, Got length {len(v)}')
- transformed_v = []
- for i in range(worker_group.world_size):
- local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
- transformed_v.append(v[local_dp_rank])
- all_kwargs[k] = transformed_v
- return all_args, all_kwargs
-
-
-def collect_megatron_compute(worker_group, output):
- """
- Only collect the data from the tp=0 and pp=last and every dp ranks
- """
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
- if not isinstance(worker_group, MegatronWorkerGroup):
- raise TypeError(f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}')
- output_in_dp = []
- pp_size = worker_group.get_megatron_global_info().pp_size
- for global_rank in range(worker_group.world_size):
- local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
- if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1:
- output_in_dp.append(output[global_rank])
- return output_in_dp
-
-
-def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
- """
- All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
- """
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
- if not isinstance(worker_group, MegatronWorkerGroup):
- raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}')
-
- splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs)
- return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs)
-
-
-def _concat_data_proto_or_future(output: List):
- from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, DataProtoFuture
- import ray
-
- # make sure all the elements in output has the same type
- for single_output in output:
- if not isinstance(single_output, type(output[0])):
- raise TypeError(f"All elements in output must have the same type. Found {type(single_output)} and {type(output[0])}")
-
- output_prime = output[0]
-
- if isinstance(output_prime, DataProto):
- return DataProto.concat(output)
- elif isinstance(output_prime, ray.ObjectRef):
- return DataProtoFuture.concat(output)
- else:
- raise NotImplementedError
-
-
-def collect_megatron_compute_data_proto(worker_group, output):
- """
- Each output must be a DataProto. We concat the dim=0 of output
- """
- from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto
- import ray
-
- output = collect_megatron_compute(worker_group, output)
- for single_output in output:
- if not isinstance(single_output, (DataProto, ray.ObjectRef)):
- raise TypeError(f"Expecting {single_output} to be DataProto or ray.ObjectRef, but got {type(single_output)}")
-
- return _concat_data_proto_or_future(output)
-
-
-def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
- """
- treat pp as dp.
- """
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
- if not isinstance(worker_group, MegatronWorkerGroup):
- raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}')
-
- pp_size = worker_group.pp_size
- dp_size = worker_group.dp_size
-
- pp_dp_size = pp_size * dp_size
-
- all_args = []
- for arg in args:
- if not isinstance(arg, (List, Tuple)) or len(arg) != pp_dp_size:
- raise ValueError(f'Each argument in args must be a List or Tuple of length {pp_dp_size}, but got length {len(arg)}')
- transformed_args = []
- for i in range(worker_group.world_size):
- local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
- local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
- # compute the rank in arg. Note that the order is dp then pp
- # Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected.
- # For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order:
- # dispatch: pp_allgther: collect:
- # dp 0 1 2 3 dp 0 1 2 3
- # pp +---------+ pp +-------------+
- # 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH
- # 1 | B D F H | 1 | AB CD EF GH |
- # +---------+ +-------------+
- arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
-
- transformed_args.append(arg[arg_rank])
- all_args.append(transformed_args)
- all_args = tuple(all_args)
-
- all_kwargs = {}
- for k, v in kwargs.items():
- if not isinstance(v, (List, Tuple)) or len(v) != pp_dp_size:
- raise ValueError(f'Each argument in kwargs must be a List or Tuple of length {pp_dp_size}, but got length {len(v)}')
- transformed_v = []
- for i in range(worker_group.world_size):
- local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
- local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
- # compute the rank in arg. Note that the order is dp then pp
- arg_rank = local_dp_rank * worker_group.pp_size + local_pp_rank
- transformed_v.append(v[arg_rank])
- all_kwargs[k] = transformed_v
- return all_args, all_kwargs
-
-
-def collect_megatron_pp_as_dp(worker_group, output):
- """
- treat pp as dp. Only collect data on tp=0
- """
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
- if not isinstance(worker_group, MegatronWorkerGroup):
- raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}')
- output_in_dp = []
- for global_rank in range(worker_group.world_size):
- local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
- if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == 0:
- output_in_dp.append(output[global_rank])
- return output_in_dp
-
-
-def collect_megatron_pp_only(worker_group, output):
- """
- Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
- """
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
- if not isinstance(worker_group, MegatronWorkerGroup):
- raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}')
- output_in_pp = []
- for global_rank in range(worker_group.world_size):
- local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
- if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0:
- output_in_pp.append(output[global_rank])
- return output_in_pp
-
-
-def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
- if not isinstance(worker_group, MegatronWorkerGroup):
- raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}')
-
- pp_dp_size = worker_group.dp_size * worker_group.pp_size
- splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_size, *args, **kwargs)
- return dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs)
-
-
-def collect_megatron_pp_as_dp_data_proto(worker_group, output):
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
- if not isinstance(worker_group, MegatronWorkerGroup):
- raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}')
-
- output = collect_megatron_pp_as_dp(worker_group, output)
- return _concat_data_proto_or_future(output)
-
-
-def dispatch_dp_compute(worker_group, *args, **kwargs):
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker_group import WorkerGroup
- if not isinstance(worker_group, WorkerGroup):
- raise TypeError(f'worker_group must be an instance of WorkerGroup. Got {type(worker_group)}')
- for arg in args:
- if not isinstance(arg, (Tuple, List)) or len(arg) != worker_group.world_size:
- raise ValueError(f'Each argument in args must be a Tuple or List of length {worker_group.world_size}')
- for _, v in kwargs.items():
- if not isinstance(v, (Tuple, List)) or len(v) != worker_group.world_size:
- raise ValueError(f'Each argument in kwargs must be a Tuple or List of length {worker_group.world_size}')
- return args, kwargs
-
-
-def collect_dp_compute(worker_group, output):
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker_group import WorkerGroup
- if not isinstance(worker_group, WorkerGroup):
- raise TypeError(f'worker_group must be an instance of WorkerGroup. Got {type(worker_group)}')
-
- if len(output) != worker_group.world_size:
- raise ValueError(f'Output must have a length equal to world_size. Got length {len(output)}')
- return output
-
-
-def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker_group import WorkerGroup
- if not isinstance(worker_group, WorkerGroup):
- raise TypeError(f'worker_group must be an instance of WorkerGroup. Got {type(worker_group)}')
- splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
- return splitted_args, splitted_kwargs
-
-
-def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker_group import WorkerGroup
- if not isinstance(worker_group, WorkerGroup):
- raise TypeError(f'worker_group must be an instance of WorkerGroup. Got {type(worker_group)}')
-
- if type(args[0]) != FunctionType:
- raise TypeError(f'The first argument must be a callable function. Got {type(args[0])}') # NOTE: The first one args is a function!
-
- splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
- splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
- return splitted_args_with_func, splitted_kwargs
-
-
-def collect_dp_compute_data_proto(worker_group, output):
- import ray
- from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto
- for single_output in output:
- if not isinstance(single_output, (DataProto, ray.ObjectRef)):
- raise TypeError(f"Expecting {single_output} to be DataProto or ray.ObjectRef, but got {type(single_output)}")
-
- output = collect_dp_compute(worker_group, output)
- return _concat_data_proto_or_future(output)
-
-
-def collect_dp_all_gather(worker_group, output, is_train):
- """
- collect data in DP groups, in each DP group, only use the output return on TP_0 PP_last.
- """
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
- if not isinstance(worker_group, MegatronWorkerGroup):
- raise TypeError(f'worker_group must be an instance of MegatronWorkerGroup. Got {type(worker_group)}')
- output_in_dp = []
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.ray.base import get_actor_train_world_size
- actor_train_world_size = get_actor_train_world_size()
- pp_size = worker_group.get_megatron_global_info().pp_size if is_train else 1
- rank_offset = 0 if is_train else actor_train_world_size
- for global_rank in range(worker_group.world_size):
- is_train_node = global_rank < actor_train_world_size
- if is_train_node and not is_train:
- continue
- elif not is_train_node and is_train:
- continue
- local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
- if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1:
- output_in_dp.append(output[global_rank - rank_offset])
- return _concat_data_proto_or_future(output_in_dp)
-
-collect_dp_train = partial(collect_dp_all_gather, is_train=True)
-collect_dp_infer = partial(collect_dp_all_gather, is_train=False)
-
-
-
-def get_predefined_dispatch_fn(dispatch_mode):
- predefined_dispatch_mode_fn = {
- Dispatch.ONE_TO_ALL: {
- 'dispatch_fn': dispatch_one_to_all,
- 'collect_fn': collect_all_to_all,
- },
- Dispatch.ALL_TO_ALL: {
- 'dispatch_fn': dispatch_all_to_all,
- 'collect_fn': collect_all_to_all,
- },
- Dispatch.MEGATRON_COMPUTE: {
- 'dispatch_fn': dispatch_megatron_compute,
- 'collect_fn': collect_megatron_compute,
- },
- Dispatch.MEGATRON_PP_AS_DP: {
- 'dispatch_fn': dispatch_megatron_pp_as_dp,
- 'collect_fn': collect_megatron_pp_as_dp,
- },
- Dispatch.MEGATRON_PP_ONLY: {
- 'dispatch_fn': dispatch_one_to_all,
- 'collect_fn': collect_megatron_pp_only
- },
- Dispatch.MEGATRON_COMPUTE_PROTO: {
- 'dispatch_fn': dispatch_megatron_compute_data_proto,
- 'collect_fn': collect_megatron_compute_data_proto
- },
- Dispatch.MEGATRON_PP_AS_DP_PROTO: {
- 'dispatch_fn': dispatch_megatron_pp_as_dp_data_proto,
- 'collect_fn': collect_megatron_pp_as_dp_data_proto
- },
- Dispatch.DP_COMPUTE: {
- 'dispatch_fn': dispatch_dp_compute,
- 'collect_fn': collect_dp_compute
- },
- Dispatch.DP_COMPUTE_PROTO: {
- 'dispatch_fn': dispatch_dp_compute_data_proto,
- 'collect_fn': collect_dp_compute_data_proto
- },
- Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
- 'dispatch_fn': dispatch_dp_compute_data_proto_with_func,
- 'collect_fn': collect_dp_compute_data_proto
- },
- Dispatch.DP_COMPUTE_METRIC: {
- 'dispatch_fn': dispatch_dp_compute_data_proto,
- 'collect_fn': collect_dp_compute
- },
- Dispatch.DP_ALL_GATHER_TRAIN: {
- 'dispatch_fn': dispatch_one_to_all,
- 'collect_fn': collect_dp_train,
- },
- Dispatch.DP_ALL_GATHER_INFER: {
- 'dispatch_fn': dispatch_one_to_all,
- 'collect_fn': collect_dp_infer,
- },
- }
- return predefined_dispatch_mode_fn.get(dispatch_mode)
-
-
-def get_predefined_execute_fn(execute_mode):
- """
- Note that here we only asks execute_all and execute_rank_zero to be implemented
- Leave the choice of how these two functions handle argument 'blocking' to users
- """
- predefined_execute_mode_fn = {
- Execute.ALL: {
- 'execute_fn_name': 'execute_all'
- },
- Execute.RANK_ZERO: {
- 'execute_fn_name': 'execute_rank_zero'
- },
- Execute.INFER: {
- 'execute_fn_name': 'execute_infer'
- },
- Execute.TRAIN: {
- 'execute_fn_name': 'execute_train'
- }
- }
- return predefined_execute_mode_fn.get(execute_mode)
-
-
-def _check_dispatch_mode(dispatch_mode):
- if not isinstance(dispatch_mode, (Dispatch, Dict)):
- raise TypeError(f'dispatch_mode must be a Dispatch or a Dict. Got {type(dispatch_mode)}')
- if isinstance(dispatch_mode, Dict):
- necessary_keys = ['dispatch_fn', 'collect_fn']
- for key in necessary_keys:
- if key not in dispatch_mode:
- raise KeyError(f'key {key} should be in dispatch_mode if it is a dictionary')
-
-
-def _check_execute_mode(execute_mode):
- if not isinstance(execute_mode, Execute):
- raise TypeError(f'execute_mode must be an instance of Execute. Got {type(execute_mode)}')
-
-
-def _materialize_futures(*args, **kwargs):
- from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProtoFuture
- new_args = []
- for arg in args:
- if isinstance(arg, DataProtoFuture):
- arg = arg.get()
- # add more type to materialize
- new_args.append(arg)
- for k, v in kwargs.items():
- if isinstance(v, DataProtoFuture):
- kwargs[k] = v.get()
-
- new_args = tuple(new_args)
- return new_args, kwargs
-
-
-def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
- _check_dispatch_mode(dispatch_mode=dispatch_mode)
- _check_execute_mode(execute_mode=execute_mode)
-
- def decorator(func):
-
- @wraps(func)
- def inner(*args, **kwargs):
- if materialize_futures:
- args, kwargs = _materialize_futures(*args, **kwargs)
- return func(*args, **kwargs)
-
- attrs = {'dispatch_mode': dispatch_mode, 'execute_mode': execute_mode, 'blocking': blocking}
- setattr(inner, MAGIC_ATTR, attrs)
- return inner
-
- return decorator
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/__init__.py
deleted file mode 100644
index 1ce90c5eb352d85c59105c0dc85b5f1dd576f095..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker.py
deleted file mode 100644
index f3c3341df05db9c85a59c5c25f14214c6d18bb7f..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from dataclasses import dataclass
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo
-
-
-class MegatronWorker(Worker):
-
- def __init__(self, cuda_visible_devices=None) -> None:
- super().__init__(cuda_visible_devices)
-
- def get_megatron_global_info(self):
- from megatron.core import parallel_state as mpu
- tp_size = mpu.get_tensor_model_parallel_world_size()
- dp_size = mpu.get_data_parallel_world_size()
- pp_size = mpu.get_pipeline_model_parallel_world_size()
- info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size)
- return info
-
- def get_megatron_rank_info(self):
- from megatron.core import parallel_state as mpu
- tp_rank = mpu.get_tensor_model_parallel_rank()
- dp_rank = mpu.get_data_parallel_rank()
- pp_rank = mpu.get_pipeline_model_parallel_rank()
- info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank)
- return info
\ No newline at end of file
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker_group.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker_group.py
deleted file mode 100644
index 85371f6eab72b4097b2f599ef4c845272b33b06c..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/megatron/worker_group.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Dict
-
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import ResourcePool, WorkerGroup
-from .worker import DistRankInfo, DistGlobalInfo
-
-
-class MegatronWorkerGroup(WorkerGroup):
-
- def __init__(self, resource_pool: ResourcePool, **kwargs):
- super().__init__(resource_pool=resource_pool, **kwargs)
- self._megatron_rank_info = None
- self._megatron_global_info: DistGlobalInfo = None
-
- def init_megatron(self, default_megatron_kwargs: Dict = None):
- raise NotImplementedError(f"MegatronWorkerGroup.init_megatron should be overwritten")
-
- def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
- if not (0 <= rank < self.world_size):
- raise ValueError(f'rank must be from [0, world_size), Got {rank}')
- return self._megatron_rank_info[rank]
-
- @property
- def tp_size(self):
- if self._megatron_global_info is None:
- raise ValueError("MegatronWorkerGroup._megatron_global_info must be initialized")
- return self._megatron_global_info.tp_size
-
- @property
- def dp_size(self):
- if self._megatron_global_info is None:
- raise ValueError("MegatronWorkerGroup._megatron_global_info must be initialized")
- return self._megatron_global_info.dp_size
-
- @property
- def pp_size(self):
- if self._megatron_global_info is None:
- raise ValueError("MegatronWorkerGroup._megatron_global_info must be initialized")
- return self._megatron_global_info.pp_size
-
- def get_megatron_global_info(self):
- return self._megatron_global_info
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/__init__.py
deleted file mode 100644
index 1ce90c5eb352d85c59105c0dc85b5f1dd576f095..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/ray.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/ray.py
deleted file mode 100644
index 430290cf2683d882d35a83256aa363d959265a05..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/register_center/ray.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import ray
-
-
-@ray.remote
-class WorkerGroupRegisterCenter:
-
- def __init__(self, rank_zero_info):
- self.rank_zero_info = rank_zero_info
-
- def get_rank_zero_info(self):
- return self.rank_zero_info
-
-
-def create_worker_group_register_center(name, info):
- return WorkerGroupRegisterCenter.options(name=name).remote(info)
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker.py
deleted file mode 100644
index fd9c162e70c4821d146c9d9f1b42e251744325f2..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-the class for Worker
-"""
-import os
-import socket
-from dataclasses import dataclass
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch
-
-
-@dataclass
-class DistRankInfo:
- tp_rank: int
- dp_rank: int
- pp_rank: int
-
-
-@dataclass
-class DistGlobalInfo:
- tp_size: int
- dp_size: int
- pp_size: int
-
-
-class WorkerHelper:
-
- def _get_node_ip(self):
-
- def get_node_ip_by_sdk():
- if os.getenv("WG_BACKEND", None) == "ray":
- import ray
- return ray._private.services.get_node_ip_address()
- elif os.getenv("WG_BACKEND", None) == "torch_rpc":
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller import get_ip_addr
- return get_ip_addr()
- return None
-
- host_ipv4 = os.getenv("MY_HOST_IP", None)
- host_ipv6 = os.getenv("MY_HOST_IPV6", None)
- host_ip_by_env = host_ipv4 or host_ipv6
- host_ip_by_sdk = get_node_ip_by_sdk()
-
- host_ip = host_ip_by_env or host_ip_by_sdk
- return host_ip
-
- def _get_free_port(self):
- with socket.socket() as sock:
- sock.bind(('', 0))
- return sock.getsockname()[1]
-
- def get_availale_master_addr_port(self):
- return self._get_node_ip(), str(self._get_free_port())
-
- def _get_pid(self):
- return
-
-
-class WorkerMeta:
- keys = [
- "WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", "CUDA_VISIBLE_DEVICES"
- ]
-
- def __init__(self, store) -> None:
- self._store = store
-
- def to_dict(self):
- return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
-
-
-# we assume that in each WorkerGroup, there is a Master Worker
-class Worker(WorkerHelper):
-
- def __new__(cls, *args, **kwargs):
- instance = super().__new__(cls)
-
- # note that here we use int to distinguish
- disable_worker_init = int(os.environ.get('DISABLE_WORKER_INIT', 0))
- if disable_worker_init:
- return instance
-
- rank = os.environ.get("RANK", None)
- worker_group_prefix = os.environ.get("WG_PREFIX", None)
-
- # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
- if None not in [rank, worker_group_prefix] and 'ActorClass(' not in cls.__name__:
- instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
-
- return instance
-
- def _configure_before_init(self, register_center_name: str, rank: int):
- if not isinstance(rank, int):
- raise TypeError(f"rank must be int, instead of {type(rank)}")
-
- if rank == 0:
- master_addr, master_port = self.get_availale_master_addr_port()
- rank_zero_info = {
- "MASTER_ADDR": master_addr,
- "MASTER_PORT": master_port,
- }
-
- if os.getenv("WG_BACKEND", None) == "ray":
- from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.register_center.ray import create_worker_group_register_center
- self.register_center = create_worker_group_register_center(name=register_center_name,
- info=rank_zero_info)
-
- os.environ.update(rank_zero_info)
-
- def __init__(self, cuda_visible_devices=None) -> None:
- # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
- import os
- world_size = int(os.environ['WORLD_SIZE'])
- rank = int(os.environ['RANK'])
- self._rank = rank
- self._world_size = world_size
-
- master_addr = os.environ["MASTER_ADDR"]
- master_port = os.environ["MASTER_PORT"]
-
- local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
- local_rank = int(os.getenv("LOCAL_RANK", "0"))
-
- store = {
- '_world_size': world_size,
- '_rank': rank,
- '_local_world_size': local_world_size,
- '_local_rank': local_rank,
- '_master_addr': master_addr,
- '_master_port': master_port
- }
- if cuda_visible_devices is not None:
- store['_cuda_visible_devices'] = cuda_visible_devices
-
- meta = WorkerMeta(store=store)
- self._configure_with_meta(meta=meta)
-
- def _configure_with_meta(self, meta: WorkerMeta):
- """
- This function should only be called inside by WorkerGroup
- """
- if not isinstance(meta, WorkerMeta):
- raise TypeError(
- f"Invalid meta type: expected WorkerMeta, got {type(meta).__name__}. "
- f"(Received value: {repr(meta)})"
- )
- self.__dict__.update(meta.to_dict()) # this is hacky
- for key in WorkerMeta.keys:
- val = self.__dict__.get(f"_{key.lower()}", None)
- if val is not None:
- os.environ[key] = str(val)
- os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace(
- "]", "") if self._master_addr else ""
-
- def get_master_addr_port(self):
- return self._master_addr, self._master_port
-
- def get_cuda_visible_devices(self):
- import os
- cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set")
- return cuda_visible_devices
-
- @property
- def world_size(self):
- return self._world_size
-
- @property
- def rank(self):
- return self._rank
-
- @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
- def execute_with_func_generator(self, func, *args, **kwargs):
- ret_proto = func(self, *args, **kwargs)
- return ret_proto
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker_group.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker_group.py
deleted file mode 100644
index e847b7dccde9c2c24d9a782760a07ea859b36ab0..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/base/worker_group.py
+++ /dev/null
@@ -1,222 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-the class of WorkerGroup
-"""
-import logging
-import threading
-import signal
-import time
-from typing import List, Any, Callable, Dict
-
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
-
-
-class ResourcePool:
-
- def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
- if process_on_nodes is None:
- process_on_nodes = []
- self._store = process_on_nodes
- self.max_collocate_count = max_collocate_count
- self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
-
- def add_node(self, process_count):
- self._store.append(process_count)
-
- @property
- def world_size(self):
- return sum(self._store)
-
- def __call__(self) -> Any:
- return self._store
-
- @property
- def store(self):
- return self._store
-
- def local_world_size_list(self) -> List[int]:
- nested_local_world_size_list = []
-
- for local_world_size in self._store:
- inner_list = []
-
- for _ in range(local_world_size):
- inner_list.append(local_world_size)
-
- nested_local_world_size_list.append(inner_list)
-
- return [item for row in nested_local_world_size_list for item in row]
-
- def local_rank_list(self) -> List[int]:
- nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
- return [item for row in nested_local_rank_list for item in row]
-
-
-class ClassWithInitArgs:
- """
- This class stores a class constructor and the args/kwargs to construct the class.
- It is used to instantiate the remote class.
- """
-
- def __init__(self, cls, *args, **kwargs) -> None:
- self.cls = cls
- self.args = args
- self.kwargs = kwargs
-
- def __call__(self) -> Any:
- return self.cls(*self.args, **self.kwargs)
-
-
-def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
- import time
- while True:
- for worker in workers:
- if not is_alive(worker):
- logging.warning(f"worker {worker} is not alive" + " sending signal to main thread")
- signal.raise_signal(signal.SIGABRT)
- time.sleep(gap_time)
-
-
-class WorkerGroup:
-
- def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
- self._is_init_with_detached_workers = True if resource_pool is None else False
-
- if resource_pool is not None:
- # handle the case when WorkGroup is attached to an existing one
- self._procecss_dispatch_config = resource_pool()
- else:
- self._procecss_dispatch_config = None
-
- self._workers = []
- self._worker_names = []
-
- self._master_addr = None
- self._master_port = None
-
- self._checker_thread: threading.Thread = None
-
- def _is_worker_alive(self, worker):
- raise NotImplementedError(f"WorkerGroup._is_worker_alive called, should be implemented in derived class.")
-
- def _block_until_all_workers_alive(self) -> None:
- while True:
- all_state = [self._is_worker_alive(worker) for worker in self._workers]
- if False in all_state:
- time.sleep(1)
- else:
- break
-
- def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
- # before starting checking worker aliveness, make sure all workers are already alive
- self._block_until_all_workers_alive()
-
- self._checker_thread = threading.Thread(target=check_workers_alive,
- args=(self._workers, self._is_worker_alive, every_n_seconds))
- self._checker_thread.start()
-
- @property
- def world_size(self):
- return len(self._workers)
-
- # execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
- # MegatronWorkerGroup, XperfWorkerGroup should skip
-
- def _bind_worker_method(self, user_defined_cls, func_generator):
- """
- Bind the worker method to the WorkerGroup
- """
-
- for method_name in dir(user_defined_cls):
-
- try:
- method = getattr(user_defined_cls, method_name)
- if not callable(method):
- raise TypeError(
- f"{method_name} in {user_defined_cls} is not callable"
- )
- except Exception as e:
- # if it is a property, it will fail because Class doesn't have instance property
- continue
-
- if hasattr(method, MAGIC_ATTR):
- # this method is decorated by register
- attribute = getattr(method, MAGIC_ATTR)
- if not isinstance(attribute, dict):
- raise TypeError(
- f"Attribute must be a dictionary. Got {type(attribute)} for {method_name}"
- )
- if 'dispatch_mode' not in attribute:
- raise KeyError(
- f"Attribute must contain 'dispatch_mode' key in {method_name}. "
- f"Found keys: {list(attribute.keys())}"
- )
-
- dispatch_mode = attribute['dispatch_mode']
- execute_mode = attribute['execute_mode']
- blocking = attribute['blocking']
-
- # get dispatch fn
- if isinstance(dispatch_mode, Dispatch):
- # get default dispatch fn
- fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
- dispatch_fn = fn['dispatch_fn']
- collect_fn = fn['collect_fn']
- else:
- if not isinstance(dispatch_mode, dict):
- raise TypeError(
- f"dispatch_mode must be dict type. Got {type(dispatch_mode)} "
- f"in {method_name}"
- )
-
- if 'dispatch_fn' not in dispatch_mode:
- raise KeyError(
- f"dispatch_mode requires 'dispatch_fn' key in {method_name}. "
- f"Found keys: {list(dispatch_mode.keys())}"
- )
-
- if 'collect_fn' not in dispatch_mode:
- raise KeyError(
- f"dispatch_mode requires 'collect_fn' key in {method_name}. "
- f"Found keys: {list(dispatch_mode.keys())}"
- )
- dispatch_fn = dispatch_mode['dispatch_fn']
- collect_fn = dispatch_mode['collect_fn']
-
- # get execute_fn_name
- execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
- wg_execute_fn_name = execute_mode['execute_fn_name']
-
- # get execute_fn from string
- try:
- execute_fn = getattr(self, wg_execute_fn_name)
- if not callable(execute_fn):
- raise TypeError(f"{wg_execute_fn_name} is not callable")
- except Exception as e:
- print(f'execute_fn {wg_execute_fn_name} is invalid')
- raise
-
- # bind a new method to the RayWorkerGroup
- func = func_generator(self,
- method_name,
- dispatch_fn=dispatch_fn,
- collect_fn=collect_fn,
- execute_fn=execute_fn,
- blocking=blocking)
-
- try:
- setattr(self, method_name, func)
- except Exception as e:
- raise ValueError(f'Fail to set method_name {method_name}') from e
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/__init__.py
deleted file mode 100644
index ed0ede061bc226772a8777e10032ed2511747a86..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
-from .megatron import (DistRankInfo, DistGlobalInfo)
\ No newline at end of file
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/base.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/base.py
deleted file mode 100644
index 8f0fa4cb088458bbdff5c76d718f2897cda8dd60..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/base.py
+++ /dev/null
@@ -1,480 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-__all__ = ['Worker']
-
-import os
-import time
-from typing import Dict, List, Any
-from unittest.mock import patch
-
-import ray
-from ray.util import list_named_actors
-from ray.util.placement_group import placement_group
-from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
-from ray.experimental.state.api import get_actor
-
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import MAGIC_ATTR
-
-ACTOR_INFER_WORLD_SIZE = None
-ACTOR_TRAIN_WORLD_SIZE = None
-
-
-def set_actor_infer_world_size(world_size):
- global ACTOR_INFER_WORLD_SIZE
- ACTOR_INFER_WORLD_SIZE = world_size
-
-
-def set_actor_train_world_size(world_size):
- global ACTOR_TRAIN_WORLD_SIZE
- ACTOR_TRAIN_WORLD_SIZE = world_size
-
-
-def get_actor_infer_world_size():
- return ACTOR_INFER_WORLD_SIZE
-
-
-def get_actor_train_world_size():
- return ACTOR_TRAIN_WORLD_SIZE
-
-
-def get_random_string(length: int) -> str:
- import random
- import string
- letters_digits = string.ascii_letters + string.digits
- return ''.join(random.choice(letters_digits) for _ in range(length))
-
-
-def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):
-
- def func(*args, **kwargs):
- args, kwargs = dispatch_fn(self, *args, **kwargs)
- output = execute_fn(method_name, *args, **kwargs)
- if blocking:
- output = ray.get(output)
- output = collect_fn(self, output)
- return output
-
- return func
-
-
-class RayResourcePool(ResourcePool):
-
- def __init__(self,
- process_on_nodes: List[int] = None,
- use_gpu: bool = True,
- name_prefix: str = "",
- max_colocate_count: int = 5,
- detached=False) -> None:
- super().__init__(process_on_nodes, max_colocate_count)
- self.use_gpu = use_gpu
- self.name_prefix = name_prefix
- self.pgs = None
- self.detached = detached
-
- def get_placement_groups(self, strategy="STRICT_PACK", name=None):
- if self.pgs is not None:
- return self.pgs
-
- if not self.name_prefix:
- self.name_prefix = name
-
- pg_name_prefix = f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:"
- pg_scheme = [[{
- "CPU": self.max_collocate_count,
- "NPU": 1
- } if self.use_gpu else {
- "CPU": self.max_collocate_count
- } for _ in range(process_count)] for process_count in self._store]
-
- lifetime = 'detached' if self.detached else None
-
- pgs = [
- placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime)
- for idx, bundles in enumerate(pg_scheme)
- ]
-
- ray.get([pg.ready() for pg in pgs])
-
- self.pgs = pgs
- return pgs
-
-
-class RayClassWithInitArgs(ClassWithInitArgs):
-
- def __init__(self, cls, *args, **kwargs) -> None:
- super().__init__(cls, *args, **kwargs)
- self._options = {}
- self._additional_resource = {}
-
- def set_additional_resource(self, additional_resource):
- self._additional_resource = additional_resource
-
- def update_options(self, options: Dict):
- self._options.update(options)
-
- def __call__(self,
- placement_group,
- placement_group_bundle_idx,
- use_gpu: bool = True,
- num_gpus=1,
- sharing_with=None) -> Any:
- if sharing_with is not None:
- target_node_id = ray.get(sharing_with.get_node_id.remote())
- cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote())
- options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)}
- return self.cls.options(**options).remote(*self.args,
- cuda_visible_devices=cuda_visible_devices,
- **self.kwargs)
-
- options = {
- "scheduling_strategy":
- PlacementGroupSchedulingStrategy(placement_group=placement_group,
- placement_group_bundle_index=placement_group_bundle_idx)
- }
- options.update(self._options)
-
- if use_gpu:
- options["resources"] = {"NPU": num_gpus}
-
- if len(self._additional_resource) > 1:
- for k, v in self._additional_resource.items():
- options[k] = v
-
- print("cls:", self.cls)
- print("args: ", self.args)
- print("kwargs: ", self.kwargs)
- return self.cls.options(**options).remote(*self.args, **self.kwargs)
-
-
-class RayWorkerGroup(WorkerGroup):
-
- def __init__(self,
- resource_pool: RayResourcePool = None,
- ray_cls_with_init: RayClassWithInitArgs = None,
- bin_pack: bool = True,
- name_prefix: str = None,
- detached=False,
- worker_names=None,
- **kwargs) -> None:
- super().__init__(resource_pool=resource_pool, **kwargs)
- self.ray_cls_with_init = ray_cls_with_init
- self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix
-
- if self._is_init_with_detached_workers:
- self._init_with_detached_workers(worker_names=worker_names)
- else:
- self._init_with_resource_pool(resource_pool=resource_pool,
- ray_cls_with_init=ray_cls_with_init,
- bin_pack=bin_pack,
- detached=detached)
-
- if ray_cls_with_init is not None:
- self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
-
- def _is_worker_alive(self, worker: ray.actor.ActorHandle):
- worker_state_dict = get_actor(worker._actor_id.hex())
- return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
-
- def _init_with_detached_workers(self, worker_names):
- workers = [ray.get_actor(name=name) for name in worker_names]
- self._workers = workers
- self._worker_names = worker_names
- self._world_size = len(worker_names)
-
- def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached):
- use_gpu = resource_pool.use_gpu
-
- strategy = "PACK"
- if bin_pack:
- strategy = "STRICT_PACK"
- pgs = resource_pool.get_placement_groups(strategy=strategy, name=self.name_prefix)
- world_size = resource_pool.world_size
- self._world_size = world_size
- num_gpus = 1 / resource_pool.max_collocate_count
-
- rank = -1
- for pg_idx, local_world_size in enumerate(resource_pool.store):
- pg = pgs[pg_idx]
- if local_world_size > pg.bundle_count:
- raise ValueError(f"when generating for {self.name_prefix}, for the ")
- for local_rank in range(local_world_size):
- rank += 1
-
- # we pass in environment variable at option so that Worker can use environment variable to set
- env_vars = {
- 'WORLD_SIZE': str(world_size),
- 'RANK': str(rank),
- 'WG_PREFIX': self.name_prefix,
- 'WG_BACKEND': 'ray',
- 'RAY_LOCAL_WORLD_SIZE': str(local_world_size),
- 'RAY_LOCAL_RANK': str(local_rank),
- }
- if rank != 0:
- env_vars['MASTER_ADDR'] = self._master_addr
- env_vars['MASTER_PORT'] = self._master_port
-
- import re
- cia_name = type(ray_cls_with_init.cls).__name__
- match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
- cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
- name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5
- ray_cls_with_init.update_options({'runtime_env': {'env_vars': env_vars}, 'name': name})
-
- if detached:
- ray_cls_with_init.update_options({'lifetime': 'detached'})
-
- # create a worker
- worker = ray_cls_with_init(placement_group=pg,
- placement_group_bundle_idx=local_rank,
- use_gpu=use_gpu,
- num_gpus=num_gpus)
- self._workers.append(worker)
- self._worker_names.append(name)
-
- if rank == 0:
- register_center_actor = None
- for _ in range(120):
- if f"{self.name_prefix}_register_center" not in list_named_actors():
- time.sleep(1)
- else:
- register_center_actor = ray.get_actor(f"{self.name_prefix}_register_center")
- if register_center_actor is None:
- available_actors = list_named_actors(all_namespaces=True)
- raise ValueError(
- f"failed to get register_center_actor: {self.name_prefix}_register_center in {list_named_actors(all_namespaces=True)}"
- )
- rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote())
- self._master_addr, self._master_port = rank_zero_info['MASTER_ADDR'], rank_zero_info['MASTER_PORT']
-
- @property
- def worker_names(self):
- return self._worker_names
-
- @classmethod
- def from_detached(cls, worker_names=None, ray_cls_with_init=None):
- worker_group = cls(resource_pool=None,
- ray_cls_with_init=ray_cls_with_init,
- name_prefix=None,
- worker_names=worker_names)
- return worker_group
-
- def spawn(self, prefix_set):
- """
- spawn to a dictionary of worker groups, each with a subset of method with prefix.
-
- """
-
- def remove_prefix(text, prefix):
- if text.startswith(prefix):
- return text[len(prefix):]
- return text
-
- def _rebind_actor_methods(worker_group, actor_name):
- """
- bind the method with actor_prefix to its original name
- """
- prefix: str = actor_name + '_'
- for method_name in dir(worker_group):
- if method_name.startswith(prefix):
- original_method_name = remove_prefix(method_name, prefix)
- method = getattr(worker_group, method_name)
- setattr(worker_group, original_method_name, method)
-
- new_worker_group_dict = {}
- for prefix in prefix_set:
- new_worker_group = self.from_detached(worker_names=self._worker_names,
- ray_cls_with_init=self.ray_cls_with_init)
-
- _rebind_actor_methods(new_worker_group, prefix)
- new_worker_group_dict[prefix] = new_worker_group
- return new_worker_group_dict
-
- def execute_rank_zero_sync(self, method_name: str, *args, **kwargs):
- return ray.get(self.execute_all_async(method_name, **args, **kwargs))
-
- def execute_rank_zero_async(self, method_name: str, *args, **kwargs):
- remote_call = getattr(self._workers[0], method_name)
- return remote_call.remote(*args, **kwargs)
-
- def execute_rank_zero(self, method_name: str, *args, **kwargs):
- return self.execute_rank_zero_async(method_name, *args, **kwargs)
-
- def execute_all(self, method_name: str, *args, **kwargs):
- return self.execute_all_async(method_name, *args, **kwargs)
-
- def execute_all_sync(self, method_name: str, *args, **kwargs):
- return ray.get(self.execute_all_async(method_name, *args, **kwargs))
-
- def execute_all_async(self, method_name: str, *args, **kwargs):
- # 这里我们假设,如果 args 和 kwargs 里面所有的参数都是 list,且所有的 list 长度都与 len(self._workers) 一致的话,我们会把
- # list 中的每一个分别发到对应的 worker 上去
- length = len(self._workers)
- if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
- if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):
- result = []
- for i in range(length):
- sliced_args = tuple(arg[i] for arg in args)
- sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
- remote_call = getattr(self._workers[i], method_name)
- result.append(remote_call.remote(*sliced_args, **sliced_kwargs))
- return result
-
- return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self._workers]
-
- def execute_train(self, method_name: str, *args, **kwargs):
- return self.execute_train_infer_async(method_name, True, *args, **kwargs)
-
- def execute_train_sync(self, method_name: str, *args, **kwargs):
- return ray.get(self.execute_train_infer_async(method_name, True, *args, **kwargs))
-
- def execute_infer(self, method_name: str, *args, **kwargs):
- return self.execute_train_infer_async(method_name, False, *args, **kwargs)
-
- def execute_infer_sync(self, method_name: str, *args, **kwargs):
- return ray.get(self.execute_train_infer_async(method_name, False, *args, **kwargs))
-
-
- def execute_train_infer_async(self, method_name: str, is_train: bool, *args, **kwargs):
- # 这里我们假设,如果 args 和 kwargs 里面所有的参数都是 list,且所有的 list 长度都与 len(self._workers) 一致的话,我们会把
- # list 中的每一个分别发到对应的 worker 上去
- all_length = len(self._workers)
- infer_length = get_actor_infer_world_size()
- train_length = get_actor_train_world_size()
- if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
- if all(len(arg) == all_length for arg in args) and all(len(kwarg) == all_length for kwarg in kwargs.values()):
- result = []
- loop_length = train_length if is_train else infer_length
- offset = 0 if is_train else train_length
- for i in range(loop_length):
- sliced_args = tuple(arg[i] for arg in args)
- sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
- remote_call = getattr(self._workers[offset + i], method_name)
- result.append(remote_call.remote(*sliced_args, **sliced_kwargs))
- return result
-
- if is_train:
- workers = self._workers[:train_length]
- else:
- workers = self._workers[train_length:]
- return [getattr(worker, method_name).remote(*args, **kwargs) for worker in workers]
-
- @property
- def master_address(self):
- return self._master_addr
-
- @property
- def master_port(self):
- return self._master_port
-
- @property
- def workers(self):
- return self._workers
-
- @property
- def world_size(self):
- return self._world_size
-
-
-def _bind_workers_method_to_parent(cls, key, user_defined_cls):
- """
- Utilities that enables creating workers inside the same ray.Actor,
- with code written in separate ray.Actors.
-
- Binds the methods of each worker to the WorkerDict.
- Note that we only bind public methods that are decorated by register
- """
- for method_name in dir(user_defined_cls):
- try:
- method = getattr(user_defined_cls, method_name)
- if not callable(method):
- raise TypeError(
- f"{method_name} in {user_defined_cls} is not callable"
- )
- except Exception as e:
- # if it is a property, it will fail because Class doesn't have instance property
- continue
-
- if hasattr(method, MAGIC_ATTR):
-
- def generate_function(name):
-
- def func(self, *args, **kwargs):
- # dispatch to the actual worker
- return getattr(self.worker_dict[key], name)(*args, **kwargs)
-
- return func
-
- func = generate_function(method_name)
- # pass MAGIC_ATTR for outer worker group
- setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR))
- try:
- method_name_with_prefix = key + '_' + method_name
- setattr(cls, method_name_with_prefix, func)
- except Exception as e:
- raise ValueError(f'Fail to set method_name {method_name}') from e
-
-
-def _unwrap_ray_remote(cls):
- if hasattr(cls, '__ray_actor_class__'):
- cls = cls.__ray_actor_class__
- return cls
-
-
-def create_colocated_worker_cls(class_dict: Dict[str, RayClassWithInitArgs]):
- """
- This function should return a class instance that delegates the calls to every
- cls in cls_dict
- """
- cls_dict = {}
- init_args_dict = {}
- worker_cls = None
- for key, cls in class_dict.items():
- if worker_cls is None:
- worker_cls = cls.cls.__ray_actor_class__.__base__
- else:
- if worker_cls != cls.cls.__ray_actor_class__.__base__:
- raise ValueError(
- "the worker class should be the same when sharing the same process"
- )
- cls_dict[key] = cls.cls
- init_args_dict[key] = {'args': cls.args, 'kwargs': cls.kwargs}
-
- if cls_dict.keys() != init_args_dict.keys():
- raise ValueError(
- f"Key mismatch: cls_dict keys ({cls_dict.keys()}) "
- f"must match init_args_dict keys ({init_args_dict.keys()})"
- )
-
- class WorkerDict(worker_cls):
-
- def __init__(self):
- super().__init__()
- self.worker_dict = {}
- for key, user_defined_cls in cls_dict.items():
- user_defined_cls = _unwrap_ray_remote(user_defined_cls)
- # directly instantiate the class without remote
- with patch.dict(os.environ, {'DISABLE_WORKER_INIT': '1'}):
- self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get('args', ()),
- **init_args_dict[key].get('kwargs', {}))
-
- # now monkey-patch the methods from inner class to WorkerDict
- for key, user_defined_cls in cls_dict.items():
- user_defined_cls = _unwrap_ray_remote(user_defined_cls)
- _bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)
-
- remote_cls = ray.remote(WorkerDict)
- remote_cls = RayClassWithInitArgs(cls=remote_cls)
- return remote_cls
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/megatron.py b/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/megatron.py
deleted file mode 100644
index 7dbb88876746df934c258a7029f3630280c08293..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/single_controller/ray/megatron.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Dict, Optional
-
-import ray
-
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker_group import MegatronWorkerGroup
-from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs
-
-
-# NOTE(sgm): for opensource megatron-core
-class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
- """
- MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup
- so that the dispatcher can use it to dispatch data.
- """
-
- def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs):
- super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs)
- self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name='get_megatron_rank_info')
- self._megatron_global_info: DistGlobalInfo = ray.get(
- self.execute_rank_zero_async(method_name='get_megatron_global_info'))
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/training/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/core_algos.py b/mindspeed_llm/tasks/posttrain/rlxf/training/core_algos.py
deleted file mode 100644
index a4ae895f0b4c61eb5207b3ac8b365785beb3fb6e..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/training/core_algos.py
+++ /dev/null
@@ -1,598 +0,0 @@
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Core functions to implement PPO algorithms.
-The function implemented in this file should be used by trainer with different distributed strategies to
-implement PPO
-"""
-
-from copy import deepcopy
-
-import numpy as np
-import torch
-from transformers import AutoTokenizer
-
-import mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional as F
-from mindspeed_llm.tasks.posttrain.rlxf.utils.loggers import Loggers
-from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto
-from mindspeed_llm.tasks.posttrain.verifier.rule_verifier import preprocess_box_response_for_qwen_prompt, format_reward, reasoning_steps_reward, strict_format_reward, \
-base_model_accuracy_reward
-
-logger = Loggers()
-
-
-class AdaptiveKLController:
- def __init__(self, init_kl_coef, target_kl, horizon):
- self.value = init_kl_coef
- self.target = target_kl
- self.horizon = horizon
-
- def update(self, current_kl, n_steps):
- target = self.target
- proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
- mult = 1 + proportional_error * n_steps / self.horizon
- self.value *= mult
-
-
-class FixedKLController:
- """Fixed KL controller."""
-
- def __init__(self, kl_coef):
- self.value = kl_coef
-
- def update(self, current_kl, n_steps):
- pass
-
-
-def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, eos_mask: torch.Tensor,
- gamma: torch.Tensor, lam: torch.Tensor):
- """
- Args:
- token_level_rewards: `(torch.Tensor)`
- shape: (bs, response_length)
- values: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
- gamma: `(float)`
- discounted factor used in RL
- lam: `(float)`
- lambda value when computing Generalized Advantage Estimation
-
- Returns:
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- Returns: `(torch.Tensor)`
- shape: (bs, response_length)
-
- """
- with torch.no_grad():
- lastgaelam = 0
- advantages_reversed = []
- gen_len = token_level_rewards.shape[-1]
-
- for t in reversed(range(gen_len)):
- nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
- delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
- lastgaelam = delta + gamma * lam * lastgaelam
- advantages_reversed.append(lastgaelam)
- advantages = torch.stack(advantages_reversed[::-1], dim=1)
-
- returns = advantages + values
- advantages = F.masked_whiten(advantages, eos_mask)
- advantages = torch.masked_fill(advantages, ~eos_mask, 0)
- return advantages, returns
-
-
-def compute_group_norm_advantage_return(
- token_level_rewards: torch.Tensor,
- eos_mask: torch.Tensor,
- gamma: torch.Tensor,
- lam: torch.Tensor,
- config
-):
- """
- Args:
- token_level_rewards: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
- gamma: `(float)`
- discounted factor used in RL
- lam: `(float)`
- lambda value when computing Generalized Advantage Estimation
-
-
- Returns:
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- Returns: `(torch.Tensor)`
- shape: (bs, response_length)
-
- """
- response_length = token_level_rewards.size(1)
- returns = torch.zeros_like(token_level_rewards)
- cumulative_return = torch.zeros(token_level_rewards.size(0), device=token_level_rewards.device)
-
- # Calculate returns by accumulating discounted rewards
- for t in reversed(range(response_length)):
- cumulative_return = token_level_rewards[:, t] + gamma * cumulative_return
- returns[:, t] = cumulative_return
- advantages = deepcopy(returns)
- if not hasattr(config.algorithm, "advantage_whiten") or config.algorithm.advantage_whiten:
- advantages = F.masked_whiten(advantages, eos_mask)
- else:
- advantages = advantages * eos_mask
- return advantages, returns
-
-
-def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange):
- """
- Args:
- old_log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
- cliprange: (float)
- The clip range used in PPO.
-
- Returns:
- pg_loss: `a scalar torch.Tensor`
- policy gradient loss computed via PPO
- pg_clipfrac: (float)
- a float number indicating the fraction of policy gradient loss being clipped
-
- """
- negative_approx_kl = log_prob - old_log_prob
- ratio = torch.exp(negative_approx_kl)
- ppo_kl = F.masked_mean(-negative_approx_kl, eos_mask)
-
- pg_losses = -advantages * ratio
- pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
-
- pg_loss = F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
- pg_clipfrac = F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
- return pg_loss, pg_clipfrac, ppo_kl
-
-
-def compute_grpo_policy_loss(old_log_prob, log_prob, ref_log_prob, advantages, eos_mask, cliprange, kl_ctrl):
- """
- Args:
- old_log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- log_prob: `(torch.Tensor)`
- shape: (bs, response_length)
- ref_log_prob `(torch.Tensor)`
- shape: (bs, response_length)
- advantages: `(torch.Tensor)`
- shape: (bs, response_length)
- eos_mask: `(torch.Tensor)`
- shape: (bs, response_length)
- cliprange: (float)
- The clip range used in PPO.
-
- Returns:
- pg_loss: `a scalar torch.Tensor`
- policy gradient loss computed via GRPO
- pg_clipfrac: (float)
- a float number indicating the fraction of policy gradient loss being clipped
-
- """
- negative_approx_kl = log_prob - old_log_prob
- ratio = torch.exp(negative_approx_kl)
- ppo_kl = F.masked_mean(-negative_approx_kl, eos_mask)
-
- pg_losses = -advantages * ratio
- pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
-
- pg_loss = F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
- pg_clipfrac = F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
-
- ref_approx_kl = ref_log_prob - log_prob
- ratio_kl = torch.exp(ref_approx_kl)
- kl_losses = ratio_kl - ref_approx_kl - 1
- kl_loss = F.masked_mean(kl_losses, eos_mask)
- pg_loss = pg_loss + kl_loss * kl_ctrl.value
- return pg_loss, pg_clipfrac, ppo_kl
-
-
-def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, penalty) -> torch.FloatTensor:
- """Compute KL divergence given logprob and ref_logprob.
- Args:
- logprob:
- ref_logprob:
-
- Returns:
-
- """
- if penalty == "kl":
- return logprob - ref_logprob
-
- if penalty == "abs":
- return (logprob - ref_logprob).abs()
-
- if penalty == "mse":
- return 0.5 * (logprob - ref_logprob).square()
-
- if penalty == "full":
- # so, here logprob and ref_logprob should contain the logits for every token in vocabulary
- raise NotImplementedError
-
- raise NotImplementedError
-
-
-def find_first_eos_index(tensor, eos_token_id):
- """
- 找到张量中每一行第一个等于 eos_token_id 的索引。
-
- Args:
- tensor (torch.Tensor): 输入的张量,形状为 (batch_size, seq_len)。
-
- Returns:
- torch.Tensor: 每一行中每一行第一个等于 eos_token_id 的索引,形状为 (batch_size,)。
- 如果没有找到,返回 -1。
- """
-
- is_eos = (tensor == eos_token_id)
-
- # 使用 torch.argmax 找到第一个等于 eos_id 的索引
- score_first_eos_index = torch.argmax(is_eos.int(), dim=1)
- reward_first_eos_index = torch.argmax(is_eos.int(), dim=1) + 1
- max_id = is_eos.shape[1] - 1
- reward_first_eos_index = torch.min(reward_first_eos_index, torch.tensor(max_id, device=reward_first_eos_index.device))
- has_eos = is_eos.any(dim=1)
- score_first_eos_index[~has_eos] = -1
- reward_first_eos_index[~has_eos] = -1
-
- return score_first_eos_index, reward_first_eos_index
-
-
-def apply_kl_penalty(config, data: DataProto, kl_ctrl: AdaptiveKLController):
- responses = data.batch['responses']
- response_length = responses.size(1)
- token_level_scores = data.batch['rm_scores']
- batch_size = data.batch.batch_size[0]
- attention_mask = data.batch['attention_mask']
- response_mask = attention_mask[:, -response_length:]
- eos_token_id = data.meta_info['eos_token_id']
- contain_eos_token = torch.any(responses == eos_token_id, dim=-1)
- if config.algorithm.missing_eos_penalty is not None:
- token_level_scores[~contain_eos_token] -= config.algorithm.missing_eos_penalty
- data.batch['rm_scores'] = token_level_scores
- # compute kl between ref_policy and current policy
- if 'ref_log_prob' in data.batch.keys():
- kld = kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'],
- penalty=config.algorithm.kl_penalty) # (batch_size, response_length)
- kld = kld * response_mask
- beta = kl_ctrl.value
- else:
- beta = 0
- kld = torch.zeros_like(response_mask, dtype=torch.float32)
- actual_start = torch.arange(token_level_scores.size(0), device=token_level_scores.device)
- score_first_eos_index, reward_first_eos_index = find_first_eos_index(responses, eos_token_id)
- token_level_rewards = - beta * kld
- token_level_rewards[[actual_start, reward_first_eos_index]] += token_level_scores[[actual_start, score_first_eos_index]]
-
- current_kl = F.masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
- current_kl = torch.mean(current_kl, dim=0).item()
-
- # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
- kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
- data.batch['token_level_rewards'] = token_level_rewards
-
- metrics = {'critic/kl': current_kl, 'critic/kl_coeff': beta}
-
- return data, metrics
-
-
-def compute_advantage(data: DataProto, config):
- responses = data.batch['responses']
- response_length = responses.size(1)
- attention_mask = data.batch['attention_mask']
- response_mask = attention_mask[:, -response_length:]
- token_level_rewards = data.batch['token_level_rewards']
-
- if config.algorithm.adv_estimator == 'gae':
- values = data.batch['values']
- advantages, returns = compute_gae_advantage_return(token_level_rewards=token_level_rewards,
- values=values,
- eos_mask=response_mask,
- gamma=config.algorithm.gamma,
- lam=config.algorithm.lam)
- data.batch['advantages'] = advantages
- data.batch['returns'] = returns
- elif config.algorithm.adv_estimator == 'group_norm':
- advantages, returns = compute_group_norm_advantage_return(token_level_rewards=token_level_rewards,
- eos_mask=response_mask,
- gamma=config.algorithm.gamma,
- lam=config.algorithm.lam,
- config=config)
- data.batch['advantages'] = advantages
- data.batch['returns'] = returns
- else:
- raise NotImplementedError
- return data
-
-
-def compute_score(reward_wg, batch, metrics, config):
- token_level_rewards = torch.zeros_like(batch.batch["responses"], dtype=torch.float32)
-
- assert reward_wg is not None or config.reward.verifier, "At least one reward should be provided for score computing."
-
- # 0 for default/general problems, 1 for math problems
- if "categories" in batch.batch.keys():
- use_verifier_mask = batch.batch["categories"][:, 0].squeeze().bool()
- elif hasattr(config.reward, "verifier") and config.reward.verifier:
- use_verifier_mask = torch.ones(len(batch.batch["input_ids"]), dtype=torch.bool)
- else:
- use_verifier_mask = torch.zeros(len(batch.batch["input_ids"]), dtype=torch.bool)
-
- if reward_wg and (~use_verifier_mask).sum():
- score_tensor = reward_wg.compute_rm_score(batch).batch['rm_scores']
-
- rm_token_level_rewards = get_last_reward(
- batch,
- rm_scores=score_tensor,
- n_sample_batch=config.actor_rollout_ref.actor_rollout.n_samples_per_prompt,
- metrics=metrics,
- valid_mask=~use_verifier_mask
- )
- token_level_rewards[~use_verifier_mask] += rm_token_level_rewards
-
-
- if hasattr(config.reward, "verifier") and config.reward.verifier and use_verifier_mask.sum():
- verifier_token_level_rewards = compute_verifier_score(batch, metrics, config, use_verifier_mask)
- token_level_rewards[use_verifier_mask] += verifier_token_level_rewards
-
- rewards = DataProto.from_dict(
- tensors={
- 'token_level_rewards': token_level_rewards,
- 'rm_scores': token_level_rewards
- }
- )
-
- return batch.union(rewards)
-
-
-def compute_verifier_score(batch, metrics, config, valid_mask):
- tokenizer = AutoTokenizer.from_pretrained(config.training.tokenizer_name_or_path, trust_remote_code=True)
-
- responses = batch.batch["responses"][valid_mask]
- str_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
- question = batch.batch["prompts"][valid_mask]
- str_question = tokenizer.batch_decode(question, skip_special_tokens=True)
-
- reward_index = batch.batch["responses_ori_length"].unsqueeze(1) - 1
- reward_index = reward_index[valid_mask]
-
- logger.logger.info("=" * 50)
- logger.logger.info(">>>>>>>>>> User:\n")
- logger.logger.info(str_question[0])
- logger.logger.info(">>>>>>>>>> Assistant:\n")
- logger.logger.info(str_responses[0])
-
- extra_data = {}
-
- if hasattr(config.training, "dataset_additional_keys"):
- for k in config.training.dataset_additional_keys:
- extra_data[k] = tokenizer.batch_decode(batch.batch[k], skip_special_tokens=True)
- if k == "categories":
- continue
- logger.logger.info(f">>>>>>>>>> {k}")
- logger.logger.info(extra_data[k][valid_mask.nonzero()[0]])
-
- logger.logger.info("=" * 50)
-
- labels = [label for label, mask in zip(extra_data.get("labels"), valid_mask) if mask]
- scores = verifier(str_responses, labels, config, metrics, infos=None)
-
- scores = torch.tensor(
- scores,
- dtype=torch.float32,
- device=reward_index.device
- )
-
- scores = scores.reshape(-1, config.actor_rollout_ref.actor_rollout.n_samples_per_prompt)
- scores = (scores - scores.mean(dim=1, keepdim=True)) / (scores.std(dim=1, keepdim=True) + 1e-8)
- scores = scores.reshape(-1).unsqueeze(1)
-
- token_level_rewards = torch.zeros_like(responses, dtype=torch.float32)
- token_level_rewards.scatter_(1, reward_index, scores)
-
- return token_level_rewards
-
-
-def verifier(responses, labels, config, metrics, infos=None):
- """
- User-defined verifier scoring process.
-
- Parameters:
- ----------
- responses(List[`str`]):
- Actor rollout answers.
- labels(List[`str`]):
- Ground Truth.
- infos(List[`str`], *optional*):
- Additional usable information loaded from the dataset.
-
- Return:
- scores(List[`float`]): Final scores.
- """
- rule_verifier_function = {
- "acc": preprocess_box_response_for_qwen_prompt,
- "format": format_reward,
- "step": reasoning_steps_reward,
- "strict_format": strict_format_reward,
- "base_acc": base_model_accuracy_reward
- }
-
- scores = [0.0] * len(labels)
-
- verifier_function = config.algorithm.verifier_function if hasattr(
- config.algorithm, "verifier_function") else ["acc"]
- verifier_weight = config.algorithm.verifier_weight if hasattr(
- config.algorithm, "verifier_weight") else [1.0]
-
- for idx, fun_verifier in enumerate(verifier_function):
- if fun_verifier not in rule_verifier_function:
- continue
- score = rule_verifier_function[fun_verifier](sequences=responses, answers=labels)
- metrics[f"grpo/{fun_verifier}_rewards/mean"] = sum(score) / max(len(score), 1)
- scores = [all_score + tmp_score * verifier_weight[idx]
- for all_score, tmp_score in zip(scores, score)]
-
- return scores
-
-
-def get_last_reward(data, rm_scores, n_sample_batch, metrics, valid_mask):
- eos_indices = data.batch["responses_ori_length"].unsqueeze(1) - 1
-
- # gather reward from eos position
- rm_scores = rm_scores[valid_mask]
- eos_indices = eos_indices[valid_mask]
- reward = rm_scores.gather(dim=1, index=eos_indices).squeeze(1)
-
- # record raw reward
- metrics[f"grpo/reward_model_rewards/mean"] = sum(reward) / max(len(reward), 1)
-
- # calculate group norm
- reward = reward.reshape(-1, n_sample_batch)
- reward = (reward - reward.mean(dim=1, keepdim=True)) / (reward.std(dim=1, keepdim=True) + 1e-8)
- reward = reward.reshape(-1)
- token_level_rewards = torch.zeros_like(rm_scores).scatter_(dim=1, index=eos_indices, src=reward.unsqueeze(1).to(rm_scores.dtype))
- return token_level_rewards
-
-
-def reduce_metrics(metrics: dict):
- for key, val in metrics.items():
- metrics[key] = np.mean(val)
- return metrics
-
-
-def compute_data_metrics(batch):
- sequence_score = batch.batch['rm_scores'].sum(-1)
- sequence_reward = batch.batch['token_level_rewards'].sum(-1)
-
- response_length = batch.batch['responses'].shape[-1]
-
- advantages = batch.batch['advantages']
- prompt_mask = batch.batch['attention_mask'][:, :-response_length]
- response_mask = batch.batch['attention_mask'][:, -response_length:]
-
- prompt_length = prompt_mask.sum(-1).float()
- response_length = response_mask.sum(-1).float() # (batch_size,)
-
- returns = batch.batch['returns']
- values = batch.batch['values']
-
- metrics = {
- # score
- 'critic/score/mean': torch.mean(sequence_score).detach().item(),
- 'critic/score/max': torch.max(sequence_score).detach().item(),
- 'critic/score/min': torch.min(sequence_score).detach().item(),
- # reward
- 'critic/rewards/mean': torch.mean(sequence_reward).detach().item(),
- 'critic/rewards/max': torch.max(sequence_reward).detach().item(),
- 'critic/rewards/min': torch.min(sequence_reward).detach().item(),
- # adv
- 'critic/advantages/mean': F.masked_mean(advantages, response_mask).detach().item(),
- 'critic/advantages/max': torch.max(advantages[response_mask]).detach().item(),
- 'critic/advantages/min': torch.min(advantages[response_mask]).detach().item(),
- # returns
- 'critic/returns/mean': F.masked_mean(returns, response_mask).detach().item(),
- 'critic/returns/max': torch.max(returns[response_mask]).detach().item(),
- 'critic/returns/min': torch.min(returns[response_mask]).detach().item(),
- # values
- 'critic/values/mean': F.masked_mean(values, response_mask).detach().item(),
- 'critic/values/max': torch.max(values[response_mask]).detach().item(),
- 'critic/values/min': torch.min(values[response_mask]).detach().item(),
- # response length
- 'response_length/mean': torch.mean(response_length).detach().item(),
- 'response_length/max': torch.max(response_length).detach().item(),
- 'response_length/min': torch.min(response_length).detach().item(),
- # prompt length
- 'prompt_length/mean': torch.mean(prompt_length).detach().item(),
- 'prompt_length/max': torch.max(prompt_length).detach().item(),
- 'prompt_length/min': torch.min(prompt_length).detach().item(),
- }
- return metrics
-
-
-def compute_data_online_dpo_metrics(batch):
- sequence_score = batch.batch['rm_scores'].sum(-1)
- response_length = batch.batch['responses'].shape[-1]
-
- prompt_mask = batch.batch['attention_mask'][:, :-response_length]
- response_mask = batch.batch['attention_mask'][:, -response_length:]
-
- prompt_length = prompt_mask.sum(-1).float()
- response_length = response_mask.sum(-1).float()
-
- metrics = {
- # score
- 'reward/score/mean': torch.mean(sequence_score).detach().item(),
- 'reward/score/max': torch.max(sequence_score).detach().item(),
- 'reward/score/min': torch.min(sequence_score).detach().item(),
- # response length
- 'response_length/mean': torch.mean(response_length).detach().item(),
- 'response_length/max': torch.max(response_length).detach().item(),
- 'response_length/min': torch.min(response_length).detach().item(),
- # prompt length
- 'prompt_length/mean': torch.mean(prompt_length).detach().item(),
- 'prompt_length/max': torch.max(prompt_length).detach().item(),
- 'prompt_length/min': torch.min(prompt_length).detach().item(),
- }
- return metrics
-
-
-def compute_grpo_data_metrics(batch):
- sequence_score = batch.batch['rm_scores'].sum(-1)
- sequence_reward = batch.batch['token_level_rewards'].sum(-1)
- response_length = batch.batch['responses'].shape[-1]
- advantages = batch.batch['advantages']
- prompt_mask = batch.batch['attention_mask'][:, :-response_length]
- response_mask = batch.batch['attention_mask'][:, -response_length:]
- prompt_length = prompt_mask.sum(-1).float()
- response_length = response_mask.sum(-1).float()
- returns = batch.batch['returns']
- metrics = {
- # score
- 'grpo/score/mean': torch.mean(sequence_score).detach().item(),
- 'grpo/score/max': torch.max(sequence_score).detach().item(),
- 'grpo/score/min': torch.min(sequence_score).detach().item(),
- # reward
- 'grpo/rewards/mean': torch.mean(sequence_reward).detach().item(),
- 'grpo/rewards/max': torch.max(sequence_reward).detach().item(),
- 'grpo/rewards/min': torch.min(sequence_reward).detach().item(),
- # adv
- 'grpo/advantages/mean': F.masked_mean(advantages, response_mask).detach().item(),
- 'grpo/advantages/max': torch.max(advantages[response_mask]).detach().item(),
- 'grpo/advantages/min': torch.min(advantages[response_mask]).detach().item(),
- 'grpo/returns/mean': F.masked_mean(returns, response_mask).detach().item(),
- 'grpo/returns/max': torch.max(returns[response_mask]).detach().item(),
- 'grpo/returns/min': torch.min(returns[response_mask]).detach().item(),
- # response length
- 'response_length/mean': torch.mean(response_length).detach().item(),
- 'response_length/max': torch.max(response_length).detach().item(),
- 'response_length/min': torch.min(response_length).detach().item(),
- # prompt length
- 'prompt_length/mean': torch.mean(prompt_length).detach().item(),
- 'prompt_length/max': torch.max(prompt_length).detach().item(),
- 'prompt_length/min': torch.min(prompt_length).detach().item(),
- }
- return metrics
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/initialize.py b/mindspeed_llm/tasks/posttrain/rlxf/training/initialize.py
deleted file mode 100644
index d1a01469ceb0765beeb9d8647a552e69d763291b..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/training/initialize.py
+++ /dev/null
@@ -1,325 +0,0 @@
-from functools import wraps
-from omegaconf import OmegaConf
-
-import ray
-import torch
-
-import megatron
-from megatron.core import mpu
-from megatron.training import get_args, print_rank_0
-from megatron.training.arguments import validate_args
-from megatron.training.yaml_arguments import validate_yaml
-from megatron.training.checkpointing import load_args_from_checkpoint
-from megatron.training.global_vars import set_global_variables
-from megatron.training.initialize import (
- _set_random_seed,
- _init_autoresume, _initialize_tp_communicators,
-)
-
-from mindspeed.core.tensor_parallel.ascend_turbo.initialize import initialize_cfg_from_args
-from mindspeed_llm.training.arguments import parse_args_decorator
-from mindspeed_llm.tasks.utils.error_utils import ensure_valid
-from mindspeed_llm.training.utils import seed_all
-from mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state import initialize_model_parallel_2megatron
-from mindspeed_llm.training.initialize import _compile_dependencies
-import mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state as ps
-
-
-def parse_args_from_config(role, config):
- import sys
- # update role and model configs
- OmegaConf.set_struct(config, False) # unset read only properties
- role_args_from_config = getattr(config, role, None) if role in ["critic", "reward"] else getattr(
- config.actor_rollout_ref, role, None)
- model_name = role_args_from_config.model
- model_config = config['model'][model_name]
- common_config = config['training']
- # override priority: role > training (common) > model
- role_args_from_config = OmegaConf.merge(model_config, common_config, role_args_from_config)
- role_args_from_config.pop("model")
- OmegaConf.set_struct(config, True)
-
- # Parsing training parameters.
- for key, value in role_args_from_config.items():
- if isinstance(value, bool):
- if value:
- sys.argv.append(f"--{key.replace('_', '-')}")
- else:
- sys.argv.append(f"--{key.replace('_', '-')}={value}")
-
-
-def initialize_megatron(
- extra_args_provider=None,
- args_defaults={},
- ignore_unknown_args=False,
- allow_no_cuda=False,
- skip_mpu_initialization=False,
- role=None,
- config=None,
- two_megatron=False
-):
- """Set global variables, initialize distributed, and
- set autoresume and random seeds.
- `allow_no_cuda` should not be set unless using megatron for cpu only
- data processing. In general this arg should not be set unless you know
- what you are doing.
- Returns a function to finalize distributed env initialization
- (optionally, only when args.lazy_mpu_init == True)
- """
- if not allow_no_cuda:
- # Make sure cuda is available.
- ensure_valid(torch.cuda.is_available(), "Megatron requires CUDA.")
-
- # Parse arguments
- import sys
- origin_sys_argv = sys.argv
- if role and config is not None:
- sys.argv = [sys.argv[0]]
- parse_args_from_config(role, config)
- parse_args = parse_args_decorator(megatron.training.arguments.parse_args)
- args = parse_args(extra_args_provider, ignore_unknown_args)
- args.role = role
- sys.argv = origin_sys_argv
-
- if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
- ensure_valid(args.load is not None,
- "--use-checkpoints-args requires --load argument")
- load_args_from_checkpoint(args)
-
- if args.yaml_cfg is not None:
- args = validate_yaml(args, args_defaults)
- else:
- validate_args(args, args_defaults)
-
- # set global args, build tokenizer, and set adlr-autoresume,
- # tensorboard-writer, and timers.
- set_global_variables(args)
-
- # add deterministic computing function
- if args.use_deter_comp:
- seed_all(args.seed)
- print_rank_0("deterministic computing is applied for npu.")
-
- # torch.distributed initialization
- def finish_mpu_init():
- args = get_args()
- # Pytorch distributed.
- _initialize_distributed(two_megatron)
-
- # Random seeds for reproducibility.
- if args.rank == 0:
- print("> setting random seeds to {} ...".format(args.seed))
- _set_random_seed(args.seed, args.data_parallel_random_init)
- if args.use_mc2:
- initialize_cfg_from_args(args)
-
- if skip_mpu_initialization:
- return None
-
- args = get_args()
- if args.lazy_mpu_init:
- args.use_cpu_initialization = True
- # delayed initialization of DDP-related stuff
- # We only set basic DDP globals
- mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
- # and return function for external DDP manager
- # to call when it has DDP initialized
- mpu.set_tensor_model_parallel_rank(args.rank)
- return finish_mpu_init
- else:
- # Megatron's MPU is the master. Complete initialization right away.
- finish_mpu_init()
-
- # Autoresume.
- _init_autoresume()
-
- # Compile dependencies.
- _compile_dependencies()
-
- if args.tp_comm_overlap:
- _initialize_tp_communicators()
-
- # No continuation function
- return None
-
-
-def _initialize_distributed(two_megatron=False):
- """Initialize torch.distributed and core model parallel."""
- args = get_args()
- from datetime import timedelta
-
- device_count = torch.cuda.device_count()
- if torch.distributed.is_initialized():
- if args.rank == 0:
- print(
- "torch distributed is already initialized, "
- "skipping initialization ...",
- flush=True,
- )
- args.rank = torch.distributed.get_rank()
- args.world_size = torch.distributed.get_world_size()
- else:
- if args.rank == 0:
- print("> initializing torch distributed ...", flush=True)
- # Manually set the device ids.
- if device_count > 0:
- if args.stage in ["ray_ppo", "ray_online_dpo", "ray_grpo"]:
- allocated_device = int(ray.get_runtime_context().get_accelerator_ids()["NPU"][0])
- torch.cuda.set_device(allocated_device)
- else:
- device = args.rank % device_count
- if args.local_rank is not None:
- if args.local_rank != device:
- raise ValueError("expected local-rank to be the same as rank % device-count.")
- else:
- args.local_rank = device
- torch.cuda.set_device(device)
- # Call the init process
- torch.distributed.init_process_group(
- backend=args.distributed_backend,
- world_size=args.world_size,
- rank=args.rank,
- timeout=timedelta(minutes=args.distributed_timeout_minutes),
- )
-
- # Set the tensor model-parallel, pipeline model-parallel, and
- # data-parallel communicators.
- if device_count > 0:
- if mpu.model_parallel_is_initialized():
- print("model parallel is already initialized")
- else:
- if not two_megatron: # normal case
- mpu.initialize_model_parallel(
- args.tensor_model_parallel_size,
- args.pipeline_model_parallel_size,
- args.virtual_pipeline_model_parallel_size,
- args.pipeline_model_parallel_split_rank,
- context_parallel_size=args.context_parallel_size,
- expert_model_parallel_size=args.expert_model_parallel_size,
- distributed_timeout_minutes=args.distributed_timeout_minutes,
- nccl_communicator_config_path=args.nccl_communicator_config_path,
- order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp',
- )
- else:
- # It is a little tricky here, that both the training and inference nodes need to build the two groups.
- TRAIN_SIZE = args.num_gpus_for_train
- INFER_SIZE = args.num_gpus_for_infer
- if torch.distributed.get_world_size() != TRAIN_SIZE + INFER_SIZE:
- raise ValueError("TRAIN_SIZE + INFER_SIZE should equal to total GPU num.")
- initialize_model_parallel_2megatron(
- args.tensor_model_parallel_size,
- args.pipeline_model_parallel_size,
- args.virtual_pipeline_model_parallel_size,
- args.pipeline_model_parallel_split_rank,
- context_parallel_size=args.context_parallel_size,
- expert_model_parallel_size=args.expert_model_parallel_size,
- distributed_timeout_minutes=args.distributed_timeout_minutes,
- nccl_communicator_config_path=args.nccl_communicator_config_path,
- infer_size=INFER_SIZE,
- )
- initialize_model_parallel_2megatron( # currently only use TP for Inference
- args.tensor_model_parallel_size,
- 1, # inference do not use PP
- distributed_timeout_minutes=args.distributed_timeout_minutes,
- nccl_communicator_config_path=args.nccl_communicator_config_path,
- infer_size=INFER_SIZE,
- is_second_megatron=True
- )
-
- if ps.in_mg2_inference_group(): # set true TP, PP args for inference groups
- args.pipeline_model_parallel_size = 1
- args.virtual_pipeline_model_parallel_size = 1
-
- if args.rank != 0 and ps.is_mg2_first_rank(): # first rank in inference
- print(
- f"> initialized inference tensor model parallel with size "
- f"{mpu.get_tensor_model_parallel_world_size()}"
- )
-
- if args.rank == 0:
- print(
- f"> initialized tensor model parallel with size "
- f"{mpu.get_tensor_model_parallel_world_size()}"
- )
- print(
- f"> initialized pipeline model parallel with size "
- f"{mpu.get_pipeline_model_parallel_world_size()}"
- )
-
-
-def barrier_wrapper(fn):
- @wraps(fn)
- def wrapper(*args, **kwargs):
- mg2_available = ps._MEGATRON2_INITIALIZED
- no_group_info = 'group' not in kwargs or kwargs['group'] is None
- if no_group_info and mg2_available:
- dist_group = ps.get_mg2_local_group()
- kwargs['group'] = dist_group
- return fn(*args, **kwargs)
-
- return wrapper
-
-
-def broadcast_wrapper(fn):
- @wraps(fn)
- def wrapper(*args, **kwargs):
- no_group_info = 'group' not in kwargs or kwargs['group'] is None
- mg2_available = ps._MEGATRON2_INITIALIZED
- if no_group_info and mg2_available:
- dist_group = ps.get_mg2_local_group()
- kwargs['group'] = dist_group
- args = list(args)
- if len(args) >= 2:
- src_rank = ps.get_mg2_local_ranks()[0]
- args[1] = src_rank
- return fn(*args, **kwargs)
-
- return wrapper
-
-
-def get_world_size_wrapper(fn):
- @wraps(fn)
- def wrapper(*args, **kwargs):
- mg2_available = ps._MEGATRON2_INITIALIZED
- if mg2_available and not args and not kwargs:
- dist_group = ps.get_mg2_local_group()
- args = [dist_group]
- return fn(*args, **kwargs)
-
- return wrapper
-
-
-def get_elapsed_time_all_ranks(self, names, reset, barrier):
- if barrier:
- torch.distributed.barrier()
-
- world_size = torch.distributed.get_world_size()
- args = get_args()
- if args.role == "actor_rollout": # patch here: use mg2 local rank
- rank = ps.get_mg2_local_rank()
- group = ps.get_mg2_local_group()
- else:
- rank = torch.distributed.get_rank()
- group = None
-
- rank_name_to_time = torch.zeros(
- (world_size, len(names)), dtype=torch.float, device=torch.cuda.current_device()
- )
- for i, name in enumerate(names):
- if name in self._timers:
- rank_name_to_time[rank, i] = self._timers[name].elapsed(reset=reset)
-
- torch.distributed._all_gather_base(
- rank_name_to_time.view(-1), rank_name_to_time[rank, :].view(-1), group=group
- )
-
- return rank_name_to_time
-
-
-def is_last_rank():
- rank = torch.distributed.get_rank()
- if ps._MEGATRON2_INITIALIZED:
- return rank == ps.get_mg2_local_ranks()[-1]
- else:
- return torch.distributed.get_rank() == (
- torch.distributed.get_world_size() - 1)
\ No newline at end of file
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/parallel_state.py b/mindspeed_llm/tasks/posttrain/rlxf/training/parallel_state.py
deleted file mode 100644
index ebfad2062ddd3086d7a9154c301f5fc0e1421fc8..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/training/parallel_state.py
+++ /dev/null
@@ -1,390 +0,0 @@
-from functools import wraps
-from typing import Optional
-from datetime import timedelta
-import torch
-
-
-# global variables for two megatron running
-_TENSOR_AND_CONTEXT_PARALLEL_GROUP = None
-_TENSOR_AND_CONTEXT_PARALLEL_GLOBAL_RANKS = None
-_MEGATRON2_LOCAL_RANKS = None
-_MEGATRON2_LOCAL_GROUP = None
-_MEGATRON2_INITIALIZED = False
-
-
-def initialize_model_parallel_2megatron(
- tensor_model_parallel_size: int = 1,
- pipeline_model_parallel_size: int = 1,
- virtual_pipeline_model_parallel_size: Optional[int] = None,
- pipeline_model_parallel_split_rank: Optional[int] = None,
- use_sharp: bool = False,
- context_parallel_size: int = 1,
- expert_model_parallel_size: int = 1,
- nccl_communicator_config_path: Optional[str] = None,
- distributed_timeout_minutes: int = 30,
- order: str = "tp-cp-ep-dp-pp",
- infer_size: int = 0,
- is_second_megatron: bool = False
-) -> None:
- """Initialize model data parallel groups with offset.
- Assert two groups are initialized :
- training group contains GPU with rank[0, world_size - infer_world_size - 1]
- inference group contains GPU with rank [world_size - infer_world_size, world_size - 1]
- rank_offset is only set for inference groups.
- """
- import megatron.core.parallel_state as ps
-
- global _MEGATRON2_LOCAL_RANKS
- global _MEGATRON2_LOCAL_GROUP
- global _MEGATRON2_INITIALIZED
-
- timeout = timedelta(minutes=distributed_timeout_minutes)
-
- # Get world size and rank. Ensure some consistencies.
- if not torch.distributed.is_initialized():
- raise RuntimeError("torch.distributed not initialized.")
- world_size: int = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
- train_size = world_size - infer_size
-
- nccl_comm_cfgs = {}
- if nccl_communicator_config_path is not None:
- try:
- import yaml
- except ImportError as e:
- raise ImportError(
- "Cannot import `yaml`. Setting custom nccl communicator configs "
- "requires the yaml package."
- ) from e
-
- with open(nccl_communicator_config_path, "r") as stream:
- nccl_comm_cfgs = yaml.safe_load(stream)
-
- # build megatron2 groups for inference and training
- if infer_size and not is_second_megatron: # only build megatron2 groups once with positive inf_size
- if _MEGATRON2_LOCAL_GROUP is not None:
- raise RuntimeError("megatron local group is already initialized.")
- ranks_mg2_inference = range(train_size, world_size)
- group_mg2_inference = torch.distributed.new_group(
- ranks_mg2_inference, timeout=timeout, pg_options=ps.get_nccl_options('dp_cp', nccl_comm_cfgs)
- )
- ranks_mg2_training = range(train_size)
- group_mg2_training = torch.distributed.new_group(
- ranks_mg2_training, timeout=timeout, pg_options=ps.get_nccl_options('dp_cp', nccl_comm_cfgs)
- )
- if rank in ranks_mg2_inference: # inf groups
- _MEGATRON2_LOCAL_GROUP = group_mg2_inference
- _MEGATRON2_LOCAL_RANKS = ranks_mg2_inference
- else:
- _MEGATRON2_LOCAL_GROUP = group_mg2_training
- _MEGATRON2_LOCAL_RANKS = ranks_mg2_training
-
- # update world_size and rank_offset
- if is_second_megatron: # inference group, i.e. the second group
- world_size = infer_size
- rank_offset = train_size
- else:
- world_size = train_size
- rank_offset = 0
-
- if (
- world_size
- % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size)
- != 0
- ):
- raise RuntimeError(
- f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
- f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size}) "
- f"x context_parallel_size ({context_parallel_size})"
- )
-
- data_parallel_size: int = world_size // (
- tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
- )
-
- if data_parallel_size % expert_model_parallel_size != 0:
- raise RuntimeError(
- f"data_parallel_size ({data_parallel_size}) is not divisible by expert_model_parallel_size "
- )
-
- if expert_model_parallel_size > 1 and context_parallel_size > 1:
- raise RuntimeError(
- f"combination of expert model prallellism and context parallelism is not supported"
- )
-
- # num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
- # num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
-
- if virtual_pipeline_model_parallel_size is not None:
- if not pipeline_model_parallel_size > 2:
- raise RuntimeError(
- "pipeline-model-parallel size should be greater than 2 with interleaved schedule"
- )
- ps._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
- ps._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
-
- if pipeline_model_parallel_split_rank is not None:
- ps._PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
-
- rank_generator = ps.RankGenerator(
- tp=tensor_model_parallel_size,
- ep=expert_model_parallel_size,
- dp=data_parallel_size,
- pp=pipeline_model_parallel_size,
- cp=context_parallel_size,
- order=order,
- offset=rank_offset
- )
-
- timeout = timedelta(minutes=distributed_timeout_minutes)
-
- # Build the data-parallel groups.
- for ranks in rank_generator.get_ranks('dp'):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('dp', nccl_comm_cfgs)
- )
- group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")
- if rank in ranks:
- if ps._DATA_PARALLEL_GROUP is not None:
- raise RuntimeError('data parallel group is already initialized')
- ps._DATA_PARALLEL_GROUP = group
- ps._DATA_PARALLEL_GROUP_GLOO = group_gloo
- ps._DATA_PARALLEL_GLOBAL_RANKS = ranks
- for ranks_with_cp in rank_generator.get_ranks('dp-cp'):
- group_with_cp = torch.distributed.new_group(
- ranks_with_cp, timeout=timeout, pg_options=ps.get_nccl_options('dp_cp', nccl_comm_cfgs)
- )
- group_with_cp_gloo = torch.distributed.new_group(
- ranks_with_cp, timeout=timeout, backend="gloo"
- )
- if rank in ranks_with_cp:
- ps._DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
- ps._DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
- ps._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp
-
- # Build the context-parallel groups.
- global _TENSOR_AND_CONTEXT_PARALLEL_GROUP
- global _TENSOR_AND_CONTEXT_PARALLEL_GLOBAL_RANKS
- for ranks in rank_generator.get_ranks('cp'):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('cp', nccl_comm_cfgs)
- )
- if rank in ranks:
- if ps._CONTEXT_PARALLEL_GROUP is not None:
- raise RuntimeError('context parallel group is already initialized')
- ps._CONTEXT_PARALLEL_GROUP = group
- ps._CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
-
- for ranks in rank_generator.get_ranks('tp-cp'):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp_cp', nccl_comm_cfgs)
- )
- if rank in ranks:
- if ps._TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None:
- raise RuntimeError('tensor and context parallel group is already initialized')
- _TENSOR_AND_CONTEXT_PARALLEL_GROUP = group
- _TENSOR_AND_CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
-
-
- # Build the model-parallel groups.
- for ranks in rank_generator.get_ranks('tp-pp'):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('mp', nccl_comm_cfgs)
- )
- if rank in ranks:
- if ps._MODEL_PARALLEL_GROUP is not None:
- raise RuntimeError('model parallel group is already initialized')
- ps._MODEL_PARALLEL_GROUP = group
-
- # Build the tensor model-parallel groups.
- for ranks in rank_generator.get_ranks('tp'):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp', nccl_comm_cfgs)
- )
- if rank in ranks:
- if ps._TENSOR_MODEL_PARALLEL_GROUP is not None:
- raise RuntimeError('tensor model parallel group is already initialized')
-
- ps._TENSOR_MODEL_PARALLEL_GROUP = group
- ps._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks
-
- # Build the pipeline model-parallel groups and embedding groups
- # (first and last rank in each pipeline model-parallel group).
-
- for ranks in rank_generator.get_ranks('pp'):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('pp', nccl_comm_cfgs)
- )
- if rank in ranks:
- if ps._PIPELINE_MODEL_PARALLEL_GROUP is not None:
- raise RuntimeError('pipeline model parallel group is already initialized')
-
- ps._PIPELINE_MODEL_PARALLEL_GROUP = group
- ps._PIPELINE_GLOBAL_RANKS = ranks
- # Setup embedding group (to exchange gradients between
- # first and last stages).
- if len(ranks) > 1:
- embedding_ranks = [ranks[0], ranks[-1]]
- position_embedding_ranks = [ranks[0]]
- if pipeline_model_parallel_split_rank is not None:
- if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
- embedding_ranks = [
- ranks[0],
- ranks[pipeline_model_parallel_split_rank],
- ranks[-1],
- ]
- if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
- position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
- else:
- embedding_ranks = ranks
- position_embedding_ranks = ranks
-
- group = torch.distributed.new_group(
- embedding_ranks, timeout=timeout, pg_options=ps.get_nccl_options('embd', nccl_comm_cfgs)
- )
- if rank in embedding_ranks:
- if ps._EMBEDDING_GROUP is not None:
- raise RuntimeError('embedding group is already initialized')
- ps._EMBEDDING_GROUP = group
- if rank in ranks:
- ps._EMBEDDING_GLOBAL_RANKS = embedding_ranks
-
- group = torch.distributed.new_group(
- position_embedding_ranks,
- timeout=timeout,
- pg_options=ps.get_nccl_options('embd', nccl_comm_cfgs),
- )
- if rank in position_embedding_ranks:
- if ps._POSITION_EMBEDDING_GROUP is not None:
- raise RuntimeError('position embedding group is already initialized')
- ps._POSITION_EMBEDDING_GROUP = group
- if rank in ranks:
- ps._POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
-
- # Build the tensor + data parallel groups.
- for ranks in rank_generator.get_ranks('tp-dp-cp'):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp_dp_cp', nccl_comm_cfgs)
- )
- if rank in ranks:
- if ps._TENSOR_AND_DATA_PARALLEL_GROUP is not None:
- raise RuntimeError('Tensor + data parallel group is already initialized')
- ps._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
- for ranks in rank_generator.get_ranks('tp-dp'):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp_dp', nccl_comm_cfgs)
- )
- if rank in ranks:
- ps._TENSOR_AND_DATA_PARALLEL_GROUP = group
-
- # Build the tensor + expert parallel groups
- for ranks in rank_generator.get_ranks('tp-ep', independent_ep=True):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('tp_exp', nccl_comm_cfgs)
- )
- if rank in ranks:
- if ps._TENSOR_AND_EXPERT_PARALLEL_GROUP is not None:
- raise RuntimeError('Tensor + expert parallel group is already initialized')
- ps._TENSOR_AND_EXPERT_PARALLEL_GROUP = group
-
- for ranks in rank_generator.get_ranks('ep', independent_ep=True):
- group = torch.distributed.new_group(
- ranks, pg_options=ps.get_nccl_options('exp', nccl_comm_cfgs)
- )
- if rank in ranks:
- if ps._EXPERT_MODEL_PARALLEL_GROUP is not None:
- raise RuntimeError('Expert parallel group is already initialized')
- ps._EXPERT_MODEL_PARALLEL_GROUP = group
-
- for ranks in rank_generator.get_ranks('dp', independent_ep=True):
- group = torch.distributed.new_group(
- ranks, timeout=timeout, pg_options=ps.get_nccl_options('dp_modulo_exp', nccl_comm_cfgs)
- )
- group_gloo = torch.distributed.new_group(ranks, backend="gloo")
- if rank in ranks:
- if ps._DATA_MODULO_EXPERT_PARALLEL_GROUP is not None:
- raise RuntimeError('Data modulo expert group is already initialized')
- ps._DATA_MODULO_EXPERT_PARALLEL_GROUP = group
- ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = group_gloo
- for ranks in rank_generator.get_ranks('dp-cp', independent_ep=True):
- # Lazy initialization of the group
- group = ps._DATA_MODULO_EXPERT_PARALLEL_GROUP
- group_gloo = ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
- if rank in ranks:
- ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = group
- ps._DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = group_gloo
- # Initialize global memory buffer
- # This isn't really "parallel state" but there isn't another good place to
- # put this. If we end up with a more generic initialization of megatron-core
- # we could stick it there
- if not is_second_megatron: # global memory buffer should only be set once
- ps._set_global_memory_buffer()
- else:
- _MEGATRON2_INITIALIZED = True
-
-
-def is_mg2_first_rank():
- """
- Check if current node is the first node in the Megatron2 local group.
- Use this to extend the old usage rank == 0.
- """
- if _MEGATRON2_LOCAL_RANKS is None:
- raise RuntimeError('Megatron2 group is not initialized')
- return torch.distributed.get_rank() == _MEGATRON2_LOCAL_RANKS[0]
-
-
-def in_mg2_inference_group():
- """
- """
- if _MEGATRON2_LOCAL_RANKS is None:
- raise RuntimeError('Megatron2 group is not initialized')
- return _MEGATRON2_LOCAL_RANKS[0] != 0
-
-
-def get_mg2_local_group():
- if _MEGATRON2_LOCAL_GROUP is None:
- raise RuntimeError('Megatron2 group is not initialized')
- return _MEGATRON2_LOCAL_GROUP
-
-
-def get_mg2_local_ranks():
- if _MEGATRON2_LOCAL_RANKS is None:
- raise RuntimeError('Megatron2 ranks are not initialized')
- return _MEGATRON2_LOCAL_RANKS
-
-
-def get_mg2_first_rank():
- """
- When the same world size is divided into multiple process groups in the actor-train
- and actor-rollout worker roles, this method needs to be converted to local.
- """
- if _MEGATRON2_LOCAL_RANKS is None:
- raise RuntimeError('Megatron2 group is not initialized')
- return _MEGATRON2_LOCAL_RANKS[0]
-
-
-def get_mg2_local_rank():
- """
- When the same world size is divided into multiple process groups in the actor-train
- and actor-rollout worker roles, this method needs to be converted to local.
- """
- return torch.distributed.get_rank() - get_mg2_first_rank()
-
-
-def rank_generator_init_wrapper(init_func):
- @wraps(init_func)
- def wrapper(self, *args, **kwargs):
- if 'offset' in kwargs:
- self.offset = kwargs.pop('offset')
- else:
- self.offset = 0
- init_func(self, *args, **kwargs)
- return wrapper
-
-
-def rank_generator_get_ranks_wrapper(get_ranks):
- @wraps(get_ranks)
- def wrapper(self, *args, **kwargs):
- ranks_list = get_ranks(self, *args, **kwargs)
- return [[item + self.offset for item in ranks] for ranks in ranks_list]
- return wrapper
\ No newline at end of file
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/training/parameter_mapping.py b/mindspeed_llm/tasks/posttrain/rlxf/training/parameter_mapping.py
deleted file mode 100644
index 8dc3227ae2af364426789cd86423bdeff9ecbc54..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/training/parameter_mapping.py
+++ /dev/null
@@ -1,202 +0,0 @@
-# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
-import hashlib
-
-import torch
-import torch.distributed as dist
-from megatron.training import get_args
-from megatron.core import parallel_state as mpu
-
-import mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state as ps
-
-
-RECEIVE_PARAM_NUMS = None
-MODEL_SEND_GROUP = None
-MODEL_RECEIVE_GROUPS = []
-
-
-def init_comm_groups():
- """
- Initialize model auto-mapping communication groups
- Scenario example:
- Training: tp 2 pp 2
- Inference: tp 2 dp 2
- Principle: Same tp rank weight data is sequentially broadcast to different dp on the inference side,
- and aggregated across different dp for training pp dimension.
- Number of communication groups: tp * pp
- groups = [[0,4,6],[2,4,6],[1,5,7],[3,5,7]]
- """
- args = get_args()
- pipeline_parallel_groups = []
- data_parallel_groups = []
- for i in range(args.tensor_model_parallel_size):
- ranks = list(range(i, args.num_gpus_for_train, args.tensor_model_parallel_size))
- pipeline_parallel_groups.append(ranks)
- ranks = list(range(args.num_gpus_for_train + i, args.num_gpus_for_train + args.num_gpus_for_infer,
- args.tensor_model_parallel_size))
- data_parallel_groups.append(ranks)
- comm_groups = []
- for data_parallel_group, pipeline_parallel_group in zip(data_parallel_groups, pipeline_parallel_groups):
- for rank in pipeline_parallel_group:
- tmp_group = [rank] + list(data_parallel_group)
- comm_groups.append(tmp_group)
- print(comm_groups)
- return comm_groups
-
-
-def init_parameter_mapping_distributed():
- """
- Automatically mapping communication groups for model initialization
- Scenario example:
- Training: tp 2 pp 2
- Inference: tp 2 dp 2
- Given a world size of 8 as above
- rank 0 acts as the sender, the communication group is [0, 4, 6]
- rank 2 acts as the sender, the communication group is [2, 4, 6]
- rank 4 acts as the receiver, the communication group is [[0, 4, 6], [2, 4, 6]]
- """
- global MODEL_SEND_GROUP, MODEL_RECEIVE_GROUPS
- groups = init_comm_groups()
- rank = dist.get_rank()
- for group in groups:
- tmp_group = dist.new_group(group)
- if not ps.in_mg2_inference_group() and rank in group:
- MODEL_SEND_GROUP = tmp_group
- if ps.in_mg2_inference_group() and rank in group:
- MODEL_RECEIVE_GROUPS.append((group[0], tmp_group))
- print("init_distributed_sucess")
-
-
-def get_model_send_group():
- return MODEL_SEND_GROUP
-
-
-def get_model_receive_groups():
- return MODEL_RECEIVE_GROUPS
-
-
-def get_receive_param_nums():
- return RECEIVE_PARAM_NUMS
-
-
-def param_nums_is_initialized():
- return get_receive_param_nums() is not None
-
-
-def sync_param_nums(moudle: torch.nn.Module):
- """
- Synchronize the number of parameters for sending and receiving, ensuring that
- the number of broadcast communications aligns and the model parameters align.
- """
- args = get_args()
- if not ps.in_mg2_inference_group():
- args_need_broadcast = torch.tensor([args.iteration, args.consumed_train_samples], dtype=torch.int64,
- device=torch.cuda.current_device())
- dist.broadcast(args_need_broadcast, group=get_model_send_group(), src=dist.get_rank())
- num_parameters = torch.tensor([sum(1 for _ in moudle.named_parameters())], dtype=torch.int64,
- device=torch.cuda.current_device())
- dist.broadcast(num_parameters, group=get_model_send_group(), src=dist.get_rank())
- else:
- global RECEIVE_PARAM_NUMS
-
- recv_param_nums = []
- for group in get_model_receive_groups():
- args_need_broadcast = torch.empty(2, dtype=torch.int64, device=torch.cuda.current_device())
- dist.broadcast(args_need_broadcast, group=group[1], src=group[0], async_op=True)
- args.iteration, args.consumed_train_samples = args_need_broadcast
-
- tmp_num_parameters = torch.empty(1, dtype=torch.int64, device=torch.cuda.current_device())
- dist.broadcast(tmp_num_parameters, group=group[1], src=group[0], async_op=True)
- recv_param_nums.append(tmp_num_parameters)
- RECEIVE_PARAM_NUMS = recv_param_nums
-
-
-def compute_model_hash(model, hash_func):
- hash_value = hash_func()
- for param in model.parameters():
- param_bytes = param.data.cpu().numpy().tobytes()
- hash_value.update(param_bytes)
- md5_tensor = torch.tensor([int(h, 16) for h in hash_value.hexdigest()])
- return md5_tensor
-
-
-def send_model_to_infer_model(moudle: torch.nn.Module):
- """
- Decompose model information and transfer it directly from the model.
- """
- args = get_args()
- model_send_group = get_model_send_group()
- if args.md5_validate:
- hash_value = hashlib.md5()
-
- is_reuse_output_weights = (not args.untie_embeddings_and_output_weights and
- args.pipeline_model_parallel_size >= 2 and
- mpu.is_pipeline_last_stage(ignore_virtual=True))
- for name, param in moudle.named_parameters():
- if is_reuse_output_weights and 'output_layer.weight' in name:
- continue
- param_info_data = param.data
- dist.broadcast(param_info_data, group=model_send_group, src=dist.get_rank(), async_op=True)
- if args.md5_validate:
- param_bytes = param_info_data.to(torch.float32).cpu().numpy().tobytes()
- hash_value.update(param_bytes)
- if args.md5_validate:
- md5_tensor = torch.tensor([int(h, 16) for h in hash_value.hexdigest()], dtype=torch.int64,
- device=torch.cuda.current_device())
- dist.broadcast(md5_tensor, group=model_send_group, src=dist.get_rank(), async_op=True)
-
-
-def recv_model_from_train_model(moudle: torch.nn.Module):
- """
- Decompose model parameters and directly loaded them into the model.
- """
- args = get_args()
- model_receive_groups = get_model_receive_groups()
- recv_param_nums = get_receive_param_nums()
- flag = True
- idx = 0
- if args.md5_validate:
- hash_value = hashlib.md5()
- for _, param in moudle.named_parameters():
- if flag:
- cur_num = int(recv_param_nums[idx])
- cur_group = model_receive_groups[idx]
- flag = False
-
- param_info_data = param.data
- torch.distributed.broadcast(param_info_data, group=cur_group[1], src=cur_group[0], async_op=False)
- if args.md5_validate:
- param_bytes = param_info_data.to(torch.float32).cpu().numpy().tobytes()
- hash_value.update(param_bytes)
-
- cur_num -= 1
- if cur_num == 0:
- if args.md5_validate:
- md5_tensor = torch.tensor([int(h, 16) for h in hash_value.hexdigest()], dtype=torch.int64,
- device=torch.cuda.current_device())
- md5_tensor_src = torch.zeros_like(md5_tensor, dtype=torch.int64, device=torch.cuda.current_device())
- dist.broadcast(md5_tensor_src, group=cur_group[1], src=cur_group[0], async_op=False)
- if torch.equal(md5_tensor_src, md5_tensor):
- print("MD5 Hash: The weights of the two models match.")
- else:
- print("MD5 Hash: The weights of the two models do not match.")
- hash_value = hashlib.md5()
- flag = True
- idx += 1
-
- if cur_num != 0:
- if args.md5_validate:
- md5_tensor = torch.tensor([int(h, 16) for h in hash_value.hexdigest()], dtype=torch.int64,
- device=torch.cuda.current_device())
- md5_tensor_src = torch.zeros_like(md5_tensor, dtype=torch.int64, device=torch.cuda.current_device())
- dist.broadcast(md5_tensor_src, group=cur_group[1], src=cur_group[0], async_op=False)
- if torch.equal(md5_tensor_src, md5_tensor):
- print("MD5 Hash: The weights of the two models match.")
- else:
- print("MD5 Hash: The weights of the two models do not match.")
-
-
-def run_auto_mapping(model):
- if ps.in_mg2_inference_group():
- recv_model_from_train_model(model)
- else:
- send_model_to_infer_model(model)
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/utils/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/utils/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/utils/loggers.py b/mindspeed_llm/tasks/posttrain/rlxf/utils/loggers.py
deleted file mode 100644
index 6d884ea6daacf72eed03cfd1db99197d68322563..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/utils/loggers.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
-
-import logging
-from datetime import datetime
-
-
-class Loggers(object):
- def __init__(self,
- name='root',
- logger_level='DEBUG',
- ):
- self.logger = logging.getLogger(name)
- self.logger.setLevel(logger_level)
-
- def handle_msg(self, msg, level, iteration, steps):
- current_time = str(datetime.now()).split(".")[0]
-
- if isinstance(msg, dict):
- fmt_msg = f"[{current_time}] "
- fmt_msg += f"iteration: {iteration} / {steps} | "
- for key in msg:
- fmt_msg += f"{key} : {format(msg[key], '.4f')} | "
- fmt_msg = fmt_msg[:-2]
- else:
- fmt_msg = f"[{current_time}] {level} " + str(msg)
- return fmt_msg
-
- def info(self, msg, iteration, steps):
- format_msg = self.handle_msg(msg, "INFO", iteration, steps)
- self.logger.info(format_msg)
-
- def warning(self, msg, iteration, steps):
- format_msg = self.handle_msg(msg, "WARNING", iteration, steps)
- self.logger.warning(format_msg)
-
- def debug(self, msg, iteration, steps):
- format_msg = self.handle_msg(msg, "DEBUG", iteration, steps)
- self.logger.debug(format_msg)
-
- def error(self, msg, iteration, steps):
- format_msg = self.handle_msg(msg, "ERROR", iteration, steps)
- self.logger.error(format_msg)
-
- def flush(self):
- for handler in self.logger.handlers:
- handler.flush()
\ No newline at end of file
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/utils/protocol.py b/mindspeed_llm/tasks/posttrain/rlxf/utils/protocol.py
deleted file mode 100644
index 62f61f9501ea9c4e019a8579797609a740b267e5..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/utils/protocol.py
+++ /dev/null
@@ -1,534 +0,0 @@
-# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Implement base data transfer protocol between any two functions, modules.
-We can subclass Protocol to define more detailed batch info with specific keys
-"""
-
-__all__ = ['DataProto', 'union_tensor_dict']
-
-import copy
-from dataclasses import dataclass, field
-from typing import Callable, Dict, List, Union
-
-import numpy as np
-import ray
-import torch
-import tensordict
-from tensordict import TensorDict
-from torch.utils.data import DataLoader
-
-try:
- tensordict.set_lazy_legacy(False).set()
-except Exception as e:
- pass
-
-
-def union_two_dict(dict1: Dict, dict2: Dict):
- """Union two dict. Will throw an error if there is an item not the same object with the same key.
-
- Args:
- dict1:
- dict2:
-
- Returns:
-
- """
- for key, val in dict2.items():
- if key in dict1:
- if dict2[key] != dict1[key]:
- raise ValueError(f'{key} in dict1 and dict2 are not the same object')
- dict1[key] = val
-
- return dict1
-
-
-def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
- """Union two tensordicts."""
- if tensor_dict1.batch_size != tensor_dict2.batch_size:
- raise ValueError(f'Two tensor dicts must have identical batch sizes. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}')
- for key in tensor_dict2.keys():
- if key not in tensor_dict1.keys():
- tensor_dict1[key] = tensor_dict2[key]
- else:
- if not tensor_dict1[key].equal(tensor_dict2[key]):
- raise ValueError(f'{key} in tensor_dict1 and tensor_dict2 are not the same object')
-
- return tensor_dict1
-
-
-def union_numpy_dict(tensor_dict1, tensor_dict2):
- for key, val in tensor_dict2.items():
- if key in tensor_dict1:
- if not isinstance(tensor_dict2[key], np.ndarray):
- raise TypeError(f"The value for key '{key}' in tensor_dict2 is not a numpy.ndarray.")
- if not isinstance(tensor_dict1[key], np.ndarray):
- raise TypeError(f"The value for key '{key}' in tensor_dict1 is not a numpy.ndarray.")
- if not np.all(tensor_dict2[key] == tensor_dict1[key]):
- raise ValueError(f"Arrays for key '{key}' in tensor_dict1 and tensor_dict2 are not the same object.")
- tensor_dict1[key] = val
-
- return tensor_dict1
-
-
-def list_of_dict_to_dict_of_list(list_of_dict: List[dict]):
- if len(list_of_dict) == 0:
- return {}
- keys = list_of_dict[0].keys()
- output = {key: [] for key in keys}
- for data in list_of_dict:
- for key, item in data.items():
- if key not in output:
- raise KeyError(f"Key '{key}' is not found in the output dictionary")
- output[key].append(item)
- return output
-
-
-def collate_fn(x: List['DataProtoItem']):
- batch = []
- non_tensor_batch = []
- for data in x:
- batch.append(data.batch)
- non_tensor_batch.append(data.non_tensor_batch)
- batch = torch.stack(batch).contiguous()
- non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch)
- for key, val in non_tensor_batch.items():
- non_tensor_batch[key] = np.array(val, dtype=object)
- return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
-
-
-@dataclass
-class DataProtoItem:
- batch: TensorDict = None
- non_tensor_batch: Dict = field(default_factory=dict)
- meta_info: Dict = field(default_factory=dict)
-
-
-@dataclass
-class DataProto:
- """
- A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
- It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict.
- TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
- same batch size should be put inside batch.
- """
- batch: TensorDict = None
- non_tensor_batch: Dict = field(default_factory=dict)
- meta_info: Dict = field(default_factory=dict)
-
- def __post_init__(self):
- # perform necessary checking
- self.check_consistency()
-
- def __len__(self):
- return self.batch.batch_size[0]
-
- def __getitem__(self, item):
- tensor_data = self.batch[item]
- non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
- return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
-
- def __getstate__(self):
- import io
- buffer = io.BytesIO()
- if tensordict.__version__ >= '0.5.0' and self.batch is not None:
- self.batch = self.batch.contiguous()
- self.batch = self.batch.consolidate()
- torch.save(self.batch, buffer)
- return buffer, self.non_tensor_batch, self.meta_info
-
- def __setstate__(self, data):
- batch_deserialized, non_tensor_batch, meta_info = data
- batch_deserialized.seek(0)
- batch = torch.load(batch_deserialized,
- weights_only=False,
- map_location='cpu' if not torch.cuda.is_available() else None)
- self.batch = batch
- self.non_tensor_batch = non_tensor_batch
- self.meta_info = meta_info
-
- def check_consistency(self):
- """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch
- We expose this function as a public one so that user can call themselves directly
- """
- if self.batch is not None:
- if len(self.batch.batch_size) != 1:
- raise ValueError('only support num_batch_dims=1')
-
- if len(self.non_tensor_batch) != 0:
- if len(self.batch.batch_size) != 1:
- raise ValueError('only support num_batch_dims=1 when non_tensor_batch is not empty.')
-
- batch_size = self.batch.batch_size[0]
- for key, val in self.non_tensor_batch.items():
- if not isinstance(val, np.ndarray) or val.dtype != object:
- raise TypeError(f'data in the non_tensor_batch must be a numpy.array with dtype=object, '
- f'but found {key} with dtype {val.dtype}')
- if val.shape[0] != batch_size:
- raise ValueError(f'key {key} length {len(val)} is not equal to batch size {batch_size}')
-
- @classmethod
- def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None):
- tensors = {}
- non_tensors = {}
-
- for key, val in data.items():
- if isinstance(val, torch.Tensor):
- tensors[key] = val
- elif isinstance(val, np.ndarray):
- non_tensors[key] = val
- else:
- raise ValueError(f'Unsupported type in data {type(val)}')
-
- return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
-
- @classmethod
- def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1):
- """Create a DataProto from a dict of tensors. This assumes that
- 1. All the tensor in tensors have the same dim0
- 2. Only dim0 is the batch dim
- """
- if len(tensors) == 0:
- raise ValueError('tensors must not be empty')
-
- if num_batch_dims <= 0:
- raise ValueError('num_batch_dims must be greater than zero')
-
- if non_tensors is not None:
- if num_batch_dims != 1:
- raise ValueError('only support num_batch_dims=1 when non_tensors is not None.')
-
- if meta_info is None:
- meta_info = {}
- if non_tensors is None:
- non_tensors = {}
-
- if not isinstance(non_tensors, dict):
- raise TypeError('non_tensors must be a dictionary')
-
- # get and check batch size
- batch_size = None
- pivot_key = None
- for key, tensor in tensors.items():
- if batch_size is None:
- batch_size = tensor.shape[:num_batch_dims]
- pivot_key = key
- else:
- current_batch = tensor.shape[:num_batch_dims]
- if batch_size != current_batch:
- raise ValueError(f'Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. '
- f'Got {pivot_key} has {batch_size}, {key} has {current_batch}')
-
- for key, val in non_tensors.items():
- non_tensors[key] = np.array(val, dtype=object)
-
- tensor_dict = TensorDict(source=tensors, batch_size=batch_size)
- return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
-
- def to(self, device) -> 'DataProto':
- """move the batch to device
-
- Args:
- device (torch.device, str): torch device
-
- Returns:
- DataProto: the current DataProto
-
- """
- if self.batch is not None:
- self.batch = self.batch.to(device)
- return self
-
- def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto':
- """Select a subset of the DataProto via batch_keys and meta_info_keys
-
- Args:
- batch_keys (list, optional): a list of strings indicating the keys in batch to select
- meta_info_keys (list, optional): a list of keys indicating the meta info to select
-
- Returns:
- DataProto: the DataProto with the selected batch_keys and meta_info_keys
- """
- if batch_keys is not None:
- batch_keys = tuple(batch_keys)
- sub_batch = self.batch.select(*batch_keys)
- else:
- sub_batch = self.batch
-
- if non_tensor_batch_keys is not None:
- non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}
- else:
- non_tensor_batch = self.non_tensor_batch
-
- if deepcopy:
- non_tensor_batch = copy.deepcopy(non_tensor_batch)
-
- if meta_info_keys is not None:
- sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}
- else:
- sub_meta_info = self.meta_info
-
- if deepcopy:
- sub_meta_info = copy.deepcopy(sub_meta_info)
-
- return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
-
- def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto':
- """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`
-
- Args:
- batch_keys (list, optional): a list of strings indicating the keys in batch to pop
- meta_info_keys (list, optional): a list of keys indicating the meta info to pop
-
- Returns:
- DataProto: the DataProto with the poped batch_keys and meta_info_keys
- """
- if batch_keys is None:
- raise ValueError("batch_keys cannot be None. Please provide a valid list of keys.")
-
- if meta_info_keys is None:
- meta_info_keys = []
- if non_tensor_batch_keys is None:
- non_tensor_batch_keys = []
-
- tensors = {}
- # tensor batch
- for key in batch_keys:
- if key not in self.batch.keys():
- raise KeyError(f"Key '{key}' not found in self.batch.")
- tensors[key] = self.batch.pop(key)
- non_tensors = {}
- # non tensor batch
- for key in non_tensor_batch_keys:
- if key not in self.non_tensor_batch.keys():
- raise KeyError(f"Key '{key}' not found in self.non_tensor_batch.")
- non_tensors[key] = self.non_tensor_batch.pop(key)
- meta_info = {}
- for key in meta_info_keys:
- if key not in self.meta_info.keys():
- raise KeyError(f"Key '{key}' not found in self.meta_info.")
- meta_info[key] = self.meta_info.pop(key)
- return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
-
- def rename(self, old_keys=None, new_keys=None) -> 'DataProto':
- """
- Note that this function only rename the key in the batch
- """
-
- def validate_input(keys):
- if keys is not None:
- if isinstance(keys, str):
- keys = [keys]
- elif isinstance(keys, list):
- pass
- else:
- raise TypeError(f'keys must be a list or a string, but got {type(keys)}')
- return keys
-
- old_keys = validate_input(old_keys)
- new_keys = validate_input(new_keys)
-
- if len(new_keys) != len(old_keys):
- raise ValueError(
- f'new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}')
-
- self.batch.rename_key_(tuple(old_keys), tuple(new_keys))
-
- return self
-
- def union(self, other: 'DataProto') -> 'DataProto':
- """Union with another DataProto. Union batch and meta_info separately.
- Throw an error if
- - there are conflict keys in batch and they are not equal
- - the batch size of two data batch is not the same
- - there are conflict keys in meta_info and they are not the same.
-
- Args:
- other (DataProto): another DataProto to union
-
- Returns:
- DataProto: the DataProto after union
- """
- self.batch = union_tensor_dict(self.batch, other.batch)
- self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)
- self.meta_info = union_two_dict(self.meta_info, other.meta_info)
- return self
-
- def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
- """Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
- dataset.
-
- Args:
- mini_batch_size (int): mini-batch size when iterating the dataset. We require that
- ``batch.batch_size[0] % mini_batch_size == 0``
- epochs (int): number of epochs when iterating the dataset.
- dataloader_kwargs: internally, it returns a DataLoader over the batch.
- The dataloader_kwargs is the kwargs passed to the DataLoader
-
- Returns:
- Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is
- ``self.batch.batch_size * epochs // mini_batch_size``
- """
- if self.batch.batch_size[0] % mini_batch_size != 0:
- raise ValueError(f"Batch size {self.batch.batch_size[0]} is not divisible by mini_batch_size {mini_batch_size}.")
- # we can directly create a dataloader from TensorDict
- if dataloader_kwargs is None:
- dataloader_kwargs = {}
-
- if seed is not None:
- generator = torch.Generator()
- generator.manual_seed(seed)
- else:
- generator = None
-
- if not isinstance(dataloader_kwargs, dict):
- raise TypeError(f"dataloader_kwargs should be a dictionary, but got {type(dataloader_kwargs)}.")
- train_dataloader = DataLoader(dataset=self,
- batch_size=mini_batch_size,
- collate_fn=collate_fn,
- generator=generator,
- **dataloader_kwargs)
-
- def get_data():
- for _ in range(epochs):
- for d in train_dataloader:
- d.meta_info = self.meta_info
- yield d
-
- return iter(get_data())
-
- def chunk(self, chunks: int) -> List['DataProto']:
- """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
-
- Args:
- chunks (int): the number of chunks to split on dim=0
-
- Returns:
- List[DataProto]: a list of DataProto after splitting
- """
- if self.batch is not None:
- batch_lst = self.batch.chunk(chunks=chunks, dim=0)
- else:
- batch_lst = [None for _ in range(chunks)]
-
- non_tensor_batch_lst = [{} for _ in range(chunks)]
- for key, val in self.non_tensor_batch.items():
- if not isinstance(val, np.ndarray):
- raise TypeError(f"Expected value of type np.ndarray for key '{key}', but got {type(val)}.")
- non_tensor_lst = np.array_split(val, chunks)
- if len(non_tensor_lst) != chunks:
- raise ValueError(f"After splitting, the number of chunks for key '{key}' is {len(non_tensor_lst)}, "
- f"which does not match the expected number of chunks: {chunks}.")
- for i in range(chunks):
- non_tensor_batch_lst[i][key] = non_tensor_lst[i]
-
- output = []
- for i in range(chunks):
- output.append(
- DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info))
-
- return output
-
- @staticmethod
- def concat(data: List['DataProto']) -> 'DataProto':
- """Concat a list of DataProto. The batch is concatenated among dim=0.
- The meta_info is assumed to be identical and will use the first one.
-
- Args:
- data (List[DataProto]): list of DataProto
-
- Returns:
- DataProto: concatenated DataProto
- """
- batch_lst = []
- for batch in data:
- batch_lst.append(batch.batch)
- if batch_lst[0] is not None:
- new_batch = torch.cat(batch_lst, dim=0)
- else:
- new_batch = None
-
- non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])
- for key, val in non_tensor_batch.items():
- non_tensor_batch[key] = np.concatenate(val, axis=0)
-
- return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
-
- def reorder(self, indices):
- """
- Note that this operation is in-place
- """
- indices_np = indices.detach().numpy()
- self.batch = self.batch[indices]
- self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
-
-
-@dataclass
-class DataProtoFuture:
- """
- DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
- for data so that asynchronous execution becomes possible.
- DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
- - collect_fn is a Callable that reduces the list of futures to a DataProto
- - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select
-
- Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination
- - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any
- operation on the DataProtoFuture in driver.
- """
- collect_fn: Callable
- futures: List[ray.ObjectRef]
- dispatch_fn: Callable = None
-
- @staticmethod
- def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture':
- output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
- return output
-
- def chunk(self, chunks: int) -> List['DataProtoFuture']:
- from functools import partial
-
- arg_future_lst = []
- for i in range(chunks):
- # note that we can't directly pass i and chunks
- def dispatch_fn(x, i, chunks):
- return x.chunk(chunks=chunks)[i]
-
- arg_future = DataProtoFuture(collect_fn=self.collect_fn,
- dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks),
- futures=self.futures)
- arg_future_lst.append(arg_future)
- return arg_future_lst
-
- def get(self):
- output = ray.get(self.futures) # dp_size.
- for single_output in output:
- if not isinstance(single_output, DataProto):
- raise TypeError(f"Expected instance of DataProto, but got {type(single_output)}.")
- output = self.collect_fn(output) # select dp, concat
- if self.dispatch_fn is not None:
- output = self.dispatch_fn(output) # split in batch dim, select using dp
- return output
-
-
-def make_batch_generator(batches, vpp_size):
- if vpp_size > 1:
- # has vpp
- batch_generator = [batches] * vpp_size # number of vpp chunks
- batch_generator = [iter(b) for b in batch_generator]
- else:
- # no vpp
- batch_generator = iter(batches)
- return batch_generator
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/utils/torch_functional.py b/mindspeed_llm/tasks/posttrain/rlxf/utils/torch_functional.py
deleted file mode 100644
index 927d8d2bd203f62bc80f51643e4a6974b1e6ffd3..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/utils/torch_functional.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Contain small torch utilities
-"""
-
-from typing import List
-
-import torch
-import torch.distributed
-from tensordict import TensorDict
-
-
-def clip_by_value(x, tensor_min, tensor_max):
- """
- Tensor extenstion to torch.clamp
- """
- clipped = torch.max(torch.min(x, tensor_max), tensor_min)
- return clipped
-
-
-def masked_mean(values, mask, axis=None):
- """Compute mean of tensor with a masked values."""
- return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
-
-
-def masked_var(values, mask, unbiased=True):
- """Compute variance of tensor with masked values."""
- mean = masked_mean(values, mask)
- centered_values = values - mean
- variance = masked_mean(centered_values ** 2, mask)
- if unbiased:
- mask_sum = mask.sum()
- if mask_sum == 0:
- raise ValueError("At least one element in the mask has to be 1.")
- # note that if mask_sum == 1, then there is a division by zero issue
- # to avoid it you just need to use a larger minibatch_size
- elif mask_sum == 1:
- bessel_correction = mask_sum
- else:
- bessel_correction = mask_sum / (mask_sum - 1)
- variance = variance * bessel_correction
- return variance
-
-
-def masked_whiten(values, mask, shift_mean=True):
- """Whiten values with masked values."""
- mean, var = masked_mean(values, mask), masked_var(values, mask)
- whitened = (values - mean) * torch.rsqrt(var + 1e-8)
- if not shift_mean:
- whitened += mean
- return whitened
-
-
-def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[TensorDict]:
- if tensors.batch_size[0] % batch_size != 0:
- raise ValueError(f"Input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}. "
- f"Please ensure that the input batch size is divisible by the split batch size.")
- return tensors.split(batch_size)
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/__init__.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/actor_train_infer.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/actor_train_infer.py
deleted file mode 100644
index eeee7de9475f8b9e2bef9d033ecb4c1b15703b1b..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/workers/actor_train_infer.py
+++ /dev/null
@@ -1,890 +0,0 @@
-# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
-import copy
-import os
-from functools import partial
-from typing import Iterable, Dict
-
-import ray
-import torch
-import torch_npu
-import torch.nn.functional as F
-from tensordict import TensorDict
-
-from megatron.training import get_args, initialize_megatron, get_timers, get_tokenizer
-from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
-from megatron.core import parallel_state as mpu, tensor_parallel
-from megatron.training.training import append_to_progress_log, build_train_valid_test_data_iterators, print_datetime
-from megatron.training import get_model
-from megatron.training.utils import unwrap_model
-from megatron.training.checkpointing import save_checkpoint
-from megatron.training.training import num_floating_point_operations
-from mindspeed_llm.tasks.posttrain.rlxf.training.core_algos import compute_policy_loss, find_first_eos_index, compute_grpo_policy_loss
-from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import split_dict_tensor_into_batches
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import MegatronWorker
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch, Execute
-from mindspeed_llm.tasks.posttrain.rlxf.training.parameter_mapping import sync_param_nums, run_auto_mapping, \
- init_parameter_mapping_distributed
-from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, make_batch_generator
-from mindspeed_llm.tasks.posttrain.base import BaseTrainer
-import mindspeed_llm.tasks.posttrain.rlxf.training.parallel_state as ps
-from mindspeed_llm.tasks.inference.module import MegatronModuleForCausalLM
-from mindspeed_llm.tasks.preprocess.blended_mtf_dataset import build_blended_mtf_dataset
-from mindspeed_llm.training.initialize import set_jit_fusion_options
-from mindspeed_llm.training.utils import get_finetune_data_on_this_tp_rank, get_tune_attention_mask
-from mindspeed_llm.tasks.posttrain.utils import compute_log_probs, append_to_dict
-
-
-def train_valid_test_datasets_provider(train_val_test_num_samples):
- """Build the train test and validation datasets.
-
- Args:
- train_val_test_num_samples : A list containing the number of samples in train test and validation.
- """
- args = get_args()
- print("> building train, validation, and test datasets for GPT ...")
-
- train_ds, valid_ds, test_ds = build_blended_mtf_dataset(
- data_prefix=args.data_path,
- splits_string=args.split,
- train_valid_test_num_samples=train_val_test_num_samples,
- seq_length=args.max_prompt_length,
- seed=args.seed)
-
- print("> finished creating GPT datasets ...")
-
- return train_ds, valid_ds, test_ds
-
-
-@ray.remote
-class PPOActorWorker(MegatronWorker):
- """
- A basic class to launch two megatron instances with different communication groups.
- Currently assume that the first group is for training and the second group is for inference.
- """
-
- def __init__(self, config, role):
- super().__init__()
- self.config = config
- self.role = role
- self.IGNORE_INDEX = -100
- os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
-
- initialize_megatron(args_defaults={'no_load_rng': True, 'no_load_optim': True},
- role=self.role,
- config=self.config,
- two_megatron=True)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def initialize(self):
- init_parameter_mapping_distributed()
- self.is_inference_node = ps.in_mg2_inference_group()
- if self.is_inference_node:
- self.node = PPOActorInferWorker()
- else:
- self.node = PPOActorTrainWorker()
- torch.cuda.empty_cache()
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def auto_mapping(self):
- if self.is_inference_node:
- run_auto_mapping(self.node.inf_model)
- else:
- run_auto_mapping(self.node.actor.model[0])
- torch.cuda.empty_cache()
-
- @register(dispatch_mode=Dispatch.DP_ALL_GATHER_INFER, execute_mode=Execute.INFER)
- def generate_sequences(self):
- args = get_args()
- output = self.node.run_inference()
- tokenizer = get_tokenizer()
- meta_info = {'eos_token_id': tokenizer.eos_token_id, 'pad_token_id': tokenizer.pad_token_id, 'num_samples_per_step':args.num_samples_per_step}
- output.meta_info.update(meta_info)
- torch.cuda.empty_cache()
- return output
-
- @register(dispatch_mode=Dispatch.DP_ALL_GATHER_TRAIN, execute_mode=Execute.TRAIN)
- def update_actor(self, data):
- device = next(self.node.actor.model[0].parameters()).device
- data = data.to(device)
-
- dataloader = self.node.actor.make_minibatch_iterator(data=data)
-
- metrics = self.node.actor.update_policy(dataloader=dataloader)
-
- output = DataProto(meta_info={'metrics': metrics})
- output = output.to('cpu')
- torch.cuda.empty_cache()
- return output
-
- @register(dispatch_mode=Dispatch.DP_ALL_GATHER_TRAIN, execute_mode=Execute.TRAIN)
- def get_log_probs(self, data):
- old_log_probs = self.node.actor.compute_log_prob(data)
- if old_log_probs is not None: # pp last stage
- data.batch['old_log_probs'] = old_log_probs
- data = data.to('cpu')
- else: # pp intermediate stage, no useful results
- data = None
- # clear kv cache
- torch.cuda.empty_cache()
- return data
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, execute_mode=Execute.TRAIN)
- def save_checkpoint(self, iteration):
- self.node.actor.save_checkpoint(iteration)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL, execute_mode=Execute.TRAIN)
- def get_iteration(self):
- return self.node.actor.get_iteration()
-
-
-class PPOActorTrainWorker(BaseTrainer):
- def __init__(self):
- super().__init__()
-
- def initialize(self):
- self.args = get_args()
- model, optimizer, opt_param_scheduler = self._build_model_and_optimizer()
- self.actor = MegatronPPOActor(model=model, optimizer=optimizer, opt_param_scheduler=opt_param_scheduler)
- sync_param_nums(model[0])
- if self.args.stage == "ray_online_dpo":
- self.args.micro_batch_size *= 2
- self.args.ppo_mini_batch_size *= 2
-
- def _build_model_and_optimizer(self):
- from megatron.training.training import setup_model_and_optimizer
- model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
- self.model_provider, self.model_type)
- return model, optimizer, opt_param_scheduler
-
- def get_batch(self, data_iterator):
- """
- Retrieves a batch of data from the data iterator.
- Called during each forward step.
- """
- pass
-
- def loss_func(self, input_tensor, output_tensor):
- """
- Computes the loss function.
- Called during each forward step.
- """
- pass
-
- def forward_step(self, data_iterator, model):
- """
- Performs a forward pass and computes the loss.
- Called during each training iteration.
- """
- pass
-
-
-def pad_to_tensor_dict(data, padding_side="right", pad_multi_of=16):
- max_length = torch.LongTensor([max(len(val) for val in data)]).cuda()
- max_length = max_length if max_length % pad_multi_of == 0 else (max_length // pad_multi_of + 1) * pad_multi_of
- torch.distributed.all_reduce(max_length, op=torch.distributed.ReduceOp.MAX)
-
- tokenizer = get_tokenizer()
-
- pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
- context_lengths = [len(val) for val in data]
-
- data_length = len(data)
- for i in range(data_length):
- if context_lengths[i] < max_length:
- if padding_side == "right":
- data[i].extend([pad_id] * (max_length - context_lengths[i]))
- else:
- data[i] = [pad_id] * (max_length - context_lengths[i]) + data[i]
- return context_lengths, max_length
-
-
-class PPOActorInferWorker(BaseTrainer):
- def __init__(self):
- super().__init__()
- self.count = 0
- self.keys = None
-
- def model_provider(self, pre_process=True, post_process=True):
- """Builds the inference model.
-
- If you set the use_mcore_models to True, it will return the mcore GPT model.
-
- Args:
- pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
- post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
-
-
- Returns:
- Union[GPTModelInfer, GPTModel]: The returned model
- """
-
- from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec, \
- get_gpt_layer_local_spec
- from megatron.core.transformer.spec_utils import import_module
- from megatron.training import get_args, print_rank_0
- from megatron.training.arguments import core_transformer_config_from_args
- from megatron.training.yaml_arguments import core_transformer_config_from_yaml
-
- from mindspeed_llm.tasks.inference.module import GPTModelInfer
-
- args = get_args()
- use_te = args.transformer_impl == "transformer_engine"
-
- print_rank_0('building GPT Rollout model ...')
- # Experimental loading arguments from yaml
- if args.yaml_cfg is not None:
- config = core_transformer_config_from_yaml(args, "language_model")
- else:
- config = core_transformer_config_from_args(args)
-
- if args.spec is not None:
- transformer_layer_spec = import_module(args.spec)
- else:
- if use_te:
- transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts,
- args.moe_grouped_gemm)
- else:
- transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm)
-
- model = GPTModelInfer(
- config=config,
- transformer_layer_spec=transformer_layer_spec,
- vocab_size=args.padded_vocab_size,
- max_sequence_length=args.max_position_embeddings,
- pre_process=pre_process,
- post_process=post_process,
- fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
- parallel_output=False,
- share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
- position_embedding_type=args.position_embedding_type,
- rotary_percent=args.rotary_percent,
- seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor
- )
-
- return model
-
- def initialize(self):
- train_valid_test_datasets_provider.is_distributed = True
- self.args = get_args()
- self.timers = get_timers()
- self.train_valid_test_datasets_provider = train_valid_test_datasets_provider
-
- if self.args.log_progress:
- append_to_progress_log("Starting job")
- # Set pytorch JIT layer fusion options and warmup JIT functions.
- set_jit_fusion_options()
-
- self.timers('train/valid/test-data-iterators-setup', log_level=0).start(
- barrier=True)
-
- self.args.num_layer_list = None
- self.args.micro_batch_size = 1
- self.args.sequence_parallel = False
-
- self.args.model = unwrap_model(get_model(self.model_provider, wrap_with_ddp=False))
- self.inf_model = self.args.model[0]
- self.args.dataset_additional_keys = eval(self.args.dataset_additional_keys[0]) if self.args.dataset_additional_keys else []
-
- sync_param_nums(self.inf_model)
- true_pad_to_multiple_of = self.args.pad_to_multiple_of
- self.args.pad_to_multiple_of = 1 # we don't want to pad data here
- self.train_data_iterator, self.valid_data_iterator, self.test_data_iterator \
- = build_train_valid_test_data_iterators(
- self.train_valid_test_datasets_provider)
- self.args.pad_to_multiple_of = true_pad_to_multiple_of
- self.timers('train/valid/test-data-iterators-setup').stop()
- print_datetime('after dataloaders are built')
-
- # Print setup timing.
- print('done with setup ...')
- self.timers.log(['model-setup', 'train/valid/test-data-iterators-setup'], barrier=True)
-
- def get_batch(self, data_iterator):
- """Generate a batch identical to Llama factory"""
- args = get_args()
-
- self.keys = ['input_ids', *self.args.dataset_additional_keys]
-
- if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
- if args.variable_seq_lengths and args.pipeline_model_parallel_size > 2:
- tokens, _ = get_finetune_data_on_this_tp_rank(data_iterator)
- else:
- tokens = None
- return tokens
-
- # Items and their type.
-
- data_type = torch.int64
- cur_data = next(data_iterator)
-
- # add problem category for reward choosing
- if args.dataset_category is not None:
- dataset_category = [int(item) for item in args.dataset_category.split(",")]
- categories = [dataset_category[id.item()] for id in cur_data['dataset_id']]
- cur_data['categories'] = torch.tensor(categories, dtype=torch.int64)
-
- # Broadcast data.
- data_b = tensor_parallel.broadcast_data(self.keys, cur_data, data_type)
-
- # Unpack
- batch = {}
- for key in self.keys:
- batch[key] = data_b.get(key).long()
-
- return batch
-
- def run_inference(self):
- args = get_args()
- num_infer_steps = args.global_batch_size // (args.data_parallel_size * args.num_samples_per_step)
- responses = []
- idx_list = []
- idx_list_per_step = []
- additional_dict = {}
- additional_dict_per_step = {}
-
- for k in self.args.dataset_additional_keys:
- if not hasattr(additional_dict, k):
- additional_dict[k] = []
-
- max_new_tokens = args.seq_length - args.max_prompt_length
- if max_new_tokens % args.pad_to_multiple_of != 0:
- raise ValueError(f"Please adjust pad_to_multiple_of so that max_new_tokens % args.pad_to_multiple_of == 0. "
- f"Current max_new_tokens: {max_new_tokens}, pad_to_multiple_of: {args.pad_to_multiple_of}")
- for _ in range(num_infer_steps):
- for k in self.args.dataset_additional_keys:
- if not hasattr(additional_dict_per_step, k):
- additional_dict_per_step[k] = []
-
- for _ in range(args.num_samples_per_step):
- batch = self.get_batch(self.train_data_iterator)
-
- tokens = batch["input_ids"]
- tokens_list = tokens.view(-1).cpu().numpy().tolist()
-
- for additional_key in self.args.dataset_additional_keys:
- additional_val = batch.get(additional_key).view(-1).cpu().numpy().tolist()
-
- for _ in range(args.n_samples_per_prompt):
- additional_dict_per_step.get(additional_key).append(copy.deepcopy(additional_val))
-
- for _ in range(args.n_samples_per_prompt):
- idx_list_per_step.append(copy.deepcopy(tokens_list))
-
- if args.stage == "ray_online_dpo":
- idx_list_per_step.append(copy.deepcopy(tokens_list))
-
- responses_per_step = self.inf_model.generate(
- copy.deepcopy(idx_list_per_step),
- max_new_tokens=max_new_tokens,
- temperature=args.temperature,
- do_sample=args.do_sample,
- detokenize=False,
- broadcast=False,
- truncate=True
- )
-
- if not isinstance(responses_per_step, list):
- responses_per_step = [responses_per_step]
-
- responses.extend(responses_per_step)
- idx_list.extend(idx_list_per_step)
- idx_list_per_step = []
-
- for k in additional_dict:
- additional_dict[k].extend(additional_dict_per_step[k])
-
- additional_dict_per_step = {}
-
-
- responses_ori_length, responses_pad_length = pad_to_tensor_dict(
- responses,
- pad_multi_of=args.pad_to_multiple_of
- )
- prompts_ori_length, prompts_pad_length = pad_to_tensor_dict(
- idx_list, "left",
- pad_multi_of=args.pad_to_multiple_of
- )
-
- for additional_key in self.args.dataset_additional_keys:
- tmp_val = additional_dict.get(additional_key)
- pad_to_tensor_dict(
- tmp_val,
- pad_multi_of=args.pad_to_multiple_of
- )
- additional_dict[additional_key] = tmp_val
-
- input_ids = [prompt + response for prompt, response in zip(idx_list, responses)]
-
- attention_mask = generate_attention_mask(input_ids, prompts_ori_length, prompts_pad_length,
- responses_ori_length, responses_pad_length)
-
- position_ids = generate_position_ids_from_attention_mask(input_ids, prompts_ori_length, prompts_pad_length)
- if self.args.stage == "ray_online_dpo":
- batch_size = args.global_batch_size // args.data_parallel_size * 2
- else:
- batch_size = args.global_batch_size // args.data_parallel_size * args.n_samples_per_prompt
-
- batch = TensorDict(
- dict(
- {
- "prompts": idx_list,
- "responses": responses,
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "position_ids": position_ids,
- "responses_ori_length": responses_ori_length
- }, **additional_dict
- ),
- batch_size=batch_size
- )
-
- return DataProto(batch=batch)
-
- def loss_func(self, input_tensor, output_tensor):
- """
- Computes the loss function.
- Called during each forward step.
- """
- pass
-
- def forward_step(self, data_iterator, model):
- """
- Performs a forward pass and computes the loss.
- Called during each training iteration.
- """
- pass
-
-
-def generate_position_ids_from_attention_mask(input_ids_list, prompts_ori_length, prompts_pad_length):
- """
- 生成与 attention_mask 对应的 position_ids 列表。
-
- 参数:
- input_ids_list (list of lists): 包含 input_ids 的列表,每个元素是一个列表。
- prompts_ori_length (list of lists): 包含 prompt_ori_length 的列表,每个元素是int。
- prompts_pad_length int: prompts_pad_length,int。
-
- 返回:
- list of lists: 包含 position_ids 的列表,每个元素是一个列表。
- """
- position_ids_list = []
- for idx, input_ids in enumerate(input_ids_list):
- prompt_pad_length = prompts_pad_length - prompts_ori_length[idx]
- position_ids = [0] * prompt_pad_length + list(range(len(input_ids) - prompt_pad_length))
- position_ids_list.append(position_ids)
-
- return position_ids_list
-
-
-def generate_attention_mask(input_ids_list, prompts_ori_length, prompts_pad_length, responses_ori_length,
- responses_pad_length):
- """
- 生成与 input_ids 对应的 attention_mask 列表。
-
- 参数:
- input_ids_list (list of lists): 包含 input_ids 的列表,每个元素是一个列表。
- prompts_ori_length (list of lists): 包含 prompt_ori_length 的列表,每个元素是int。
- prompts_pad_length int: prompts_pad_length,int。
- responses_ori_length (list of lists): 包含 response_ori_length 的列表,每个元素是int。
- responses_pad_length int: responses_pad_length,int。
-
- 返回:
- list of lists: 包含 attention_mask 的列表,每个元素是一个列表。
- """
- attention_mask_list = []
-
- for idx, input_ids in enumerate(input_ids_list):
- attention_mask = torch.ones_like(torch.tensor(input_ids))
- prompt_pad_length = prompts_pad_length - prompts_ori_length[idx]
- response_pad_length = responses_pad_length - responses_ori_length[idx]
- attention_mask[:prompt_pad_length] = 0
- if response_pad_length > 0:
- attention_mask[-response_pad_length:] = 0
- attention_mask_list.append(attention_mask.numpy().tolist())
-
- return attention_mask_list
-
-
-def split_two_prompts(origin_tensor):
- origin_tensor = origin_tensor.reshape(-1, 2)
- first_half, second_half = origin_tensor.split(1, dim=1)
- return first_half.reshape(-1), second_half.reshape(-1)
-
-
-
-class MegatronPPOActor():
-
- def __init__(self, model, optimizer, opt_param_scheduler):
- """MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron.
-
- Args:
- model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and
- ``model_config.hidden_size``
- megatron_config (OmegaConf): megatron configuration. It must contains
-
- ``sequence_parallel_enabled``: whether the sequence parallel is enabled.
-
- ``param_dtype``: the dtype of the parameters.
-
- ``virtual_pipeline_model_parallel_size``: virtual pipeline model parallel size. a.k.a number of chunks in each pp stage.
- actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this pp stage.
- each nn.Module in this rank holds a vpp module chunk.
- The actor module has some constraints to follow in order to use the updating logics implemented here
-
- 1. It must implement unpad_input before any computation and pad_input after all the computation. Remove padding is an
- optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn
-
- 2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size],
- where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size
- of the hidden state is [total_nnz // tp, 1, hidden_size].
- actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. It implements
- zero1 optimizer that shards the optimizer state across dp ranks.
-
- """
- self.args = get_args()
- self.model = model
- self.optimizer = optimizer
- self.opt_param_scheduler = opt_param_scheduler
- self._beta = self.args.dpo_beta
- self.num_floating_point_operations_so_far = 0
-
- def get_iteration(self):
- return self.args.iteration
-
- def save_checkpoint(self, iteration):
-
- save_checkpoint(iteration, self.model, self.optimizer, self.opt_param_scheduler,
- self.num_floating_point_operations_so_far)
-
- def compute_log_prob(self, data) -> torch.Tensor:
- """Compute the log probability of the responses given input_ids, attention_mask and position_ids
-
- Args:
- data (DataProto): a DataProto containing keys
-
- ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
- concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
-
- ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
-
- ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
-
- ``responses``: tensor of shape [batch_size, response_length]. torch.int64.
-
- Returns:
- DataProto: torch.Tensor: the log_prob tensor
- """
- data.batch = data.batch.contiguous()
-
- def compute_logprobs_fn(output, data):
- response = data['responses']
- response_length = response.size(1)
- logits = output
- logits = logits[:, -response_length - 1:-1]
- _, _, log_probs = compute_log_probs(logits, response, per_token=True)
- return {'log_probs': log_probs}
-
- # We make recompute_old_log_prob by default here.
- data = data.to(next(self.model[0].parameters()).device)
- with torch.no_grad():
- output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn)
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- # only on last rank. It should be on every tp rank
- log_probs = torch.cat([single_output['log_probs'] for single_output in output], dim=0) # (bs, seq_size)
- log_probs = log_probs.to(torch.float32)
- else:
- log_probs = None
-
- # add empty cache after each compute
- torch.cuda.empty_cache()
-
- return log_probs
-
- @property
- def beta(self):
- if isinstance(self._beta, list):
- epoch = self.state.epoch
- return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
- else:
- return self._beta
-
- def make_minibatch_iterator(self, data):
- """Make minibatch iterator for updating the actor
-
- Args:
- data (DataProto): a DataProto containing keys
-
- ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where ``sequence_length = prompt_length + response_length``
-
- ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64
-
- ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64
-
- ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that responses = input_ids[:, -response_length:]
-
- ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability of responses.
-
- ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of responses.
- See PPO paper for details.
-
- Returns:
-
- """
- return data.make_iterator(mini_batch_size=self.args.ppo_mini_batch_size,
- epochs=self.args.ppo_epochs,
- dataloader_kwargs={'shuffle': self.args.shuffle_minibatch})
-
- def forward_backward_batch(self, data, forward_only=False, post_process_fn=None):
- """
- We assume:
- - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input
- - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled
- """
- # broadcast from last pp rank to all other pp ranks
-
- data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
-
- batch_size = self.args.micro_batch_size
- batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size)
-
- n_micro_batch = len(batches)
- seq_len = batches[0]['input_ids'].shape[1]
-
- forward_backward_func = get_forward_backward_func()
-
- def loss_func_ppo(output, data, meta_info):
- """
- This loss_func has two modes
- 1. when forward_only is True: use post_process_fn to calculate the log_probs
- 2. when forward_only is False: calculate the policy loss
- """
- if forward_only:
- if post_process_fn is None:
- return 1.0, {'logits': output}
- else:
- return 1.0, post_process_fn(output, data)
-
- responses = data['responses']
- response_length = responses.size(1)
- attention_mask = data['attention_mask']
- response_mask = attention_mask[:, -response_length:]
- old_log_prob = data['old_log_probs']
- advantages = data['advantages']
-
- clip_ratio = meta_info['clip_ratio']
-
- # compute policy loss
- logits = output
- logits = logits[:, -response_length - 1:-1]
- _, _, log_prob = compute_log_probs(logits, responses, per_token=True)
- pg_loss, pg_clipfrac, ppo_kl = compute_policy_loss(old_log_prob=old_log_prob,
- log_prob=log_prob,
- advantages=advantages,
- eos_mask=response_mask,
- cliprange=clip_ratio)
- policy_loss = pg_loss
- # return loss and stats
- stats = {
- 'actor/pg_loss': abs(pg_loss.detach().item()),
- 'actor/pg_clipfrac': pg_clipfrac.detach().item(),
- 'actor/ppo_kl': ppo_kl.detach().item()
- }
- return policy_loss, stats
-
- def loss_func_grpo(output, data, meta_info):
- """
- This loss_func has two modes
- 1. when forward_only is True: use post_process_fn to calculate the log_probs
- 2. when forward_only is False: calculate the policy loss
- """
- if forward_only:
- if post_process_fn is None:
- return 1.0, {'logits': output}
- else:
- return 1.0, post_process_fn(output, data)
-
- responses = data['responses']
- response_length = responses.size(1)
- attention_mask = data['attention_mask']
- response_mask = attention_mask[:, -response_length:]
- old_log_prob = data['old_log_probs']
- advantages = data['advantages']
- ref_log_prob = data['ref_log_prob']
- clip_ratio = meta_info['clip_ratio']
-
- # compute policy loss
- logits = output
- logits = logits[:, -response_length - 1:-1]
- _, _, log_prob = compute_log_probs(logits, responses, per_token=True)
-
- pg_loss, pg_clipfrac, ppo_kl = compute_grpo_policy_loss(old_log_prob=old_log_prob,
- log_prob=log_prob,
- ref_log_prob=ref_log_prob,
- advantages=advantages,
- eos_mask=response_mask,
- cliprange=clip_ratio,
- kl_ctrl=self.args.kl_ctrl)
- policy_loss = pg_loss
-
- stats = {
- 'actor/pg_loss': abs(pg_loss.detach().item()),
- 'actor/pg_clipfrac': pg_clipfrac.detach().item(),
- 'actor/ppo_kl': ppo_kl.detach().item()
- }
- return policy_loss, stats
-
- def loss_func_online_dpo(output, data, meta_info):
- """
- calculate the policy loss
- """
- args = get_args()
- scores = data['rm_scores']
- responses = data['responses']
- device = responses.device
- ref_logprobs = data['ref_log_prob']
- response_length = responses.size(1)
- attention_mask = data['attention_mask']
- response_mask = attention_mask[:, -response_length:]
- num_examples = responses.shape[0] // 2
-
- actual_start = torch.arange(responses.size(0), device=responses.device)
- tokenizer = get_tokenizer()
- score_first_eos_index, reward_first_eos_index = find_first_eos_index(responses, tokenizer.eos_token_id)
-
- scores = scores[[actual_start, score_first_eos_index]]
- contain_eos_token = torch.any(responses == tokenizer.eos_token_id, dim=-1)
- if args.missing_eos_penalty is not None:
- scores[~contain_eos_token] -= args.missing_eos_penalty
- data['rm_scores'] = scores
- first_half, second_half = split_two_prompts(scores)
-
- mask = first_half >= second_half
- num_examples_range = torch.arange(num_examples, device=device)
- chosen_indices = num_examples_range + (~mask * num_examples)
- rejected_indices = num_examples_range + (mask * num_examples)
- # Build tensor so that the first half is the chosen examples and the second half the rejected examples
- cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
- logits = output[:, -response_length - 1:-1]
- _, _, log_prob = compute_log_probs(logits, responses, per_token=True)
-
- cr_logprobs = log_prob[cr_indices]
- cr_ref_logprobs = ref_logprobs[cr_indices]
-
- # mask out the padding tokens
- padding_mask = ~response_mask.bool()
- cr_padding_mask = padding_mask[cr_indices]
-
- cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
- cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
-
- # Split the chosen and rejected examples
- chosen_logprobs_sum, rejected_logprobs_sum = split_two_prompts(cr_logprobs_sum)
- chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = split_two_prompts(cr_ref_logprobs_sum)
- pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
- ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
-
- logits = pi_logratios - ref_logratios
-
- if args.dpo_loss_type == "sigmoid":
- losses = -F.logsigmoid(self.beta * logits)
- elif args.dpo_loss_type == "ipo":
- losses = (logits - 1 / (2 * self.beta)) ** 2
- else:
- raise NotImplementedError(f"invalid loss type {self.loss_type}")
-
- loss = losses.mean()
- chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
- rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
-
- stats = {
- 'actor/pg_loss': loss.detach().item(),
- 'beta': self.beta,
- 'logps/chosen': chosen_logprobs_sum.mean().detach().item(),
- 'logps/rejected': rejected_logprobs_sum.mean().detach().item(),
- 'rewards/chosen': chosen_rewards.mean().detach().item(),
- 'rewards/rejected': rejected_rewards.mean().detach().item(),
- }
- return loss, stats
-
-
- def forward_step(batch_iter, model):
- batch = next(batch_iter)
- input_ids = batch['input_ids']
- attention_mask_1d = batch['attention_mask']
- position_ids = batch['position_ids']
- attention_mask = get_tune_attention_mask(attention_mask_1d)
- output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
- if forward_only:
- meta_info = None
- else:
- meta_info = {'clip_ratio': self.args.clip_ratio}
-
- loss_funcs = {
- "ray_ppo": loss_func_ppo,
- "ray_online_dpo": loss_func_online_dpo,
- "ray_grpo": loss_func_grpo
- }
-
- loss_func = loss_funcs.get(self.args.stage)
- return output, partial(loss_func, data=batch, meta_info=meta_info)
-
- # batch should be a list of batches inside micro-batches
- batch_generator = make_batch_generator(batches, vpp_size=len(self.model))
- losses_reduced = forward_backward_func(
- forward_step_func=forward_step,
- data_iterator=batch_generator,
- model=self.model,
- num_microbatches=n_micro_batch,
- seq_length=seq_len, # unused when variable_seq_lengths
- micro_batch_size=self.args.micro_batch_size, # unused when variable_seq_lengths
- forward_only=forward_only
- )
-
- return losses_reduced
-
- def update_policy(self, dataloader: Iterable[DataProto]) -> Dict:
- """Update the policy with an iterator of DataProto
-
- Args:
- dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator``
- The keys of each data batch is described in the make_minibatch_iterator.
-
- Returns:
- Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage
- and users have to combine the output in each dp rank manually.
-
- """
- metrics = {}
- model = self.model
- optimizer = self.optimizer
- opt_param_scheduler = self.opt_param_scheduler
-
-
- for model_module in self.model:
- model_module.train()
-
- for data in dataloader:
-
- for model_chunk in model:
- model_chunk.zero_grad_buffer()
- optimizer.zero_grad()
- if self.args.stage == 'ray_grpo':
- self.args.kl_ctrl = data.meta_info['kl_ctrl']
- metric_micro_batch = self.forward_backward_batch(data)
-
- update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
-
- if update_successful:
- increment = 1
- opt_param_scheduler.step(increment=increment)
-
- for metric in metric_micro_batch:
- append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics.
-
- self.args.consumed_train_samples += self.args.global_batch_size
- self.num_floating_point_operations_so_far += num_floating_point_operations(self.args, self.args.global_batch_size)
-
- # add empty cache after each compute
- torch.cuda.empty_cache()
-
- return metrics
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/critic.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/critic.py
deleted file mode 100644
index af2aa1b822d51982686000b4424f862832285839..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/workers/critic.py
+++ /dev/null
@@ -1,350 +0,0 @@
-import time
-import json
-import os
-from typing import Dict, Union, Iterable
-from functools import partial
-
-import ray
-import torch
-import torch_npu
-
-import megatron
-from megatron.training import print_rank_0
-from megatron.core.models.gpt.gpt_layer_specs import (
- get_gpt_layer_local_spec,
- get_gpt_layer_with_transformer_engine_spec,
-)
-from megatron.core.utils import get_model_config
-from megatron.core import mpu
-from megatron.training.arguments import core_transformer_config_from_args
-from megatron.training.yaml_arguments import core_transformer_config_from_yaml
-from megatron.core.transformer.spec_utils import import_module
-from megatron.core.models.gpt import GPTModel
-
-from megatron.training.training import (
- print_datetime,
- get_one_logger,
- append_to_progress_log,
-)
-from megatron.core.pipeline_parallel import get_forward_backward_func
-from megatron.training.checkpointing import save_checkpoint
-from megatron.training.training import num_floating_point_operations
-from megatron.training import get_args, initialize_megatron, get_timers
-
-from mindspeed_llm.tasks.posttrain.utils import append_to_dict
-from mindspeed_llm.training.utils import get_tune_attention_mask
-from mindspeed_llm.tasks.posttrain.base import BaseTrainer
-from mindspeed_llm.tasks.posttrain.orm.orm_model import GPTRewardModel
-from mindspeed_llm.training.initialize import set_jit_fusion_options
-
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import MegatronWorker
-from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import masked_mean, split_dict_tensor_into_batches
-from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, make_batch_generator
-from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import clip_by_value
-
-_TRAIN_START_TIME = time.time()
-
-
-@ray.remote
-class CriticWorker(MegatronWorker):
- def __init__(self, config, role):
- """
- """
- super().__init__()
-
- self.config = config
- self.role = role
- self.IGNORE_INDEX = -100
- os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
- initialize_megatron(role=self.role,
- config=self.config)
-
- self.args = get_args()
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def initialize(self):
- self.critic = MegatronPPOCritic()
-
- @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
- def compute_values(self, data: DataProto):
- data = data.to('cuda')
- values = self.critic.compute_values(data=data)
- output = DataProto.from_dict(tensors={'values': values})
- output = output.to('cpu')
- return output
-
- @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
- def update_critic(self, data: DataProto):
- data = data.to('cuda')
- dataloader = self.critic.make_minibatch_iterator(data)
- metrics = self.critic.update_critic(dataloader=dataloader)
- output = DataProto(batch=None, meta_info={'metrics': metrics})
- output = output.to('cpu')
- return output
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def save_checkpoint(self, iteration):
- self.critic.save_checkpoint(iteration)
-
-
-class MegatronPPOCritic(BaseTrainer):
- def __init__(self):
- super().__init__()
-
- def initialize(self):
- self.args = get_args()
- self.timers = get_timers()
- self.num_floating_point_operations_so_far = 0
-
- if self.args.log_progress:
- append_to_progress_log("Starting job")
- # Set pytorch JIT layer fusion options and warmup JIT functions.
- set_jit_fusion_options()
- # Adjust the startup time, so it reflects the largest value.
- # This will be closer to what scheduler will see (outside of
- # image ... launches.
- global _TRAIN_START_TIME
- start_time_tensor = torch.tensor(
- [_TRAIN_START_TIME],
- dtype=torch.float,
- device='cuda'
- )
- torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
- _TRAIN_START_TIME = start_time_tensor.item()
- print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME))
- print_datetime('after megatron is initialized')
- one_logger = get_one_logger()
- if one_logger:
- one_logger.log_metrics({
- 'train_iterations_warmup': 5
- })
-
- from megatron.training.training import setup_model_and_optimizer
- # Model, optimizer, and learning rate.
- self.timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
- model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
- self.model_provider, self.model_type)
-
- self.timers('model-and-optimizer-setup').stop()
- print_datetime('after model, optimizer, and learning rate '
- 'scheduler are built')
- model_config = get_model_config(model[0])
-
- self.model = model
- self.optimizer = optimizer
- self.opt_param_scheduler = opt_param_scheduler
- self.model_config = model_config
- self.process_non_loss_data_func = None
-
- def save_checkpoint(self, iteration):
- save_checkpoint(iteration, self.model, self.optimizer, self.opt_param_scheduler,
- self.num_floating_point_operations_so_far)
-
-
- @staticmethod
- def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
- """Builds the model.
-
- Currently supports only the mcore GPT model.
-
- Args:
- pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
- post_process (bool, optional): Set to true if you need to want to compute output logits/loss.
- Defaults to True.
-
- Returns:
- Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
- """
- args = get_args()
- use_te = args.transformer_impl == "transformer_engine"
-
- print_rank_0('building GPT model ...')
- # Experimental loading arguments from yaml
- if args.yaml_cfg is not None:
- config = core_transformer_config_from_yaml(args, "language_model")
- else:
- config = core_transformer_config_from_args(args)
-
- if not args.use_mcore_models:
- raise ValueError("Reward model training currently supports mcore only.")
-
- if args.spec is not None:
- transformer_layer_spec = import_module(args.spec)
- else:
- if use_te:
- transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts,
- args.moe_grouped_gemm)
- else:
- transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm)
-
- model = GPTRewardModel(
- config=config,
- transformer_layer_spec=transformer_layer_spec,
- vocab_size=args.padded_vocab_size,
- max_sequence_length=args.max_position_embeddings,
- pre_process=pre_process,
- post_process=post_process,
- post_layer_norm=not args.no_post_layer_norm,
- fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
- parallel_output=True,
- share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
- position_embedding_type=args.position_embedding_type,
- rotary_percent=args.rotary_percent,
- )
-
- return model
-
- def critic_data_padding(self, data: DataProto) -> DataProto:
- if 'response_mask' in data.batch.keys():
- return data
-
- prompt_length = data.batch['prompts'].shape[1]
- response_mask = data.batch['attention_mask']
- response_mask[..., :prompt_length] = 0
- data.batch['response_mask'] = response_mask
-
- return data
-
- def get_batch(self, data_iterator):
- self.timers('batch-generator', log_level=2).start()
- batch = next(data_iterator)
- input_ids = batch["input_ids"]
- attention_mask_1d = batch["attention_mask"]
- attention_mask = get_tune_attention_mask(attention_mask_1d)
- position_ids = batch["position_ids"]
-
- return batch, input_ids, attention_mask, position_ids
-
- def forward_backward_batch(self, data_proto: DataProto, forward_only=False):
- data_proto.batch = data_proto.batch.contiguous()
- args = get_args()
- data = data_proto.batch
-
- forward_batch_size = data["input_ids"].shape[0]
- forward_num_microbatches = forward_batch_size // args.micro_batch_size
-
- batches = split_dict_tensor_into_batches(data, batch_size=args.micro_batch_size)
- data_iterator = make_batch_generator(batches, vpp_size=len(self.model))
-
- forward_backward_func = get_forward_backward_func()
- losses_reduced = forward_backward_func(
- forward_step_func=self.forward_step,
- data_iterator=data_iterator,
- model=self.model,
- num_microbatches=forward_num_microbatches,
- seq_length=args.seq_length,
- micro_batch_size=args.micro_batch_size,
- decoder_seq_length=args.decoder_seq_length,
- collect_non_loss_data=forward_only,
- forward_only=forward_only)
-
- return losses_reduced
-
- def compute_values(self, data: DataProto):
- responses = data.batch['responses']
- attention_mask = data.batch['attention_mask']
- response_length = responses.size(1)
-
- for model_module in self.model:
- model_module.eval()
-
- with torch.no_grad():
- output = self.forward_backward_batch(data, forward_only=True)
-
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- values = torch.cat(output, dim=0).squeeze(-1)
- values = values.to(torch.float32)
- else:
- values = torch.empty_like(attention_mask, dtype=torch.float32)
-
- values = values * attention_mask
- values = values[:, -response_length - 1:-1]
- values = values.contiguous()
-
- self.args.consumed_train_samples += self.args.global_batch_size
- self.num_floating_point_operations_so_far += num_floating_point_operations(self.args, self.args.global_batch_size)
- torch.cuda.empty_cache()
- return values
-
- def update_critic(self, dataloader: Iterable[DataProto]):
- metrics = {}
- for model_module in self.model:
- model_module.train()
-
- for data in dataloader:
- for model_chunk in self.model:
- model_chunk.zero_grad_buffer()
-
- self.optimizer.zero_grad()
-
- metric_micro_batch = self.forward_backward_batch(data, forward_only=False)
-
- # Empty unused memory.
- if self.args.empty_unused_memory_level >= 1:
- torch.cuda.empty_cache()
-
- update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step()
-
- # # Update learning rate.
- if update_successful:
- increment = self.args.critic_mini_batch_size
- self.opt_param_scheduler.step(increment=increment)
-
- for metric in metric_micro_batch:
- append_to_dict(metrics, metric)
-
- torch.cuda.empty_cache()
- return metrics
-
- def make_minibatch_iterator(self, data: DataProto):
- select_keys = data.batch.keys()
- data = data.select(batch_keys=select_keys)
- return data.make_iterator(mini_batch_size=self.args.critic_mini_batch_size,
- epochs=self.args.critic_update_epochs,
- )
-
- def loss_func(self, data, output_tensor, non_loss_data=False):
- if non_loss_data:
- return output_tensor
-
- responses = data['responses']
- response_length = responses.size(1)
- attention_mask = data['attention_mask']
- eos_mask = attention_mask[:, -response_length:]
- eos_p1_index = torch.min(torch.cumsum(eos_mask, dim=-1)[:, -1],
- torch.tensor(eos_mask.shape[1], device=eos_mask.device))
- eos_mask[:, eos_p1_index] = 1
-
- cliprange_value = self.args.cliprange_value
- curr_values = output_tensor.squeeze(-1)
- curr_values = curr_values[:, -response_length - 1:-1]
- curr_values = torch.masked_fill(curr_values, ~eos_mask, 0)
-
- returns = data['returns']
-
- if cliprange_value > 0.0:
- prev_values = data['values']
- vpredclipped = clip_by_value(curr_values, prev_values - cliprange_value, prev_values + cliprange_value)
- vf_losses1 = (vpredclipped - returns) ** 2
- else:
- vf_losses1 = torch.tensor(0.0).to(curr_values.device)
-
- vf_losses2 = (curr_values - returns) ** 2
-
- vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
- vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
-
- stats = {
- 'critic/vf_loss': vf_loss.detach().item(),
- 'critic/vf_clipfrac': vf_clipfrac.detach().item(),
- 'critic/vpred_mean': masked_mean(curr_values, eos_mask).detach().item(),
- }
-
- return vf_loss, stats
-
- def forward_step(self, data_iterator, model):
- batch, input_ids, attention_mask, position_ids = self.get_batch(data_iterator)
- output_tensor = model(input_ids, position_ids, attention_mask)
-
- return output_tensor, partial(self.loss_func, batch)
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py
deleted file mode 100644
index 268cd8b9b661864e101f6148d5d482c51f884440..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/workers/reference.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
-import os
-import time
-from functools import partial
-
-import ray
-import torch
-
-from megatron.training import get_args, get_timers, get_one_logger
-from megatron.training import print_rank_0
-from megatron.core import mpu
-from megatron.core.pipeline_parallel import get_forward_backward_func
-from megatron.training.training import append_to_progress_log, print_datetime, get_model
-from megatron.training.utils import unwrap_model
-from megatron.training.initialize import initialize_megatron
-from mindspeed_llm.tasks.checkpoint.models import load_checkpoint
-from mindspeed_llm.tasks.posttrain.base import BaseTrainer
-from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto, make_batch_generator
-from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import split_dict_tensor_into_batches
-from mindspeed_llm.tasks.posttrain.utils import compute_log_probs
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import MegatronWorker
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch
-from mindspeed_llm.training.initialize import set_jit_fusion_options
-from mindspeed_llm.training.utils import get_tune_attention_mask
-
-_TRAIN_START_TIME = time.time()
-
-
-@ray.remote
-class ReferenceWorker(MegatronWorker):
- """
- Ray ReferenceWorker
- """
-
- def __init__(self, config, role):
- super().__init__()
-
- self.config = config
- self.role = role
-
- os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
- initialize_megatron(role=self.role,
- config=self.config)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def initialize(self):
- self.reference = MegatronPPOReference()
-
- @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
- def compute_ref_log_prob(self, data: DataProto):
- output = self.reference.compute_log_prob(data=data)
- if output is not None:
- output = DataProto.from_dict(tensors={'ref_log_prob': output})
- output = output.to('cpu')
- torch.cuda.empty_cache()
- return output
-
-
-class MegatronPPOReference(BaseTrainer):
- def __init__(self):
- super().__init__()
-
- def initialize(self):
- self.args = get_args()
- self.timers = get_timers()
-
- if self.args.log_progress:
- append_to_progress_log("Starting job")
- # Set pytorch JIT layer fusion options and warmup JIT functions.
- set_jit_fusion_options()
- # Adjust the startup time, so it reflects the largest value.
- # This will be closer to what scheduler will see (outside of
- # image ... launches.
- global _TRAIN_START_TIME
- start_time_tensor = torch.tensor(
- [_TRAIN_START_TIME],
- dtype=torch.float,
- device='cuda'
- )
- torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
- _TRAIN_START_TIME = start_time_tensor.item()
- print_rank_0('Time to initialize Megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME))
- print_datetime('after megatron is initialized')
- one_logger = get_one_logger()
- if one_logger:
- one_logger.log_metrics({
- 'train_iterations_warmup': 5
- })
-
- self.timers('model-setup', log_level=0).start(barrier=True)
-
- self.model = get_model(self.model_provider, self.model_type, wrap_with_ddp=False)
- unwrapped_model = unwrap_model(self.model)
- if self.args.stage == "ray_online_dpo":
- self.args.micro_batch_size *= 2
-
- if self.args.load is not None or self.args.pretrained_checkpoint is not None:
- self.timers('load-checkpoint', log_level=0).start(barrier=True)
- self.args.iteration, self.args.num_floating_point_operations_so_far = load_checkpoint(
- self.model, None, None)
- self.timers('load-checkpoint').stop(barrier=True)
- self.timers.log(['load-checkpoint'])
- else:
- self.args.iteration = 0
- self.args.num_floating_point_operations_so_far = 0
-
- # get model without FP16 and/or DDP wrappers
- if self.args.iteration == 0 and len(unwrapped_model) == 1 \
- and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
- print_rank_0("Initializing ICT from pretrained BERT model")
- unwrapped_model[0].init_state_dict_from_bert()
-
- self.timers('model-setup').stop()
- print_datetime('after model built')
-
- def compute_log_prob(self, data: DataProto):
-
- data.batch = data.batch.contiguous()
-
- for model_module in self.model:
- model_module.eval()
-
- data = data.to(next(self.model[0].parameters()).device)
- with torch.no_grad():
- output = self.forward_backward_batch(data, forward_only=True)
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- ref_log_probs = torch.cat([out['ref_log_probs'] for out in output], dim=0) # (bs, seq_size)
- ref_log_probs = ref_log_probs.to(torch.float32)
- else:
- ref_log_probs = None
-
- return ref_log_probs
-
- def forward_backward_batch(self, data, forward_only=False):
- """
- We assume:
- - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input
- - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled
- """
- args = get_args()
- data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)
-
- if data.meta_info.get('micro_batch_size', None) is not None:
- batch_size = data.meta_info['micro_batch_size']
- else:
- batch_size = args.micro_batch_size
- batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size)
-
- n_micro_batch = len(batches)
- seq_len = batches[0]['input_ids'].shape[1]
-
- forward_backward_func = get_forward_backward_func()
-
- # batch should be a list of batches inside micro-batches
- batch_generator = make_batch_generator(batches, vpp_size=len(self.model))
- losses_reduced = forward_backward_func(
- forward_step_func=self.forward_step,
- data_iterator=batch_generator,
- model=self.model,
- num_microbatches=n_micro_batch,
- seq_length=seq_len, # unused when variable_seq_lengths
- micro_batch_size=args.micro_batch_size, # unused when variable_seq_lengths
- forward_only=forward_only
- )
- return losses_reduced
-
- def forward_step(self, batch_iter, model):
- input_ids, attention_mask, position_ids, response = self.get_batch(batch_iter)
- output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
- return output, partial(self.loss_func, response=response)
-
- def get_batch(self, batch_iter):
- batch = next(batch_iter)
- input_ids = batch['input_ids']
- attention_mask_1d = batch['attention_mask']
- position_ids = batch['position_ids']
- attention_mask = get_tune_attention_mask(attention_mask_1d)
- response = batch['responses']
- return input_ids, attention_mask, position_ids, response
-
- def loss_func(self, output, response):
- response_length = response.size(1)
- logits = output
- logits = logits[:, -response_length - 1:-1]
- log_probs = compute_log_probs(logits, response, per_token=True)[2]
- return 1.0, {'ref_log_probs': log_probs}
diff --git a/mindspeed_llm/tasks/posttrain/rlxf/workers/reward.py b/mindspeed_llm/tasks/posttrain/rlxf/workers/reward.py
deleted file mode 100644
index f46cf8089d87077421f97dcca4804e6e57a06a26..0000000000000000000000000000000000000000
--- a/mindspeed_llm/tasks/posttrain/rlxf/workers/reward.py
+++ /dev/null
@@ -1,231 +0,0 @@
-# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
-import os
-import time
-from functools import partial
-import re
-
-import ray
-import torch
-from transformers import AutoTokenizer
-
-from megatron.training import print_rank_0
-from megatron.training import get_args, get_timers, get_one_logger
-from megatron.training.utils import unwrap_model
-from megatron.training.initialize import initialize_megatron
-from megatron.training.checkpointing import load_checkpoint
-from megatron.training.training import append_to_progress_log, print_datetime, get_model
-from megatron.core import mpu
-from megatron.core.utils import get_model_config
-from megatron.core.pipeline_parallel import get_forward_backward_func
-
-from mindspeed_llm.training.initialize import set_jit_fusion_options
-from mindspeed_llm.training.utils import get_tune_attention_mask
-from mindspeed_llm.tasks.posttrain.orm import ORMTrainer
-from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import pad_to_tensor_dict, \
- generate_attention_mask, generate_position_ids_from_attention_mask
-from mindspeed_llm.tasks.posttrain.rlxf.utils.protocol import DataProto
-from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import split_dict_tensor_into_batches
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.megatron.worker import MegatronWorker
-from mindspeed_llm.tasks.posttrain.rlxf.single_controller.base.decorator import register, Dispatch
-
-
-_TRAIN_START_TIME = time.time()
-
-
-@ray.remote
-class RewardWorker(MegatronWorker):
- """
- Ray RewardWorker
- """
- def __init__(self, config, role):
- super().__init__()
-
- self.config = config
- self.role = role
- self.rm = None
-
- os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
- initialize_megatron(role=self.role,
- config=self.config)
-
- @register(dispatch_mode=Dispatch.ONE_TO_ALL)
- def initialize(self):
- self.rm = MegatronPPORM()
-
- @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
- def compute_rm_score(self, data: DataProto):
- data = data.to('cuda')
- output = self.rm.compute_rm_score(data=data)
- output = DataProto.from_dict(tensors={'rm_scores': output})
- output = output.to('cpu')
- torch.cuda.empty_cache()
- return output
-
-
-class MegatronPPORM(ORMTrainer):
- def __init__(self):
- super().__init__()
-
- def initialize(self):
- self.args = get_args()
- self.timers = get_timers()
-
- if self.args.log_progress:
- append_to_progress_log("Starting job")
- # Set pytorch JIT layer fusion options and warmup JIT functions.
- set_jit_fusion_options()
- # Adjust the startup time, so it reflects the largest value.
- # This will be closer to what scheduler will see (outside of
- # image ... launches.
- global _TRAIN_START_TIME
- start_time_tensor = torch.tensor(
- [_TRAIN_START_TIME],
- dtype=torch.float,
- device='cuda'
- )
- torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
- _TRAIN_START_TIME = start_time_tensor.item()
- print_rank_0('Time to initialize Megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME))
- print_datetime('after megatron is initialized')
- one_logger = get_one_logger()
- if one_logger:
- one_logger.log_metrics({
- 'train_iterations_warmup': 5
- })
-
- if self.args.stage == "ray_online_dpo":
- self.args.micro_batch_size *= 2
- self.timers('model-setup', log_level=0).start(barrier=True)
-
- model = get_model(self.model_provider, self.model_type, wrap_with_ddp=False)
- unwrapped_model = unwrap_model(model)
-
- if self.args.load is not None or self.args.pretrained_checkpoint is not None:
- self.timers('load-checkpoint', log_level=0).start(barrier=True)
- self.args.iteration, self.args.num_floating_point_operations_so_far = load_checkpoint(
- model, None, None, strict=True)
- self.timers('load-checkpoint').stop(barrier=True)
- self.timers.log(['load-checkpoint'])
- else:
- self.args.iteration = 0
- self.args.num_floating_point_operations_so_far = 0
-
- # get model without FP16 and/or DDP wrappers
- if self.args.iteration == 0 and len(unwrapped_model) == 1 \
- and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
- print_rank_0("Initializing ICT from pretrained BERT model")
- unwrapped_model[0].init_state_dict_from_bert()
-
- self.timers('model-and-optimizer-setup').stop()
- print_datetime('after model built')
- config = get_model_config(model[0])
-
- # Print setup timing.
- self.train_args = [self.forward_step, model, config]
- self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer_name_or_path)
-
- def compute_rm_score(self, data: DataProto):
- args = get_args()
- forward_step_func, model, config = self.train_args
- prompt_lens = data.batch["prompts"].size(1)
-
- for model_module in model:
- model_module.eval()
-
- with torch.no_grad():
- batches = split_dict_tensor_into_batches(data.batch, batch_size=args.micro_batch_size)
- n_micro_batch = len(batches)
-
- batch_generator = iter(batches)
- forward_backward_func = get_forward_backward_func()
-
- rm_score = forward_backward_func(
- forward_step_func=forward_step_func,
- data_iterator=batch_generator,
- model=model,
- num_microbatches=n_micro_batch,
- seq_length=args.seq_length,
- micro_batch_size=args.micro_batch_size,
- decoder_seq_length=args.decoder_seq_length,
- collect_non_loss_data=True,
- forward_only=True
- )
-
- # Empty unused memory
- if args.empty_unused_memory_level >= 1:
- torch.cuda.empty_cache()
- if mpu.is_pipeline_last_stage():
- rm_score = torch.cat(rm_score, dim=0).squeeze(-1) # (bs, seq_size)
- rm_score = rm_score.to(torch.float32)
- rm_score = rm_score[:, prompt_lens:]
- else:
- rm_score = torch.zeros(1)
-
- return rm_score
-
- def forward_step(self, data_iterator, model):
- """ReWardModel forward step to calculate rm scores.
-
- Args:
- data_iterator : Data iterator which wait to get input ids from Queue generated in Actor Server
- model (GPTModel): The GPT Model
- """
- self.timers('batch-generator', log_level=2).start()
- input_ids, attention_mask, position_ids = self._get_tokens(data_iterator)
- self.timers('batch-generator').stop()
-
- scores = model(input_ids, position_ids, attention_mask)
-
- return scores, self.loss_func
-
- def loss_func(self, scores: torch.Tensor, non_loss_data=False):
- return scores
-
- def _get_tokens(self, data_iterator):
- self.timers('batch-generator', log_level=2).start()
- batch = next(data_iterator)
-
- if self.args.extract_content_for_reward:
- str_responses = tokenizer.batch_decode(batch["responses"])
- pattern = r'(.*?)'
- contents = []
- for str_response in str_responses:
- first_pad_position = str_response.find(tokenizer.pad_token)
- if first_pad_position != -1:
- str_response = str_response[:first_pad_position]
- within_answer = re.findall(pattern, str_response)
- if within_answer:
- content = within_answer[0]
- else:
- content = str_response
- contents.append(tokenizer.encode(content))
-
- responses_ori_length, responses_pad_length = pad_to_tensor_dict(
- contents,
- pad_multi_of=self.args.pad_to_multiple_of
- )
-
- prompts = batch["prompts"]
- prompts_pad_length = torch.LongTensor([len(prompts[0])]).cuda()
- pad_token_id = tokenizer.pad_token_id
- prompts_ori_length = [len(prompts[i]) - (prompts[i] == pad_token_id).sum().item() for i in range(len(prompts))]
- prompts = prompts.cpu().numpy().tolist()
-
- input_ids = [prompt + response for prompt, response in zip(prompts, contents)]
- attention_mask = generate_attention_mask(input_ids, prompts_ori_length, prompts_pad_length,
- responses_ori_length, responses_pad_length)
- position_ids = generate_position_ids_from_attention_mask(input_ids, prompts_ori_length, prompts_pad_length)
-
- device = batch["input_ids"].device
- input_ids = torch.tensor(input_ids).long().to(device)
- attention_mask_1d = torch.tensor(attention_mask).long().to(device)
- attention_mask = get_tune_attention_mask(attention_mask_1d)
- position_ids = torch.tensor(position_ids).long().to(device)
-
- else:
- input_ids = batch["input_ids"]
- attention_mask_1d = batch["attention_mask"]
- attention_mask = get_tune_attention_mask(attention_mask_1d)
- position_ids = batch["position_ids"]
-
- return input_ids, attention_mask, position_ids
diff --git a/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOEngine.py b/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOEngine.py
index 45c6761f67892db39f9698bd61397a68911536e2..6ec5c27183fe048324b7c7cf8aa9f10337ac9f18 100644
--- a/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOEngine.py
+++ b/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOEngine.py
@@ -22,7 +22,6 @@ from megatron.training.training import (
get_optimizer_param_scheduler,
build_train_valid_test_data_iterators
)
-from mindspeed_llm.tasks.posttrain.rlxf.utils.torch_functional import masked_mean, masked_whiten
from mindspeed_llm.training.utils import get_tune_attention_mask
from mindspeed_llm.tasks.posttrain.utils import train_valid_test_datasets_provider
from mindspeed_llm.tasks.posttrain.trl_ppo.actor_model import ActorModel
@@ -316,3 +315,36 @@ class TrlPPOEngine():
def set_model_train(self, model):
for model_module in model:
model_module.train()
+
+
+def masked_mean(values, mask, axis=None):
+ """Compute mean of tensor with a masked values."""
+ return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
+
+
+def masked_var(values, mask, unbiased=True):
+ """Compute variance of tensor with masked values."""
+ mean = masked_mean(values, mask)
+ centered_values = values - mean
+ variance = masked_mean(centered_values ** 2, mask)
+ if unbiased:
+ mask_sum = mask.sum()
+ if mask_sum == 0:
+ raise ValueError("At least one element in the mask has to be 1.")
+ # note that if mask_sum == 1, then there is a division by zero issue
+ # to avoid it you just need to use a larger minibatch_size
+ elif mask_sum == 1:
+ bessel_correction = mask_sum
+ else:
+ bessel_correction = mask_sum / (mask_sum - 1)
+ variance = variance * bessel_correction
+ return variance
+
+
+def masked_whiten(values, mask, shift_mean=True):
+ """Whiten values with masked values."""
+ mean, var = masked_mean(values, mask), masked_var(values, mask)
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
+ if not shift_mean:
+ whitened += mean
+ return whitened
diff --git a/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOTrainer.py b/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOTrainer.py
index 23e145a86d737a5df78c9b000b42cc342ed79b62..ca6e57ce5afa1082513bdd3f4a3d577cdc8e38c1 100644
--- a/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOTrainer.py
+++ b/mindspeed_llm/tasks/posttrain/trl_ppo/TrlPPOTrainer.py
@@ -25,7 +25,6 @@ from megatron.training.training import (
training_log
)
from megatron.inference.text_generation.communication import broadcast_from_last_pipeline_stage
-from mindspeed_llm.tasks.posttrain.rlxf.workers.actor_train_infer import generate_attention_mask, generate_position_ids_from_attention_mask
from mindspeed_llm.tasks.posttrain.trl_ppo.TrlPPOEngine import TrlPPOEngine
from mindspeed_llm.training.training import get_profiler, is_profile_enabled
@@ -601,3 +600,53 @@ class TrlPPOTrainer():
args.micro_batch_size = origin_micro_batch_size
return output_tensor
+
+
+def generate_position_ids_from_attention_mask(input_ids_list, prompts_ori_length, prompts_pad_length):
+ """
+ 生成与 attention_mask 对应的 position_ids 列表。
+
+ 参数:
+ input_ids_list (list of lists): 包含 input_ids 的列表,每个元素是一个列表。
+ prompts_ori_length (list of lists): 包含 prompt_ori_length 的列表,每个元素是int。
+ prompts_pad_length int: prompts_pad_length,int。
+
+ 返回:
+ list of lists: 包含 position_ids 的列表,每个元素是一个列表。
+ """
+ position_ids_list = []
+ for idx, input_ids in enumerate(input_ids_list):
+ prompt_pad_length = prompts_pad_length - prompts_ori_length[idx]
+ position_ids = [0] * prompt_pad_length + list(range(len(input_ids) - prompt_pad_length))
+ position_ids_list.append(position_ids)
+
+ return position_ids_list
+
+
+def generate_attention_mask(input_ids_list, prompts_ori_length, prompts_pad_length, responses_ori_length,
+ responses_pad_length):
+ """
+ 生成与 input_ids 对应的 attention_mask 列表。
+
+ 参数:
+ input_ids_list (list of lists): 包含 input_ids 的列表,每个元素是一个列表。
+ prompts_ori_length (list of lists): 包含 prompt_ori_length 的列表,每个元素是int。
+ prompts_pad_length int: prompts_pad_length,int。
+ responses_ori_length (list of lists): 包含 response_ori_length 的列表,每个元素是int。
+ responses_pad_length int: responses_pad_length,int。
+
+ 返回:
+ list of lists: 包含 attention_mask 的列表,每个元素是一个列表。
+ """
+ attention_mask_list = []
+
+ for idx, input_ids in enumerate(input_ids_list):
+ attention_mask = torch.ones_like(torch.tensor(input_ids))
+ prompt_pad_length = prompts_pad_length - prompts_ori_length[idx]
+ response_pad_length = responses_pad_length - responses_ori_length[idx]
+ attention_mask[:prompt_pad_length] = 0
+ if response_pad_length > 0:
+ attention_mask[-response_pad_length:] = 0
+ attention_mask_list.append(attention_mask.numpy().tolist())
+
+ return attention_mask_list
diff --git a/ray_gpt.py b/ray_gpt.py
deleted file mode 100644
index 70c2a5b64fc32608c832e1db208d2c1d9bcbd199..0000000000000000000000000000000000000000
--- a/ray_gpt.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
-"""
-import ray
-import hydra
-
-from mindspeed_llm.tasks.posttrain.launcher import get_trainer
-
-
-@hydra.main(config_path='configs/rlxf', config_name='ppo_trainer_llama32_1b', version_base=None)
-def main(config):
- if not ray.is_initialized():
- # this is for local ray cluster
- ray.init(runtime_env={
- 'env_vars': {"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "True",
- 'TOKENIZERS_PARALLELISM': 'true',
- 'NCCL_DEBUG': 'WARN'}})
-
- ray.get(main_task.remote(config))
-
-
-@ray.remote
-def main_task(config):
- trainer = get_trainer(config.training.stage)(config)
- trainer.train()
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index ed68f4313bae7031fdc166b63b2f379137f59cbc..534cfc296b84375e3fa5dfb23e029c72859528b1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,8 +14,6 @@ protobuf
peft==0.7.1
tiktoken
ray==2.10.0
-tensordict==0.1.2
-hydra-core==1.3.2
codetiming
bitsandbytes-npu-beta==0.45.3
word2number
diff --git a/tests/rlxf/ray_grpo_full_llama32_1b_tp1pp1.sh b/tests/rlxf/ray_grpo_full_llama32_1b_tp1pp1.sh
deleted file mode 100644
index cde21233b0b57fc4a5b861d2dab0c73a658ab2e0..0000000000000000000000000000000000000000
--- a/tests/rlxf/ray_grpo_full_llama32_1b_tp1pp1.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-#!/bin/bash
-export CUDA_DEVICE_MAX_CONNECTIONS=1
-export HCCL_DETERMINISTIC=True
-
-
-basepath=$(cd `dirname $0`; cd ../../../; pwd)
-
-
-python $basepath/ray_gpt.py --config-dir=$basepath/tests/pipeline/configs --config-name=ray_grpo_full_llama32_1b_tp1pp1
\ No newline at end of file
diff --git a/tests/rlxf/ray_online_dpo_full_llama32_1b_tp1pp1.sh b/tests/rlxf/ray_online_dpo_full_llama32_1b_tp1pp1.sh
deleted file mode 100644
index eff76dc8532438bf8244cf44fcfab71872192ea6..0000000000000000000000000000000000000000
--- a/tests/rlxf/ray_online_dpo_full_llama32_1b_tp1pp1.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-#!/bin/bash
-export CUDA_DEVICE_MAX_CONNECTIONS=1
-export HCCL_DETERMINISTIC=True
-
-
-basepath=$(cd `dirname $0`; cd ../../../; pwd)
-
-
-python $basepath/ray_gpt.py --config-dir=$basepath/tests/pipeline/configs --config-name=ray_online_dpo_full_llama32_1b_tp1pp1
\ No newline at end of file
diff --git a/tests/rlxf/ray_ppo_full_llama32_1b_tp1pp1.sh b/tests/rlxf/ray_ppo_full_llama32_1b_tp1pp1.sh
deleted file mode 100644
index 80bcd00c32f717f113187da84bbf5be65accc50b..0000000000000000000000000000000000000000
--- a/tests/rlxf/ray_ppo_full_llama32_1b_tp1pp1.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-#!/bin/bash
-export CUDA_DEVICE_MAX_CONNECTIONS=1
-export HCCL_DETERMINISTIC=True
-
-
-basepath=$(cd `dirname $0`; cd ../../../; pwd)
-
-
-python $basepath/ray_gpt.py --config-dir=$basepath/tests/pipeline/configs --config-name=ray_ppo_full_llama32_1b_tp1pp1
\ No newline at end of file
diff --git a/tests/st/configs/model/llama32-1b.yaml b/tests/st/configs/model/llama32-1b.yaml
deleted file mode 100644
index 324548e026623607a47cfb6c480c2376496626a9..0000000000000000000000000000000000000000
--- a/tests/st/configs/model/llama32-1b.yaml
+++ /dev/null
@@ -1,35 +0,0 @@
-llama32-1b:
- use_mcore_models: true
- sequence_parallel: true
- use_mc2: true
- use_flash_attn: true
- use_rotary_position_embeddings: true
- use_fused_rmsnorm: true
- use_fused_swiglu: true
- rope_scaling_type: llama3
- rope_scaling_factor: 32.0
- low_freq_factor: 1.0
- high_freq_factor: 4.0
- original_max_position_embeddings: 8192
- max_position_embeddings: 8192
- num_layers: 16
- hidden_size: 2048
- ffn_hidden_size: 8192
- num_attention_heads: 32
- group_query_attention: true
- num_query_groups: 8
- make_vocab_size_divisible_by: 1
- padded_vocab_size: 128256
- disable_bias_linear: true
- attention_dropout: 0.0
- init_method_std: 0.01
- hidden_dropout: 0.0
- position_embedding_type: rope
- rotary_base: 500000
- normalization: RMSNorm
- norm_epsilon: 1e-5
- swiglu: true
- no_masked_softmax_fusion: true
- attention_softmax_in_fp32: true
- no_gradient_accumulation_fusion: true
- bf16: true
\ No newline at end of file