From fd011c11aa862d32270880d3a0080a66d3c76822 Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Sun, 5 Jan 2025 14:33:48 +0800 Subject: [PATCH] torch rmsnorm --- inference/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/inference/model.py b/inference/model.py index 5e3eac1..11004a0 100644 --- a/inference/model.py +++ b/inference/model.py @@ -140,13 +140,12 @@ class RowParallelLinear(Linear): class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() + self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor): - x = x.float() - y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - return y.type_as(self.weight) * self.weight + return F.rms_norm(x, (self.dim,), self.weight, self.eps) def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: