mirror of
https://github.com/gpt-omni/mini-omni
synced 2024-11-16 05:03:47 +00:00
619 lines
24 KiB
Python
619 lines
24 KiB
Python
|
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
||
|
|
||
|
"""Full definition of a decoder-only transformer-based language model, all of it in this single file.
|
||
|
|
||
|
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
|
||
|
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
|
||
|
"""
|
||
|
|
||
|
import math
|
||
|
from typing import Any, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from typing_extensions import Self
|
||
|
from litgpt.config import Config
|
||
|
|
||
|
|
||
|
class GPT(nn.Module):
|
||
|
def __init__(self, config: Config) -> None:
|
||
|
super().__init__()
|
||
|
assert config.padded_vocab_size is not None
|
||
|
self.config = config
|
||
|
if self.config.asr_adapter == "mlp":
|
||
|
print("Using MLP adapter for ASR feature")
|
||
|
self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)
|
||
|
elif self.config.asr_adapter == "llamamlp":
|
||
|
print("using LLAMA MLP adapter for ASR feature")
|
||
|
self.whisper_adapter = whisperMLP(config=config)
|
||
|
else:
|
||
|
raise ValueError("asr_adapter should be mlp or llamamlp")
|
||
|
self.lm_head = nn.Linear(
|
||
|
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
|
||
|
)
|
||
|
if config.post_adapter:
|
||
|
self.transformer = nn.ModuleDict(
|
||
|
dict(
|
||
|
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
||
|
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
|
||
|
post_adapter=nn.ModuleList(
|
||
|
Block(config) for _ in range(config.post_adapter_layers)
|
||
|
),
|
||
|
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
||
|
post_adapter_audio_ln=config.norm_class(
|
||
|
config.n_embd, eps=config.norm_eps
|
||
|
),
|
||
|
post_adapter_audio_lm_head=nn.Linear(
|
||
|
config.n_embd, config.cat_audio_vocab_size, bias=config.lm_head_bias
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
self.transformer = nn.ModuleDict(
|
||
|
dict(
|
||
|
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
||
|
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
|
||
|
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
||
|
)
|
||
|
)
|
||
|
self.max_seq_length = self.config.block_size
|
||
|
self.mask_cache: Optional[torch.Tensor] = None
|
||
|
if config.tie_word_embeddings:
|
||
|
self.lm_head.weight = self.transformer.wte.weight
|
||
|
|
||
|
@property
|
||
|
def max_seq_length(self) -> int:
|
||
|
return self._max_seq_length
|
||
|
|
||
|
@max_seq_length.setter
|
||
|
def max_seq_length(self, value: int) -> None:
|
||
|
"""
|
||
|
When doing inference, the sequences used might be shorter than the model's context length.
|
||
|
This allows setting a smaller number to avoid allocating unused memory
|
||
|
"""
|
||
|
if value > self.config.block_size:
|
||
|
raise ValueError(
|
||
|
f"Cannot attend to {value}, block size is only {self.config.block_size}"
|
||
|
)
|
||
|
self._max_seq_length = value
|
||
|
if not hasattr(self, "cos"):
|
||
|
# first call
|
||
|
cos, sin = self.rope_cache()
|
||
|
self.register_buffer("cos", cos, persistent=False)
|
||
|
self.register_buffer("sin", sin, persistent=False)
|
||
|
# override
|
||
|
elif value != self.cos.size(0):
|
||
|
self.cos, self.sin = self.rope_cache(device=self.cos.device)
|
||
|
# the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
|
||
|
# if the kv cache is expected
|
||
|
|
||
|
def reset_parameters(self) -> None:
|
||
|
# Trigger resetting the rope-cache
|
||
|
self.cos, self.sin = self.rope_cache(device=self.cos.device)
|
||
|
|
||
|
def _init_weights(self, module: nn.Module) -> None:
|
||
|
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
|
||
|
if isinstance(module, nn.Linear):
|
||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
|
if module.bias is not None:
|
||
|
torch.nn.init.zeros_(module.bias)
|
||
|
elif isinstance(module, nn.Embedding):
|
||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
|
|
||
|
def concat_whisper_feat(self, audio_feature, input_ids, T, task):
|
||
|
for j in range(len(T)):
|
||
|
if task[j] != "T1T2" and task[j] != "T1A2":
|
||
|
for i in range(7):
|
||
|
input_ids[i][j, 1 : T[j] + 1, :] = audio_feature[j][: T[j]].clone()
|
||
|
else:
|
||
|
continue
|
||
|
return input_ids
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
audio_features: torch.Tensor,
|
||
|
input_ids: torch.Tensor,
|
||
|
input_pos: Optional[torch.Tensor] = None,
|
||
|
whisper_lens: Optional[list] = None,
|
||
|
task: Optional[str] = None,
|
||
|
) -> torch.Tensor:
|
||
|
|
||
|
show = False
|
||
|
T = input_ids[0].size(1)
|
||
|
if self.max_seq_length < T:
|
||
|
raise ValueError(
|
||
|
f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
|
||
|
)
|
||
|
|
||
|
if input_pos is not None: # use the kv cache
|
||
|
cos = self.cos.index_select(0, input_pos)
|
||
|
sin = self.sin.index_select(0, input_pos)
|
||
|
if self.mask_cache is None:
|
||
|
raise TypeError("You need to call `gpt.set_kv_cache()`")
|
||
|
mask = self.mask_cache.index_select(2, input_pos)
|
||
|
else:
|
||
|
cos = self.cos[:T]
|
||
|
sin = self.sin[:T]
|
||
|
mask = None
|
||
|
|
||
|
if audio_features is not None:
|
||
|
# get whisper feature
|
||
|
x_a = self.whisper_adapter(audio_features)
|
||
|
# get input_ids embedding
|
||
|
x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
|
||
|
|
||
|
x0 = self.transformer.wte(x0)
|
||
|
x1 = self.transformer.wte(x1)
|
||
|
x2 = self.transformer.wte(x2)
|
||
|
x3 = self.transformer.wte(x3)
|
||
|
x4 = self.transformer.wte(x4)
|
||
|
x5 = self.transformer.wte(x5)
|
||
|
x6 = self.transformer.wte(x6)
|
||
|
x7 = self.transformer.wte(x7)
|
||
|
|
||
|
# concat whisper feature
|
||
|
input_emb = self.concat_whisper_feat(
|
||
|
x_a, [x0, x1, x2, x3, x4, x5, x6, x7], whisper_lens, task
|
||
|
)
|
||
|
x0, x1, x2, x3, x4, x5, x6, x7 = input_emb
|
||
|
|
||
|
else:
|
||
|
x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
|
||
|
|
||
|
x0 = self.transformer.wte(x0)
|
||
|
x1 = self.transformer.wte(x1)
|
||
|
x2 = self.transformer.wte(x2)
|
||
|
x3 = self.transformer.wte(x3)
|
||
|
x4 = self.transformer.wte(x4)
|
||
|
x5 = self.transformer.wte(x5)
|
||
|
x6 = self.transformer.wte(x6)
|
||
|
x7 = self.transformer.wte(x7)
|
||
|
|
||
|
x = (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
|
||
|
|
||
|
if self.config.scale_embeddings:
|
||
|
x = x * (self.config.n_embd**0.5)
|
||
|
|
||
|
for block in self.transformer.h:
|
||
|
x = block(x, cos, sin, mask, input_pos)
|
||
|
|
||
|
|
||
|
text_vocab_size = self.config.text_vocab_size
|
||
|
audio_vocab_size = self.config.audio_vocab_size
|
||
|
|
||
|
x_ori = x
|
||
|
x_ori = self.transformer.ln_f(x_ori)
|
||
|
x_ori = self.lm_head(x_ori) # (b, t, vocab_size)
|
||
|
xt = x_ori[..., :text_vocab_size]
|
||
|
|
||
|
if self.config.post_adapter:
|
||
|
for block in self.transformer.post_adapter:
|
||
|
x = block(x, cos, sin, mask, input_pos)
|
||
|
x = self.transformer.post_adapter_audio_ln(x)
|
||
|
x = self.transformer.post_adapter_audio_lm_head(x) # (b, t, vocab_size)
|
||
|
xa = []
|
||
|
for i in range(7):
|
||
|
xa.append(x[..., audio_vocab_size * i : audio_vocab_size * (i + 1)])
|
||
|
else:
|
||
|
xa = []
|
||
|
for i in range(7):
|
||
|
xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
|
||
|
|
||
|
return xa, xt
|
||
|
|
||
|
@classmethod
|
||
|
def from_name(cls, name: str, **kwargs: Any) -> Self:
|
||
|
return cls(Config.from_name(name, **kwargs))
|
||
|
|
||
|
def rope_cache(
|
||
|
self, device: Optional[torch.device] = None
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
return build_rope_cache(
|
||
|
seq_len=self.max_seq_length,
|
||
|
n_elem=self.config.rope_n_elem,
|
||
|
device=device,
|
||
|
condense_ratio=self.config.rope_condense_ratio,
|
||
|
base=self.config.rope_base,
|
||
|
)
|
||
|
|
||
|
def set_kv_cache(
|
||
|
self,
|
||
|
batch_size: int,
|
||
|
rope_cache_length: Optional[int] = None,
|
||
|
device: Optional[torch.device] = None,
|
||
|
dtype: Optional[torch.dtype] = None,
|
||
|
) -> None:
|
||
|
if rope_cache_length is None:
|
||
|
rope_cache_length = self.cos.size(-1)
|
||
|
max_seq_length = self.max_seq_length
|
||
|
|
||
|
# initialize the kv cache for all blocks
|
||
|
for block in self.transformer.h:
|
||
|
block.attn.kv_cache = block.attn.build_kv_cache(
|
||
|
batch_size, max_seq_length, rope_cache_length, device, dtype
|
||
|
)
|
||
|
if self.config.post_adapter:
|
||
|
for block in self.transformer.post_adapter:
|
||
|
block.attn.kv_cache = block.attn.build_kv_cache(
|
||
|
batch_size, max_seq_length, rope_cache_length, device, dtype
|
||
|
)
|
||
|
|
||
|
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
|
||
|
# passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
|
||
|
# for the kv-cache support (only during inference), we only create it in that situation
|
||
|
self.mask_cache = build_mask_cache(max_seq_length, device)
|
||
|
|
||
|
def clear_kv_cache(self) -> None:
|
||
|
self.mask_cache = None
|
||
|
for block in self.transformer.h:
|
||
|
block.attn.kv_cache = None
|
||
|
|
||
|
|
||
|
class Block(nn.Module):
|
||
|
|
||
|
def __init__(self, config: Config) -> None:
|
||
|
super().__init__()
|
||
|
if not config.parallel_residual and config.shared_attention_norm:
|
||
|
raise NotImplementedError(
|
||
|
"No checkpoint amongst the ones we support uses this configuration"
|
||
|
" (non-parallel residual and shared attention norm)."
|
||
|
)
|
||
|
|
||
|
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
||
|
self.attn = CausalSelfAttention(config)
|
||
|
self.norm_2 = (
|
||
|
None
|
||
|
if config.shared_attention_norm
|
||
|
else config.norm_class(config.n_embd, eps=config.norm_eps)
|
||
|
)
|
||
|
self.mlp = config.mlp_class(config)
|
||
|
|
||
|
self.config = config
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
cos: torch.Tensor,
|
||
|
sin: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None,
|
||
|
input_pos: Optional[torch.Tensor] = None,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Non-parallel residual Parallel residual
|
||
|
┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True,
|
||
|
│ ↓ │ ↓ ↓ the output from `norm_1` is reused
|
||
|
│ norm_1 │ norm_1 ───► norm_2
|
||
|
│ ↓ │ ↓ ↓
|
||
|
│ attn │ attn mlp
|
||
|
│ ↓ │ ↓ │
|
||
|
┌─ └► + └► + ◄───────────┘
|
||
|
│ norm_2
|
||
|
│ ↓
|
||
|
│ mlp
|
||
|
│ ↓
|
||
|
└───► +
|
||
|
"""
|
||
|
|
||
|
x_normed = self.norm_1(x)
|
||
|
attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
|
||
|
|
||
|
if self.config.parallel_residual:
|
||
|
x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
|
||
|
x = self.mlp(x_normed) + attention_output + x
|
||
|
else:
|
||
|
x = attention_output + x
|
||
|
x = self.mlp(self.norm_2(x)) + x
|
||
|
return x
|
||
|
|
||
|
|
||
|
class CausalSelfAttention(nn.Module):
|
||
|
def __init__(self, config: Config) -> None:
|
||
|
super().__init__()
|
||
|
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
|
||
|
# key, query, value projections for all heads, but in a batch
|
||
|
self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias)
|
||
|
# output projection
|
||
|
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
|
||
|
self.proj = nn.Linear(
|
||
|
config.head_size * config.n_head, config.n_embd, bias=config.bias
|
||
|
)
|
||
|
# disabled by default
|
||
|
self.kv_cache: Optional[KVCache] = None
|
||
|
|
||
|
self.config = config
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
cos: torch.Tensor,
|
||
|
sin: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None,
|
||
|
input_pos: Optional[torch.Tensor] = None,
|
||
|
) -> torch.Tensor:
|
||
|
B, T, C = (
|
||
|
x.size()
|
||
|
) # batch size, sequence length, embedding dimensionality (n_embd)
|
||
|
|
||
|
qkv = self.attn(x)
|
||
|
|
||
|
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
|
||
|
q_per_kv = self.config.n_head // self.config.n_query_groups
|
||
|
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
|
||
|
qkv = qkv.view(
|
||
|
B, T, self.config.n_query_groups, total_qkv, self.config.head_size
|
||
|
)
|
||
|
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
|
||
|
|
||
|
# split batched computation into three
|
||
|
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
|
||
|
|
||
|
# maybe repeat k and v if for the non multi-head attention cases
|
||
|
# training: flash attention requires it
|
||
|
# inference: multi-query would require a full kv cache so avoid it to limit its memory usage
|
||
|
if self.config.n_query_groups != self.config.n_head and (
|
||
|
input_pos is None or self.config.n_query_groups != 1
|
||
|
):
|
||
|
k = k.expand(
|
||
|
B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
|
||
|
)
|
||
|
v = v.expand(
|
||
|
B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
|
||
|
)
|
||
|
|
||
|
q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
|
||
|
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
|
||
|
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
|
||
|
|
||
|
q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
|
||
|
k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
|
||
|
q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
|
||
|
k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
|
||
|
|
||
|
if input_pos is not None:
|
||
|
if not isinstance(self.kv_cache, KVCache):
|
||
|
raise TypeError("You need to call `gpt.set_kv_cache()`")
|
||
|
k, v = self.kv_cache(input_pos, k, v)
|
||
|
|
||
|
y = self.scaled_dot_product_attention(q, k, v, mask)
|
||
|
|
||
|
y = y.reshape(
|
||
|
B, T, self.config.head_size * self.config.n_head
|
||
|
) # re-assemble all head outputs side by side
|
||
|
|
||
|
# output projection
|
||
|
return self.proj(y)
|
||
|
|
||
|
def scaled_dot_product_attention(
|
||
|
self,
|
||
|
q: torch.Tensor,
|
||
|
k: torch.Tensor,
|
||
|
v: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None,
|
||
|
) -> torch.Tensor:
|
||
|
scale = 1.0 / math.sqrt(self.config.head_size)
|
||
|
y = torch.nn.functional.scaled_dot_product_attention(
|
||
|
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
|
||
|
)
|
||
|
return y.transpose(1, 2)
|
||
|
|
||
|
def build_kv_cache(
|
||
|
self,
|
||
|
batch_size: int,
|
||
|
max_seq_length: int,
|
||
|
rope_cache_length: Optional[int] = None,
|
||
|
device: Optional[torch.device] = None,
|
||
|
dtype: Optional[torch.dtype] = None,
|
||
|
) -> "KVCache":
|
||
|
heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
|
||
|
v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
|
||
|
if rope_cache_length is None:
|
||
|
if self.config.rotary_percentage != 1.0:
|
||
|
raise TypeError(
|
||
|
"Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
|
||
|
)
|
||
|
k_shape = v_shape
|
||
|
else:
|
||
|
k_shape = (
|
||
|
batch_size,
|
||
|
heads,
|
||
|
max_seq_length,
|
||
|
rope_cache_length + self.config.head_size - self.config.rope_n_elem,
|
||
|
)
|
||
|
return KVCache(k_shape, v_shape, device=device, dtype=dtype)
|
||
|
|
||
|
|
||
|
class GptNeoxMLP(nn.Module):
|
||
|
def __init__(self, config: Config) -> None:
|
||
|
super().__init__()
|
||
|
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
||
|
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
||
|
|
||
|
self.config = config
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
x = self.fc(x)
|
||
|
x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
|
||
|
return self.proj(x)
|
||
|
|
||
|
|
||
|
class LLaMAMLP(nn.Module):
|
||
|
def __init__(self, config: Config) -> None:
|
||
|
super().__init__()
|
||
|
self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
||
|
self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
||
|
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
||
|
|
||
|
self.config = config
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
x_fc_1 = self.fc_1(x)
|
||
|
x_fc_2 = self.fc_2(x)
|
||
|
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
|
||
|
return self.proj(x)
|
||
|
|
||
|
|
||
|
class whisperMLP(nn.Module):
|
||
|
def __init__(self, config: Config) -> None:
|
||
|
super().__init__()
|
||
|
self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
|
||
|
self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
|
||
|
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
||
|
|
||
|
self.config = config
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
x_fc_1 = self.fc_1(x)
|
||
|
x_fc_2 = self.fc_2(x)
|
||
|
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
|
||
|
return self.proj(x)
|
||
|
|
||
|
|
||
|
class GemmaMLP(LLaMAMLP):
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
x_fc_1 = self.fc_1(x)
|
||
|
x_fc_2 = self.fc_2(x)
|
||
|
x = (
|
||
|
torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate)
|
||
|
* x_fc_2
|
||
|
)
|
||
|
return self.proj(x)
|
||
|
|
||
|
|
||
|
class LLaMAMoE(nn.Module):
|
||
|
def __init__(self, config: Config) -> None:
|
||
|
super().__init__()
|
||
|
self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
|
||
|
self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
|
||
|
|
||
|
self.config = config
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
"""
|
||
|
Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
|
||
|
See also figure 1 in https://arxiv.org/abs/2211.15841
|
||
|
"""
|
||
|
B, T, C = (
|
||
|
x.size()
|
||
|
) # batch size, sequence length, embedding dimensionality (n_embd)
|
||
|
x = x.view(-1, C) # (B*T, C)
|
||
|
router = self.gate(x) # (B*T, n_expert)
|
||
|
probs, indices = torch.topk(
|
||
|
router, self.config.n_expert_per_token
|
||
|
) # (B*T, n_expert_per_token)
|
||
|
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
|
||
|
masks = indices.unsqueeze(-1) == torch.arange(
|
||
|
self.config.n_expert, device=x.device
|
||
|
)
|
||
|
masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
|
||
|
y = torch.zeros_like(x) # (B*T, C)
|
||
|
for mask, expert in zip(masks, self.experts):
|
||
|
token_idx, expert_idx = torch.where(mask)
|
||
|
y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
|
||
|
return y.view(B, T, C)
|
||
|
|
||
|
|
||
|
def build_rope_cache(
|
||
|
seq_len: int,
|
||
|
n_elem: int,
|
||
|
device: Optional[torch.device] = None,
|
||
|
base: int = 10000,
|
||
|
condense_ratio: int = 1,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Enhanced Transformer with Rotary Position Embedding.
|
||
|
|
||
|
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
||
|
transformers/rope/__init__.py. MIT License:
|
||
|
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
||
|
"""
|
||
|
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
|
||
|
|
||
|
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||
|
seq_idx = torch.arange(seq_len, device=device) / condense_ratio
|
||
|
|
||
|
# Calculate the product of position index and $\theta_i$
|
||
|
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
|
||
|
|
||
|
return torch.cos(idx_theta), torch.sin(idx_theta)
|
||
|
|
||
|
|
||
|
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
||
|
head_size = x.size(-1)
|
||
|
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
|
||
|
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
|
||
|
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
|
||
|
roped = (x * cos) + (rotated * sin)
|
||
|
return roped.to(dtype=x.dtype)
|
||
|
|
||
|
|
||
|
class KVCache(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
k_shape: Tuple[int, int, int, int],
|
||
|
v_shape: Tuple[int, int, int, int],
|
||
|
device: Optional[torch.device] = None,
|
||
|
dtype: Optional[torch.dtype] = None,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self.register_buffer(
|
||
|
"k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
|
||
|
)
|
||
|
self.register_buffer(
|
||
|
"v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
# move the buffer to the activation dtype for when AMP is used
|
||
|
self.k = self.k.to(k.dtype)
|
||
|
self.v = self.v.to(v.dtype)
|
||
|
# update the cache
|
||
|
k = self.k.index_copy_(2, input_pos, k)
|
||
|
v = self.v.index_copy_(2, input_pos, v)
|
||
|
return k, v
|
||
|
|
||
|
def reset_parameters(self) -> None:
|
||
|
torch.nn.init.zeros_(self.k)
|
||
|
torch.nn.init.zeros_(self.v)
|
||
|
|
||
|
|
||
|
def build_mask_cache(
|
||
|
max_seq_length: int, device: Optional[torch.device] = None
|
||
|
) -> torch.Tensor:
|
||
|
ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
|
||
|
return torch.tril(ones).unsqueeze(0).unsqueeze(0)
|
||
|
|
||
|
|
||
|
class RMSNorm(torch.nn.Module):
|
||
|
"""Root Mean Square Layer Normalization.
|
||
|
|
||
|
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
||
|
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
self.weight = torch.nn.Parameter(torch.ones(size))
|
||
|
self.eps = eps
|
||
|
self.dim = dim
|
||
|
self.add_unit_offset = add_unit_offset
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
dtype = x.dtype
|
||
|
x = x.float()
|
||
|
# NOTE: the original RMSNorm paper implementation is not equivalent
|
||
|
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
|
||
|
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
||
|
x_normed = x_normed.to(dtype=dtype)
|
||
|
if self.add_unit_offset:
|
||
|
# Gemma model requires a unit offset
|
||
|
# https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
|
||
|
return x_normed * (1 + self.weight)
|
||
|
return x_normed * self.weight
|
||
|
|
||
|
def reset_parameters(self) -> None:
|
||
|
torch.nn.init.ones_(self.weight)
|