diff --git a/inference/model.py b/inference/model.py index 5e3eac1..11004a0 100644 --- a/inference/model.py +++ b/inference/model.py @@ -140,13 +140,12 @@ class RowParallelLinear(Linear): class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() + self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor): - x = x.float() - y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - return y.type_as(self.weight) * self.weight + return F.rms_norm(x, (self.dim,), self.weight, self.eps) def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: