torch rmsnorm

This commit is contained in:
GeeeekExplorer 2025-01-05 14:33:48 +08:00
parent 9b288b86cc
commit fd011c11aa

View File

@ -140,13 +140,12 @@ class RowParallelLinear(Linear):
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
self.dim = dim
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = x.float() return F.rms_norm(x, (self.dim,), self.weight, self.eps)
y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return y.type_as(self.weight) * self.weight
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: