diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_GRPO_full_32p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_GRPO_full_32p.sh index 2e327fdc19e3164880302ad58038d28047f0b9eb..5526e57e1fc92580e0f0af1ac5f85fb6c28545e7 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_GRPO_full_32p.sh +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_GRPO_full_32p.sh @@ -112,6 +112,12 @@ nohup python3 -m verl.trainer.main_ppo \ trainer.nnodes=2 \ trainer.save_freq=-1 \ trainer.test_freq=10 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.ref.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.actor.use_entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.use_entropy_from_logits_with_chunking=True \ trainer.total_epochs=15 > ${test_path_dir}/output/train_verl_qwen2_5_32b_instruct_grpo_full.log 2>&1 & wait diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_GRPO_performance_32p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_GRPO_performance_32p.sh index 13050adae1f60a5435cee1013fc0d1c81c6598dd..fdeec9fbbd19b7e901a228be56ce65d62f499419 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_GRPO_performance_32p.sh +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_GRPO_performance_32p.sh @@ -112,6 +112,12 @@ nohup python3 -m verl.trainer.main_ppo \ trainer.nnodes=2 \ trainer.save_freq=-1 \ trainer.test_freq=10 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.ref.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.actor.use_entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.use_entropy_from_logits_with_chunking=True \ trainer.total_epochs=1 > ${test_path_dir}/output/train_verl_qwen2_5_32b_instruct_grpo_perf.log 2>&1 & wait diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_GRPO_full_16p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_GRPO_full_16p.sh index 3dbc768943ad94837d9e5e0660576673d6051d05..e0852937a22bd677c01da75ba1e3899a5ffb67b3 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_GRPO_full_16p.sh +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_GRPO_full_16p.sh @@ -113,6 +113,12 @@ nohup python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.ref.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.actor.use_entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.use_entropy_from_logits_with_chunking=True \ trainer.total_epochs=5 > ${test_path_dir}/output/train_verl_qwen2_5_7b_instruct_grpo_full.log 2>&1 & wait diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_GRPO_performance_16p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_GRPO_performance_16p.sh index 9a4b342be3bc73cb22a2d751a8568ab933616cef..d475a5c78c29b85772d6b21b853c211941ad89a0 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_GRPO_performance_16p.sh +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_GRPO_performance_16p.sh @@ -113,6 +113,12 @@ nohup python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.ref.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.actor.use_entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.use_entropy_from_logits_with_chunking=True \ trainer.total_epochs=1 > ${test_path_dir}/output/train_verl_qwen2_5_7b_instruct_grpo_perf.log 2>&1 & wait diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py index 0aa137564469865c49a3f7dc877573ff3a4ae2b0..0da8ff8c4897b95a942b063eadfb4e4da4df8747 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py @@ -18,10 +18,13 @@ import torch import torch.nn.functional as F import torch_npu from torch_npu import npu_rotary_mul as apply_rotary_emb +from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP as Qwen2MLPLLM +from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm as Qwen2RMSNormLLM from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import apply_rotary_pos_emb_vision, \ Qwen2_5_VLVisionSdpaAttention, Qwen2RMSNorm, Qwen2_5_VLMLP, Qwen2MLP from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl from transformers.models.qwen2_vl import modeling_qwen2_vl +from transformers.models.qwen2 import modeling_qwen2 from transformers.utils import logging @@ -40,6 +43,14 @@ def apply_rotary_pos_emb_flashatt_npu( return q_embed, k_embed +def apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + return q_embed, k_embed + + def sdpa_forward( self, hidden_states: torch.Tensor, @@ -106,3 +117,7 @@ modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashat Qwen2_5_VLMLP.forward = silu_forward Qwen2MLP.forward = silu_forward modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = apply_multimodal_rotary_pos_emb_npu + +Qwen2RMSNormLLM.forward = rms_norm_forward +Qwen2MLPLLM.forward = silu_forward +modeling_qwen2.apply_rotary_pos_emb = apply_rotary_pos_emb_npu