mirror of
https://github.com/gpt-omni/mini-omni
synced 2024-11-21 23:37:38 +00:00
181 lines
7.5 KiB
Python
181 lines
7.5 KiB
Python
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
|
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Literal, Optional, Type, Union
|
|
|
|
import torch
|
|
import yaml
|
|
from typing_extensions import Self
|
|
|
|
import litgpt.model
|
|
from litgpt.utils import find_multiple
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
name: str = ""
|
|
hf_config: dict = field(default_factory=dict)
|
|
scale_embeddings: bool = False
|
|
block_size: int = 4096
|
|
vocab_size: int = 50254
|
|
padding_multiple: int = 512
|
|
padded_vocab_size: Optional[int] = None
|
|
n_layer: int = 16
|
|
n_head: int = 32
|
|
head_size: Optional[int] = None
|
|
n_embd: int = 4096
|
|
rotary_percentage: float = 0.25
|
|
parallel_residual: bool = True
|
|
bias: bool = True
|
|
lm_head_bias: bool = False
|
|
# to use multi-head attention (MHA), set this to `n_head` (default)
|
|
# to use multi-query attention (MQA), set this to 1
|
|
# to use grouped-query attention (GQA), set this to a value in between
|
|
# Example with `n_head=4`
|
|
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
|
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
|
|
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
|
# │ │ │ │ │ │ │
|
|
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
|
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
|
|
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
|
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
|
|
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
|
|
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
|
|
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
|
|
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
|
|
# MHA GQA MQA
|
|
# n_query_groups=4 n_query_groups=2 n_query_groups=1
|
|
#
|
|
# credit https://arxiv.org/pdf/2305.13245.pdf
|
|
n_query_groups: Optional[int] = None
|
|
shared_attention_norm: bool = False
|
|
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
|
|
norm_eps: float = 1e-5
|
|
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
|
|
"GptNeoxMLP"
|
|
)
|
|
gelu_approximate: str = "none"
|
|
intermediate_size: Optional[int] = None
|
|
rope_condense_ratio: int = 1
|
|
rope_base: int = 10000
|
|
n_expert: int = 0
|
|
n_expert_per_token: int = 0
|
|
|
|
add_qkv_bias: Optional[bool] = None
|
|
prompt_vocab_size: Optional[int] = None
|
|
attn_dropout: float = 0.0
|
|
pos_type: str = "rope"
|
|
force_align: bool = False
|
|
use_pretrain_phoneme_emb: bool = False
|
|
tie_word_embeddings: bool = False
|
|
|
|
# setting for mini-omni
|
|
text_vocab_size:int = 152000
|
|
cat_audio_vocab_size: int = 29120
|
|
audio_vocab_size: int = 4160
|
|
whisper_adapter_dim: int = 768
|
|
|
|
post_adapter: bool = False
|
|
post_adapter_layers: int = 6
|
|
asr_adapter: str = "llamamlp"
|
|
|
|
def __post_init__(self):
|
|
if not self.name:
|
|
self.name = self.hf_config.get("name", self.name)
|
|
|
|
if self.head_size is None:
|
|
assert self.n_embd % self.n_head == 0
|
|
self.head_size = self.n_embd // self.n_head
|
|
|
|
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
|
|
if self.padded_vocab_size is None:
|
|
self.padded_vocab_size = find_multiple(
|
|
self.vocab_size, self.padding_multiple
|
|
)
|
|
else:
|
|
# vocab size shouldn't be larger than padded vocab size
|
|
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
|
|
|
|
# compute the number of query groups
|
|
if self.n_query_groups is not None:
|
|
assert self.n_head % self.n_query_groups == 0
|
|
else:
|
|
self.n_query_groups = self.n_head
|
|
|
|
# compute the intermediate size for MLP if not set
|
|
if self.intermediate_size is None:
|
|
if self.mlp_class_name == "LLaMAMLP":
|
|
raise ValueError(
|
|
f"The config {self.name!r}, needs to set the `intermediate_size`"
|
|
)
|
|
self.intermediate_size = 4 * self.n_embd
|
|
|
|
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
|
|
|
|
if self.add_qkv_bias is None:
|
|
self.add_qkv_bias = self.bias
|
|
|
|
@classmethod
|
|
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
|
|
if name not in name_to_config:
|
|
# search through all `config['hf_config']['name']`
|
|
try:
|
|
conf_dict = next(
|
|
config
|
|
for config in configs
|
|
if name == config["hf_config"]["name"]
|
|
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
|
|
== name
|
|
)
|
|
except StopIteration:
|
|
raise ValueError(f"{name!r} is not a supported config name")
|
|
else:
|
|
conf_dict = name_to_config[name]
|
|
|
|
conf_dict = conf_dict.copy()
|
|
conf_dict.update(kwargs)
|
|
return cls(**conf_dict)
|
|
|
|
@classmethod
|
|
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
|
|
with open(path, encoding="utf-8") as fp:
|
|
file_kwargs = yaml.safe_load(fp)
|
|
if file_kwargs is None:
|
|
raise ValueError(f"{path} is empty which is likely unexpected.")
|
|
file_kwargs.update(kwargs)
|
|
return cls(**file_kwargs)
|
|
|
|
@classmethod
|
|
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
|
|
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
|
|
if (config_path := path / "model_config.yaml").is_file():
|
|
return cls.from_file(config_path, **kwargs)
|
|
if (model_name := path.name) in name_to_config:
|
|
return cls.from_name(model_name, **kwargs)
|
|
raise FileNotFoundError(
|
|
f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
|
|
)
|
|
|
|
@property
|
|
def mlp_class(self) -> Type:
|
|
# `self.mlp_class_name` cannot be the type to keep the config serializable
|
|
return getattr(litgpt.model, self.mlp_class_name)
|
|
|
|
@property
|
|
def norm_class(self) -> Type:
|
|
# `self.norm_class_name` cannot be the type to keep the config serializable
|
|
if self.norm_class_name == "RMSNorm":
|
|
from functools import partial
|
|
|
|
from litgpt.model import RMSNorm
|
|
|
|
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
|
|
return getattr(torch.nn, self.norm_class_name)
|
|
|
|
|
|
configs = []
|
|
name_to_config = {config["name"]: config for config in configs}
|