From 32d1c93c721c872a7f396c436f3da3c96bae9c82 Mon Sep 17 00:00:00 2001 From: sunyi001 <1659275352@qq.com> Date: Wed, 14 May 2025 11:32:19 +0800 Subject: [PATCH] change vllm-ascend version to 0.7.3 --- .../built-in/rl/VeRL_for_PyTorch/README.md | 22 +- .../rl/VeRL_for_PyTorch/requirements-npu.txt | 3 +- .../test/train_qwen2_5_vl_3b_full_8p.sh | 8 +- .../train_qwen2_5_vl_3b_performance_8p.sh | 8 +- .../test/train_qwen2_5_vl_7b_full_16p.sh | 2 +- .../train_qwen2_5_vl_7b_performance_16p.sh | 2 +- .../VeRL_for_PyTorch/verl/utils/npu_patch.py | 8 +- .../vllm_ascend_need/qwen2_5_vl.py | 271 ++++++++++++++++++ .../vllm_ascend_need/rotary_embedding.py | 130 +++++++++ 9 files changed, 437 insertions(+), 17 deletions(-) create mode 100644 PyTorch/built-in/rl/VeRL_for_PyTorch/vllm_ascend_need/qwen2_5_vl.py create mode 100644 PyTorch/built-in/rl/VeRL_for_PyTorch/vllm_ascend_need/rotary_embedding.py diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/README.md b/PyTorch/built-in/rl/VeRL_for_PyTorch/README.md index a7575f9e56..c7956aefb8 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/README.md +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/README.md @@ -69,7 +69,8 @@ verl‌是一个集SFT(监督学习)与RL(强化学习)于一体的灵 - 安装vLLM和vLLM Ascend ```shell - git clone -b v0.7.3 https://github.com/vllm-project/vllm.git + # 安装目录不能放在模型根目录下 + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm.git cd vllm pip install -r requirements-build.txt VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/ @@ -78,6 +79,16 @@ verl‌是一个集SFT(监督学习)与RL(强化学习)于一体的灵 source /usr/local/Ascend/ascend-toolkit/set_env.sh source /usr/local/Ascend/nnal/atb/set_env.sh + # 对于VL模型,编译并安装vllm-ascend v0.7.3 + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm-ascend.git + cp -f 模型目录/vllm_ascend_need/qwen2_5_vl.py vllm-ascend/vllm_ascend/models/ + cp -f 模型目录/vllm_ascend_need/rotary_embedding.py vllm-ascend/vllm_ascend/ops/ + cd vllm-ascend + export COMPILE_CUSTOM_KERNELS=1 + python setup.py install + cd .. + + # 对于LLM模型,编译并安装vllm-ascend 特定commit id代码 git clone https://github.com/vllm-project/vllm-ascend.git cd vllm-ascend git checkout edeadde387451ca982fe3717555c1841ee195718 @@ -88,19 +99,20 @@ verl‌是一个集SFT(监督学习)与RL(强化学习)于一体的灵 - 克隆transformers仓并切换到对应的commit id ```shell - git clone https://github.com/huggingface/transformers.git + git clone --depth 1 https://github.com/huggingface/transformers.git cd transformers + git fetch --depth 1 origin aa17cfb4d532239336d2f89e06f01d48387292a3 git checkout aa17cfb4d532239336d2f89e06f01d48387292a3 pip install -e . cd .. ``` -- 克隆torchvision仓并切换到v0.20.1 +- 对于VL模型,需要安装torchvision,克隆torchvision仓并切换到v0.20.1 ```shell - git clone https://github.com/pytorch/vision.git + git clone -b v0.20.1 --depth 1 https://github.com/pytorch/vision.git cd vision - git checkout v0.20.1 python setup.py bdist_wheel + # 安装`torchvision`前,需要先执行`pip uninstall torchvision`卸载原来的`torchvision`,如果环境中有`triton`,需要执行`pip uninstall triton`卸载 pip install dist/*.whl cd .. ``` diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt b/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt index 741b13e729..a01d1514a4 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt @@ -4,7 +4,7 @@ codetiming datasets dill hydra-core -numpy +numpy==1.26.4 pandas peft pyarrow>=15.0.0 @@ -14,4 +14,5 @@ ray tensordict<0.6 transformers>=4.51.0 mathruler +torchdata wandb \ No newline at end of file diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_3b_full_8p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_3b_full_8p.sh index cc649111ec..407a5b3821 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_3b_full_8p.sh +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_3b_full_8p.sh @@ -87,8 +87,8 @@ nohup python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.path=$model_path \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.01 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ @@ -97,7 +97,7 @@ nohup python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=$ENGINE \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ @@ -105,7 +105,7 @@ nohup python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.enforce_eager=False \ actor_rollout_ref.rollout.free_cache_engine=False \ actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_3b_performance_8p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_3b_performance_8p.sh index 14f5a83710..1b637d8a35 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_3b_performance_8p.sh +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_3b_performance_8p.sh @@ -87,8 +87,8 @@ nohup python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.path=$model_path \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.01 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ @@ -97,7 +97,7 @@ nohup python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=$ENGINE \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ @@ -105,7 +105,7 @@ nohup python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.enforce_eager=False \ actor_rollout_ref.rollout.free_cache_engine=False \ actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_7b_full_16p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_7b_full_16p.sh index 5d38458968..4bc31696c7 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_7b_full_16p.sh +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_7b_full_16p.sh @@ -100,7 +100,7 @@ nohup python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.enable_chunked_prefill=False \ actor_rollout_ref.rollout.enforce_eager=False \ actor_rollout_ref.rollout.free_cache_engine=False \ diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_7b_performance_16p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_7b_performance_16p.sh index 7622bdf1ed..ee76e0d565 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_7b_performance_16p.sh +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_vl_7b_performance_16p.sh @@ -100,7 +100,7 @@ nohup python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.enable_chunked_prefill=False \ actor_rollout_ref.rollout.enforce_eager=False \ actor_rollout_ref.rollout.free_cache_engine=False \ diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py index 302e470419..9d05a5f980 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py @@ -16,9 +16,10 @@ from typing import Optional, Tuple import torch import torch.nn.functional as F +import torch_npu from torch_npu import npu_rotary_mul as apply_rotary_emb from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import apply_rotary_pos_emb_vision, \ - Qwen2_5_VLVisionSdpaAttention + Qwen2_5_VLVisionSdpaAttention, Qwen2RMSNorm from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl from transformers.utils import logging @@ -74,5 +75,10 @@ def sdpa_forward( return attn_output +def rms_norm_forward(self, x): + return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0] + + +Qwen2RMSNorm.forward = rms_norm_forward Qwen2_5_VLVisionSdpaAttention.forward = sdpa_forward modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/vllm_ascend_need/qwen2_5_vl.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/vllm_ascend_need/qwen2_5_vl.py new file mode 100644 index 0000000000..c6c1456cdc --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/vllm_ascend_need/qwen2_5_vl.py @@ -0,0 +1,271 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Adapted from vllm/model_executor/models/qwen2_5_vl.py +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Callable, Iterable, Optional, Set, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu +from einops import rearrange +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, + Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder, + Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor, + Qwen2_5_VLProcessingInfo) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.multimodal import MULTIMODAL_REGISTRY + + +class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + embed_dim, + num_heads, + projection_size, + quant_config, + prefix, + ) + self.embed_dim = embed_dim + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + q = torch_npu.npu_rotary_mul(q, cos, sin) + k = torch_npu.npu_rotary_mul(k, cos, sin) + + q, k, v = [ + rearrange(x, "b s h d -> (b s) h d").contiguous() + for x in (q, k, v) + ] + + context_layer = torch.torch.empty_like(q) + + # operator requires pta version >= 2.5.1 + torch_npu._npu_flash_attention_unpad( + query=q, + key=k, + value=v, + seq_len=cu_seqlens, + scale_value=self.hidden_size_per_attention_head**-0.5, + num_heads=self.num_attention_heads_per_partition, + num_kv_heads=self.num_attention_heads_per_partition, + out=context_layer) + + context_layer = rearrange(context_layer, + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() + + output, _ = self.proj(context_layer) + return output + + +class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, + quant_config, prefix) + self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) + + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen2_5_VisionPatchEmbed(Qwen2_5_VisionPatchEmbed): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.matmul( + self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) + return x + + +class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): + + def __init__( + self, + vision_config: Qwen2_5_VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + interleaved=False, + ) -> None: + super().__init__(vision_config, norm_eps, quant_config, prefix) + norm_layer = partial(RMSNorm, eps=norm_eps) + self.interleaved = interleaved + self.patch_embed = AscendQwen2_5_VisionPatchEmbed( + patch_size=vision_config.patch_size, + temporal_patch_size=vision_config.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + self.blocks = nn.ModuleList([ + AscendQwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + self.hidden_size, self.num_heads) + + def cal_cos_sin(self, rotary_pos_emb): + cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] + sin = rotary_pos_emb.sin() + + if not self.interleaved: + cos_new = torch.cat((cos, cos), dim=-1) + sin_new = torch.cat((sin, sin), dim=-1) + else: + cos_new = rearrange(torch.stack((cos, cos), dim=-1), + "... d two -> ...(d two)", + two=2) + sin_new = rearrange(torch.stack((sin, sin), dim=-1), + "... d two -> ...(d two)", + two=2) + cos_new = cos_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + sin_new = sin_new.reshape(1, -1, 1, + self.hidden_size_per_attention_head) + return cos_new, sin_new + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, + 0]).cpu().to(torch.int32) + + # patchify + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=x.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32) + seq_len, _ = x.size() + x = x.reshape(seq_len // self.spatial_merge_unit, + self.spatial_merge_unit, -1) + x = x[window_index, :, :] + x = x.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cos, sin = self.cal_cos_sin(rotary_pos_emb) + + # transformers + x = x.unsqueeze(1) + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) + + # adapter + x = self.merger(x) + reverse_indices = torch.argsort(window_index) + x = x[reverse_indices, :] + return x + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5_VLMultiModalProcessor, + info=Qwen2_5_VLProcessingInfo, + dummy_inputs=Qwen2_5_VLDummyInputsBuilder) +class AscendQwen2_5_VLForConditionalGeneration( + Qwen2_5_VLForConditionalGeneration): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.visual = AscendQwen2_5_VisionTransformer( + vision_config=config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/vllm_ascend_need/rotary_embedding.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/vllm_ascend_need/rotary_embedding.py new file mode 100644 index 0000000000..d4777faf66 --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/vllm_ascend_need/rotary_embedding.py @@ -0,0 +1,130 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import stat +from typing import Optional, Tuple + +import torch +import torch_npu +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding) + + +def rope_forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + if offsets is not None or self.cos_sin_cache.shape[-1] != self.head_size: + return self.forward_native(positions, query, key, offsets) + else: + query = query.contiguous() + key = key.contiguous() + torch_npu._npu_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + + +def rope_deepseek_forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + if offsets is not None: + raise NotImplementedError( + "Batched rotary embedding is currently not supported on NPU.") + else: + ori_query_shape, ori_key_shape = query.shape, key.shape + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(query.shape[0], -1) + torch_npu._npu_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + query = query.view(ori_query_shape) + key = key.view(ori_key_shape) + + return query, key + + +# The purpose of this function is to redirect the standard output and standard error of the torch_npu.npu_mrope function +# to an empty mounting disk, in order to solve the problem of duplicate log printing +class support_stdout_stderr(object): + + def __init__(self): + flags = os.O_RDWR + mode = stat.S_IWUSR | stat.S_IRUSR + self.null_fds = [os.open(os.devnull, flags, mode) for _ in range(2)] + self.save_fds = (os.dup(1), os.dup(2)) + + def __enter__(self): + os.dup2(self.null_fds[0], 1) + os.dup2(self.null_fds[1], 2) + + def __exit__(self, *_): + os.dup2(self.save_fds[0], 1) + os.dup2(self.save_fds[1], 2) + os.close(self.null_fds[0]) + os.close(self.null_fds[1]) + os.close(self.save_fds[0]) + os.close(self.save_fds[1]) + + +def mrope_forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + mrope_section = [0, 0, 0] if positions.ndim == 1 else self.mrope_section + with support_stdout_stderr(): + query, key = torch_npu.npu_mrope(positions, + query.contiguous(), + key.contiguous(), + self.cos_sin_cache.contiguous(), + self.head_size, + mrope_section=mrope_section, + rotary_mode='half') + return query, key + + +RotaryEmbedding.forward_oot = rope_forward_oot +MRotaryEmbedding.forward = mrope_forward +DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot -- Gitee