代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import torch.nn.init as init
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-8, bias: bool = False) -> None:
self.eps = eps
self.scale = nn.Parameter(torch.empty(dim))
if bias:
self.bias = nn.Parameter(torch.empty(dim))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.ones_(self.scale)
if self.bias is not None:
init.zeros_(self.bias)
def forward(self, input: Tensor) -> Tensor:
var = input.pow(2).mean(dim=-1, keepdim=True) + self.eps
input_norm = input * torch.rsqrt(var)
rmsnorm = self.scale * input_norm
if self.bias is not None:
rmsnorm = rmsnorm + self.bias
return rmsnorm
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。