Updated model.py docstrings

This commit is contained in:
enoch kan 2025-01-05 18:24:31 +00:00
parent a1296f099e
commit bc77f22afc

View File

@ -18,6 +18,39 @@ attn_impl: Literal["naive", "absorb"] = "absorb"
@dataclass @dataclass
class ModelArgs: class ModelArgs:
"""
Data class for defining model arguments and hyperparameters.
Attributes:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
dtype (Literal["bf16", "fp8"]): Data type for computations.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
moe_inter_dim (int): Intermediate dimension for MoE layers.
n_layers (int): Number of transformer layers.
n_dense_layers (int): Number of dense layers in the model.
n_heads (int): Number of attention heads.
n_routed_experts (int): Number of routed experts for MoE layers.
n_shared_experts (int): Number of shared experts for MoE layers.
n_activated_experts (int): Number of activated experts in MoE layers.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
route_scale (float): Scaling factor for routing scores.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
v_head_dim (int): Dimension for value projections.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
rope_factor (float): Scaling factor for extended sequence lengths.
beta_fast (int): Fast beta correction factor.
beta_slow (int): Slow beta correction factor.
mscale (float): Scaling factor for extended attention.
"""
max_batch_size: int = 8 max_batch_size: int = 8
max_seq_len: int = 4096 * 4 max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16" dtype: Literal["bf16", "fp8"] = "bf16"
@ -52,6 +85,13 @@ class ModelArgs:
class ParallelEmbedding(nn.Module): class ParallelEmbedding(nn.Module):
"""
Embedding layer with parallelism support across distributed processes.
Args:
vocab_size (int): Vocabulary size.
dim (int): Embedding dimension.
"""
def __init__(self, vocab_size: int, dim: int): def __init__(self, vocab_size: int, dim: int):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -63,6 +103,18 @@ class ParallelEmbedding(nn.Module):
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for parallel embedding layer.
Args:
x (torch.Tensor): Input tensor containing token indices.
Returns:
torch.Tensor: Embedded representations.
Raises:
ValueError: If `world_size` is not defined.
"""
if world_size > 1: if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx x = x - self.vocab_start_idx
@ -75,6 +127,27 @@ class ParallelEmbedding(nn.Module):
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
and tensor formats.
Args:
x (torch.Tensor): The input tensor.
weight (torch.Tensor): The weight tensor. It may be quantized and
requires dequantization for certain cases.
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
Returns:
torch.Tensor: The result of the linear transformation, which may involve
quantization-aware computations depending on the input parameters.
Notes:
- If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version
is used for computation.
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
"""
if weight.element_size() > 1: if weight.element_size() > 1:
return F.linear(x, weight, bias) return F.linear(x, weight, bias)
elif gemm_impl == "bf16": elif gemm_impl == "bf16":
@ -89,6 +162,15 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
class Linear(nn.Module): class Linear(nn.Module):
"""
Custom linear layer with support for quantized weights and optional bias.
Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
dtype = torch.bfloat16 dtype = torch.bfloat16
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
@ -108,27 +190,72 @@ class Linear(nn.Module):
self.register_parameter("bias", None) self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the custom linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor after linear computation.
"""
return linear(x, self.weight, self.bias) return linear(x, self.weight, self.bias)
class ColumnParallelLinear(Linear): class ColumnParallelLinear(Linear):
"""
Linear layer with column parallelism, splitting output features across distributed processes.
Args:
in_features (int): Number of input features.
out_features (int): Total number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert out_features % world_size == 0 assert out_features % world_size == 0
self.part_out_features = out_features // world_size self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype) super().__init__(in_features, self.part_out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for column parallel linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with column-parallel computation.
"""
y = linear(x, self.weight, self.bias) y = linear(x, self.weight, self.bias)
return y return y
class RowParallelLinear(Linear): class RowParallelLinear(Linear):
"""
Linear layer with row parallelism, splitting input features across distributed processes.
Args:
in_features (int): Total number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert in_features % world_size == 0 assert in_features % world_size == 0
self.part_in_features = in_features // world_size self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype) super().__init__(self.part_in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for row parallel linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with row-parallel computation.
"""
y = linear(x, self.weight) y = linear(x, self.weight)
if world_size > 1: if world_size > 1:
dist.all_reduce(y) dist.all_reduce(y)
@ -138,6 +265,13 @@ class RowParallelLinear(Linear):
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization (RMSNorm).
Args:
dim (int): Dimension of the input tensor.
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
"""
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -145,10 +279,28 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
"""
Forward pass for RMSNorm.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Normalized tensor with the same shape as input.
"""
return F.rms_norm(x, (self.dim,), self.weight, self.eps) return F.rms_norm(x, (self.dim,), self.weight, self.eps)
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
"""
Precomputes frequency-based complex exponential values for rotary positional embeddings.
Args:
args (ModelArgs): Model arguments containing positional embedding parameters.
Returns:
torch.Tensor: Precomputed complex exponential values for positional embeddings.
"""
dim = args.qk_rope_head_dim dim = args.qk_rope_head_dim
seqlen = args.max_seq_len seqlen = args.max_seq_len
beta_fast = args.beta_fast beta_fast = args.beta_fast
@ -157,14 +309,51 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
factor = args.rope_factor factor = args.rope_factor
def find_correction_dim(num_rotations, dim, base, max_seq_len): def find_correction_dim(num_rotations, dim, base, max_seq_len):
"""
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
Args:
num_rotations (float): Number of rotations to compute the correction for.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
float: The correction dimension based on the input parameters.
"""
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
"""
Computes the range of correction dimensions for rotary positional embeddings.
Args:
low_rot (float): Lower bound for the number of rotations.
high_rot (float): Upper bound for the number of rotations.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim-1) return max(low, 0), min(high, dim-1)
def linear_ramp_factor(min, max, dim): def linear_ramp_factor(min, max, dim):
"""
Computes a linear ramp function used to smooth values between a minimum and maximum range.
Args:
min (float): Minimum value for the ramp function.
max (float): Maximum value for the ramp function.
dim (int): Dimensionality of the ramp tensor.
Returns:
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
clamped to the range [0, 1].
"""
if min == max: if min == max:
max += 0.001 max += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
@ -184,6 +373,16 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""
Applies rotary positional embeddings to the input tensor.
Args:
x (torch.Tensor): Input tensor with positional embeddings to be applied.
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
Returns:
torch.Tensor: Tensor with rotary embeddings applied.
"""
dtype = x.dtype dtype = x.dtype
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
@ -192,6 +391,21 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
class MLA(nn.Module): class MLA(nn.Module):
"""
Multi-Headed Attention Layer (MLA).
Attributes:
dim (int): Dimensionality of the input features.
n_heads (int): Number of attention heads.
n_local_heads (int): Number of local attention heads for distributed systems.
q_lora_rank (int): Rank for low-rank query projection.
kv_lora_rank (int): Rank for low-rank key/value projection.
qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
qk_head_dim (int): Total dimensionality of query/key projections.
v_head_dim (int): Dimensionality of value projections.
softmax_scale (float): Scaling factor for softmax in attention computation.
"""
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.dim = args.dim self.dim = args.dim
@ -227,6 +441,18 @@ class MLA(nn.Module):
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
"""
Forward pass for the Multi-Headed Attention Layer (MLA).
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
start_pos (int): Starting position in the sequence for caching.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, _ = x.size() bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen end_pos = start_pos + seqlen
if self.q_lora_rank == 0: if self.q_lora_rank == 0:
@ -269,18 +495,61 @@ class MLA(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
"""
Multi-Layer Perceptron (MLP) used as a feed-forward layer.
Attributes:
w1 (nn.Module): Linear layer for input-to-hidden transformation.
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int): def __init__(self, dim: int, inter_dim: int):
"""
Initializes the MLP layer.
Args:
dim (int): Input and output dimensionality.
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__() super().__init__()
self.w1 = ColumnParallelLinear(dim, inter_dim) self.w1 = ColumnParallelLinear(dim, inter_dim)
self.w2 = RowParallelLinear(inter_dim, dim) self.w2 = RowParallelLinear(inter_dim, dim)
self.w3 = ColumnParallelLinear(dim, inter_dim) self.w3 = ColumnParallelLinear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MLP layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after MLP computation.
"""
return self.w2(F.silu(self.w1(x)) * self.w3(x)) return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Gate(nn.Module): class Gate(nn.Module):
"""
Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
Attributes:
dim (int): Dimensionality of input features.
topk (int): Number of top experts activated for each input.
n_groups (int): Number of groups for routing.
topk_groups (int): Number of groups to route inputs to.
score_func (str): Scoring function ('softmax' or 'sigmoid').
route_scale (float): Scaling factor for routing weights.
weight (torch.nn.Parameter): Learnable weights for the gate.
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
"""
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
"""
Initializes the Gate module.
Args:
args (ModelArgs): Model arguments containing gating parameters.
"""
super().__init__() super().__init__()
self.dim = args.dim self.dim = args.dim
self.topk = args.n_activated_experts self.topk = args.n_activated_experts
@ -292,6 +561,15 @@ class Gate(nn.Module):
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for the gating mechanism.
Args:
x (torch.Tensor): Input tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
"""
scores = linear(x, self.weight) scores = linear(x, self.weight)
if self.score_func == "softmax": if self.score_func == "softmax":
scores = scores.softmax(dim=-1, dtype=torch.float32) scores = scores.softmax(dim=-1, dtype=torch.float32)
@ -318,18 +596,60 @@ class Gate(nn.Module):
class Expert(nn.Module): class Expert(nn.Module):
"""
Expert layer for Mixture-of-Experts (MoE) models.
Attributes:
w1 (nn.Module): Linear layer for input-to-hidden transformation.
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int): def __init__(self, dim: int, inter_dim: int):
"""
Initializes the Expert layer.
Args:
dim (int): Input and output dimensionality.
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__() super().__init__()
self.w1 = Linear(dim, inter_dim) self.w1 = Linear(dim, inter_dim)
self.w2 = Linear(inter_dim, dim) self.w2 = Linear(inter_dim, dim)
self.w3 = Linear(dim, inter_dim) self.w3 = Linear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the Expert layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert computation.
"""
return self.w2(F.silu(self.w1(x)) * self.w3(x)) return self.w2(F.silu(self.w1(x)) * self.w3(x))
class MoE(nn.Module): class MoE(nn.Module):
"""
Mixture-of-Experts (MoE) module.
Attributes:
dim (int): Dimensionality of input features.
n_routed_experts (int): Total number of experts in the model.
n_local_experts (int): Number of experts handled locally in distributed systems.
n_activated_experts (int): Number of experts activated for each input.
gate (nn.Module): Gating mechanism to route inputs to experts.
experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs.
"""
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
"""
Initializes the MoE module.
Args:
args (ModelArgs): Model arguments containing MoE parameters.
"""
super().__init__() super().__init__()
self.dim = args.dim self.dim = args.dim
assert args.n_routed_experts % world_size == 0 assert args.n_routed_experts % world_size == 0
@ -344,6 +664,15 @@ class MoE(nn.Module):
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MoE module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert routing and computation.
"""
shape = x.size() shape = x.size()
x = x.view(-1, self.dim) x = x.view(-1, self.dim)
weights, indices = self.gate(x) weights, indices = self.gate(x)
@ -362,7 +691,23 @@ class MoE(nn.Module):
class Block(nn.Module): class Block(nn.Module):
"""
Transformer block combining attention and feed-forward layers.
Attributes:
attn (nn.Module): Attention layer (MLA).
ffn (nn.Module): Feed-forward network (MLP or MoE).
attn_norm (nn.Module): Layer normalization for attention.
ffn_norm (nn.Module): Layer normalization for feed-forward network.
"""
def __init__(self, layer_id: int, args: ModelArgs): def __init__(self, layer_id: int, args: ModelArgs):
"""
Initializes the Transformer block.
Args:
layer_id (int): Layer index in the transformer.
args (ModelArgs): Model arguments containing block parameters.
"""
super().__init__() super().__init__()
self.attn = MLA(args) self.attn = MLA(args)
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
@ -370,13 +715,42 @@ class Block(nn.Module):
self.ffn_norm = RMSNorm(args.dim) self.ffn_norm = RMSNorm(args.dim)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
"""
Forward pass for the Transformer block.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position in the sequence.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor after block computation.
"""
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
x = x + self.ffn(self.ffn_norm(x)) x = x + self.ffn(self.ffn_norm(x))
return x return x
class Transformer(nn.Module): class Transformer(nn.Module):
"""
Transformer model with positional embeddings, multiple layers, and output projection.
Attributes:
max_seq_len (int): Maximum sequence length for the transformer.
embed (nn.Module): Embedding layer for input tokens.
layers (torch.nn.ModuleList): List of transformer blocks.
norm (nn.Module): Layer normalization applied after all blocks.
head (nn.Module): Output projection layer mapping to vocabulary size.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
"""
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
"""
Initializes the Transformer model.
Args:
args (ModelArgs): Model arguments containing transformer parameters.
"""
global world_size, rank global world_size, rank
world_size = dist.get_world_size() if dist.is_initialized() else 1 world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0 rank = dist.get_rank() if dist.is_initialized() else 0
@ -393,6 +767,16 @@ class Transformer(nn.Module):
@torch.inference_mode() @torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int = 0): def forward(self, tokens: torch.Tensor, start_pos: int = 0):
"""
Forward pass for the Transformer model.
Args:
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
"""
seqlen = tokens.size(1) seqlen = tokens.size(1)
h = self.embed(tokens) h = self.embed(tokens)
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]