From 61438ea031b9ca9597d26ec234451945b5f155eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E7=89=B9=E9=A9=B9?= Date: Fri, 9 May 2025 11:44:52 +0800 Subject: [PATCH] =?UTF-8?q?[built-in][PyTorch][OpenRLHF-v0.6.2]=20?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=80=A7=E8=83=BD=E4=BC=98=E5=8C=96=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=9E=8D=E5=90=88=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../transformers_need/modeling_llama.py | 17 +++++++---------- .../transformers_need/modeling_qwen2.py | 17 +++++++---------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/transformers_need/modeling_llama.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/transformers_need/modeling_llama.py index fb1fbce666..4c0cc3c252 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/transformers_need/modeling_llama.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/transformers_need/modeling_llama.py @@ -21,6 +21,7 @@ from functools import partial from typing import Callable, Optional, Tuple, Union import torch +import torch_npu import torch.utils.checkpoint from torch import nn @@ -76,11 +77,7 @@ class LlamaRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" @@ -152,8 +149,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) return q_embed, k_embed @@ -166,11 +163,11 @@ class LlamaMLP(nn.Module): self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + return self.down_proj( + torch_npu.npu_swiglu(torch.cat((self.gate_proj(x), self.up_proj(x)), dim=-1), dim=-1) + ) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/transformers_need/modeling_qwen2.py b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/transformers_need/modeling_qwen2.py index f02dc7ccc6..7b133b580e 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/transformers_need/modeling_qwen2.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.6.2_for_PyTorch/transformers_need/modeling_qwen2.py @@ -8,6 +8,7 @@ from functools import partial from typing import Callable, Optional, Tuple, Union import torch +import torch_npu from torch import nn from ...activations import ACT2FN @@ -53,11 +54,11 @@ class Qwen2MLP(nn.Module): self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + return self.down_proj( + torch_npu.npu_swiglu(torch.cat((self.gate_proj(x), self.up_proj(x)), dim=-1), dim=-1) + ) def rotate_half(x): @@ -89,8 +90,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) return q_embed, k_embed @@ -218,11 +219,7 @@ class Qwen2RMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -- Gitee