mirror of
https://github.com/deepseek-ai/DeepSeek-V3
synced 2025-01-22 12:25:30 +00:00
torch rmsnorm
This commit is contained in:
parent
9b288b86cc
commit
fd011c11aa
@ -140,13 +140,12 @@ class RowParallelLinear(Linear):
|
|||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
x = x.float()
|
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
||||||
y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
return y.type_as(self.weight) * self.weight
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||||
|
Loading…
Reference in New Issue
Block a user