1 Star 0 Fork 0

hazdzz/RMSNorm

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
norm.py 882 Bytes
一键复制 编辑 原始数据 按行查看 历史
hazdzz 提交于 2024-05-31 04:15 +08:00 . Update norm.py
import torch
import torch.nn as nn
import torch.nn.init as init
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-8, bias: bool = False) -> None:
self.eps = eps
self.scale = nn.Parameter(torch.empty(dim))
if bias:
self.bias = nn.Parameter(torch.empty(dim))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.ones_(self.scale)
if self.bias is not None:
init.zeros_(self.bias)
def forward(self, input: Tensor) -> Tensor:
var = input.pow(2).mean(dim=-1, keepdim=True) + self.eps
input_norm = input * torch.rsqrt(var)
rmsnorm = self.scale * input_norm
if self.bias is not None:
rmsnorm = rmsnorm + self.bias
return rmsnorm
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/hazdzz/RMSNorm.git
git@gitee.com:hazdzz/RMSNorm.git
hazdzz
RMSNorm
RMSNorm
main

搜索帮助