mirror of
https://github.com/gpt-omni/mini-omni
synced 2024-11-25 05:21:39 +00:00
642 lines
23 KiB
Python
642 lines
23 KiB
Python
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
|
|
|
"""Utility functions for training and inference."""
|
|
import inspect
|
|
import math
|
|
import os
|
|
import pickle
|
|
import shutil
|
|
import sys
|
|
from dataclasses import asdict, is_dataclass
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Literal,
|
|
Mapping,
|
|
Optional,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
import lightning as L
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils._device
|
|
import yaml
|
|
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
|
|
from lightning.fabric.strategies import FSDPStrategy
|
|
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
|
from lightning.pytorch.loggers import WandbLogger
|
|
from lightning.pytorch.cli import instantiate_class
|
|
from torch.serialization import normalize_storage_type
|
|
from typing_extensions import Self
|
|
|
|
if TYPE_CHECKING:
|
|
from litgpt import GPT, Config
|
|
|
|
|
|
def init_out_dir(out_dir: Path) -> Path:
|
|
if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
|
|
return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
|
|
return out_dir
|
|
|
|
|
|
def find_resume_path(
|
|
resume: Union[bool, Literal["auto"], Path], out_dir: Path
|
|
) -> Optional[Path]:
|
|
if not resume or isinstance(resume, Path):
|
|
return resume
|
|
|
|
resume_path = max(
|
|
out_dir.rglob("step-*/*.pth"),
|
|
key=(lambda p: int(p.parent.name.split("-")[1])),
|
|
default=None,
|
|
)
|
|
if resume == "auto":
|
|
return resume_path
|
|
if resume is True and resume_path is None:
|
|
raise FileNotFoundError(
|
|
f"You passed `--resume=True`, but no checkpont file was found in `--out_dir={out_dir}`."
|
|
)
|
|
return resume_path
|
|
|
|
|
|
def find_multiple(n: int, k: int) -> int:
|
|
assert k > 0
|
|
if n % k == 0:
|
|
return n
|
|
return n + k - (n % k)
|
|
|
|
|
|
def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
|
|
total = 0
|
|
for p in module.parameters():
|
|
if requires_grad is None or p.requires_grad == requires_grad:
|
|
if hasattr(p, "quant_state"):
|
|
# bitsandbytes 4bit layer support
|
|
total += math.prod(p.quant_state.shape)
|
|
else:
|
|
total += p.numel()
|
|
return total
|
|
|
|
|
|
def reset_parameters(module: nn.Module) -> None:
|
|
"""Calls `reset_parameters` on the module and all its submodules."""
|
|
for mod in module.modules():
|
|
if callable(getattr(mod, "reset_parameters", None)):
|
|
mod.reset_parameters()
|
|
|
|
|
|
def check_valid_checkpoint_dir(
|
|
checkpoint_dir: Path,
|
|
model_filename: str = "lit_model.pth",
|
|
verbose: bool = True,
|
|
raise_error: bool = False,
|
|
) -> None:
|
|
files = {
|
|
model_filename: (checkpoint_dir / model_filename).is_file(),
|
|
"model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
|
|
"tokenizer.json OR tokenizer.model": (
|
|
checkpoint_dir / "tokenizer.json"
|
|
).is_file()
|
|
or (checkpoint_dir / "tokenizer.model").is_file(),
|
|
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
|
|
}
|
|
if checkpoint_dir.is_dir():
|
|
if all(files.values()):
|
|
# we're good
|
|
return
|
|
problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
|
|
else:
|
|
problem = " is not a checkpoint directory"
|
|
|
|
# list locally available checkpoints
|
|
available = list(Path("checkpoints").glob("*/*"))
|
|
if available:
|
|
options = "\n".join([""] + [repr(str(p.resolve())) for p in available])
|
|
extra = f"\nYou have downloaded locally:{options}\n"
|
|
else:
|
|
extra = ""
|
|
|
|
if verbose:
|
|
error_message = (
|
|
f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
|
|
"\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n"
|
|
f"{extra}\nSee all download options by running:\n litgpt download"
|
|
)
|
|
print(error_message, file=sys.stderr)
|
|
|
|
if raise_error:
|
|
raise FileNotFoundError(
|
|
f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
|
|
)
|
|
else:
|
|
raise SystemExit(1)
|
|
|
|
|
|
class SavingProxyForStorage:
|
|
def __init__(self, obj, saver, protocol_version=5):
|
|
self.protocol_version = protocol_version
|
|
self.saver = saver
|
|
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
|
|
raise TypeError(f"expected storage, not {type(obj)}")
|
|
|
|
# this logic is taken from PyTorch 2.0+ torch/serialization.py
|
|
if isinstance(obj, torch.storage.TypedStorage):
|
|
# PT upstream wants to deprecate this eventually...
|
|
storage = obj._untyped_storage
|
|
storage_type_str = obj._pickle_storage_type()
|
|
storage_type = getattr(torch, storage_type_str)
|
|
storage_numel = obj._size()
|
|
else:
|
|
storage = obj
|
|
storage_type = normalize_storage_type(type(obj))
|
|
storage_numel = storage.nbytes()
|
|
|
|
storage_key = saver._write_storage_and_return_key(storage)
|
|
location = torch.serialization.location_tag(storage)
|
|
|
|
self.storage_info = (
|
|
"storage",
|
|
storage_type,
|
|
storage_key,
|
|
location,
|
|
storage_numel,
|
|
)
|
|
|
|
def __reduce_ex__(self, protocol_version):
|
|
assert False, "this should be handled with out of band"
|
|
|
|
|
|
class SavingProxyForTensor:
|
|
def __init__(self, tensor, saver, protocol_version=5):
|
|
self.protocol_version = protocol_version
|
|
self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
|
|
if reduce_args[0] == torch._utils._rebuild_tensor_v2:
|
|
# for Tensors with Python attributes
|
|
(a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
|
|
assert isinstance(
|
|
storage, torch.storage.TypedStorage
|
|
), "Please check for updates"
|
|
storage_proxy = SavingProxyForStorage(
|
|
storage, saver, protocol_version=protocol_version
|
|
)
|
|
self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
|
|
else:
|
|
(storage, *other_reduce_args) = reduce_args
|
|
assert isinstance(
|
|
storage, torch.storage.TypedStorage
|
|
), "Please check for updates"
|
|
storage_proxy = SavingProxyForStorage(
|
|
storage, saver, protocol_version=protocol_version
|
|
)
|
|
self.reduce_args = (storage_proxy, *other_reduce_args)
|
|
|
|
def __reduce_ex__(self, protocol_version):
|
|
if protocol_version != self.protocol_version:
|
|
raise RuntimeError(
|
|
f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
|
|
)
|
|
return self.reduce_ret_fn, self.reduce_args
|
|
|
|
|
|
class IncrementalPyTorchPickler(pickle.Pickler):
|
|
def __init__(self, saver, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.storage_dtypes = {}
|
|
self.saver = saver
|
|
self.id_map = {}
|
|
|
|
# this logic is taken from PyTorch 2.0+ torch/serialization.py
|
|
def persistent_id(self, obj):
|
|
# FIXME: the docs say that persistent_id should only return a string
|
|
# but torch store returns tuples. This works only in the binary protocol
|
|
# see
|
|
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
|
|
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
|
|
if isinstance(obj, SavingProxyForStorage):
|
|
return obj.storage_info
|
|
|
|
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
|
|
if isinstance(obj, torch.storage.TypedStorage):
|
|
# TODO: Once we decide to break serialization FC, this case
|
|
# can be deleted
|
|
storage = obj._untyped_storage
|
|
storage_dtype = obj.dtype
|
|
storage_type_str = obj._pickle_storage_type()
|
|
storage_type = getattr(torch, storage_type_str)
|
|
storage_numel = obj._size()
|
|
|
|
else:
|
|
storage = obj
|
|
storage_dtype = torch.uint8
|
|
storage_type = normalize_storage_type(type(obj))
|
|
storage_numel = storage.nbytes()
|
|
|
|
# If storage is allocated, ensure that any other saved storages
|
|
# pointing to the same data all have the same dtype. If storage is
|
|
# not allocated, don't perform this check
|
|
if storage.data_ptr() != 0:
|
|
if storage.data_ptr() in self.storage_dtypes:
|
|
if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
|
|
raise RuntimeError(
|
|
"Cannot save multiple tensors or storages that view the same data as different types"
|
|
)
|
|
else:
|
|
self.storage_dtypes[storage.data_ptr()] = storage_dtype
|
|
|
|
storage_key = self.id_map.get(storage._cdata)
|
|
if storage_key is None:
|
|
storage_key = self.saver._write_storage_and_return_key(storage)
|
|
self.id_map[storage._cdata] = storage_key
|
|
location = torch.serialization.location_tag(storage)
|
|
|
|
return ("storage", storage_type, storage_key, location, storage_numel)
|
|
|
|
return None
|
|
|
|
|
|
class incremental_save:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
self.zipfile = torch._C.PyTorchFileWriter(str(name))
|
|
self.has_saved = False
|
|
self.next_key = 0
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def store_early(self, tensor):
|
|
if isinstance(tensor, torch.Tensor):
|
|
return SavingProxyForTensor(tensor, self)
|
|
raise TypeError(f"can only store tensors early, not {type(tensor)}")
|
|
|
|
def save(self, obj):
|
|
if self.has_saved:
|
|
raise RuntimeError("have already saved")
|
|
# Write the pickle data for `obj`
|
|
data_buf = BytesIO()
|
|
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
|
|
pickler.dump(obj)
|
|
data_value = data_buf.getvalue()
|
|
self.zipfile.write_record("data.pkl", data_value, len(data_value))
|
|
self.has_saved = True
|
|
|
|
def _write_storage_and_return_key(self, storage):
|
|
if self.has_saved:
|
|
raise RuntimeError("have already saved")
|
|
key = self.next_key
|
|
self.next_key += 1
|
|
name = f"data/{key}"
|
|
if storage.device.type != "cpu":
|
|
storage = storage.cpu()
|
|
num_bytes = storage.nbytes()
|
|
self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
|
|
return key
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.zipfile.write_end_of_file()
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def chunked_cross_entropy(
|
|
logits: Union[torch.Tensor, List[torch.Tensor]],
|
|
targets: torch.Tensor,
|
|
chunk_size: int = 128,
|
|
ignore_index: int = -100,
|
|
) -> torch.Tensor:
|
|
# with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
|
|
# the memory usage in fine-tuning settings with low number of parameters.
|
|
# as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
|
|
# the memory spike's magnitude
|
|
|
|
# lm_head was chunked (we are fine-tuning)
|
|
if isinstance(logits, list):
|
|
# don't want to chunk cross entropy
|
|
if chunk_size == 0:
|
|
logits = torch.cat(logits, dim=1)
|
|
logits = logits.reshape(-1, logits.size(-1))
|
|
targets = targets.reshape(-1)
|
|
return torch.nn.functional.cross_entropy(
|
|
logits, targets, ignore_index=ignore_index
|
|
)
|
|
|
|
# chunk cross entropy
|
|
logit_chunks = [
|
|
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
|
|
]
|
|
target_chunks = [
|
|
target_chunk.reshape(-1)
|
|
for target_chunk in targets.split(logits[0].size(1), dim=1)
|
|
]
|
|
loss_chunks = [
|
|
torch.nn.functional.cross_entropy(
|
|
logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
|
|
)
|
|
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
|
|
]
|
|
non_masked_elems = (targets != ignore_index).sum()
|
|
# See [non_masked_elems div note]
|
|
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
|
|
torch.ones_like(non_masked_elems)
|
|
)
|
|
|
|
# no chunking at all
|
|
logits = logits.reshape(-1, logits.size(-1))
|
|
targets = targets.reshape(-1)
|
|
if chunk_size == 0:
|
|
return torch.nn.functional.cross_entropy(
|
|
logits, targets, ignore_index=ignore_index
|
|
)
|
|
|
|
# lm_head wasn't chunked, chunk cross entropy
|
|
logit_chunks = logits.split(chunk_size)
|
|
target_chunks = targets.split(chunk_size)
|
|
loss_chunks = [
|
|
torch.nn.functional.cross_entropy(
|
|
logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
|
|
)
|
|
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
|
|
]
|
|
non_masked_elems = (targets != ignore_index).sum()
|
|
# [non_masked_elems div note]:
|
|
# max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
|
|
# results in a python int which is then passed back to torch division. By using the
|
|
# `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
|
|
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
|
|
torch.ones_like(non_masked_elems)
|
|
)
|
|
|
|
|
|
def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
|
|
for checkpoint_name, attribute_name in mapping.items():
|
|
full_checkpoint_name = prefix + checkpoint_name
|
|
if full_checkpoint_name in state_dict:
|
|
full_attribute_name = prefix + attribute_name
|
|
state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
|
|
return state_dict
|
|
|
|
|
|
def get_default_supported_precision(training: bool) -> str:
|
|
"""Return default precision that is supported by the hardware: either `bf16` or `16`.
|
|
|
|
Args:
|
|
training: `-mixed` or `-true` version of the precision to use
|
|
|
|
Returns:
|
|
default precision that is suitable for the task and is supported by the hardware
|
|
"""
|
|
from lightning.fabric.accelerators import MPSAccelerator
|
|
|
|
if MPSAccelerator.is_available() or (
|
|
torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
|
|
):
|
|
return "16-mixed" if training else "16-true"
|
|
return "bf16-mixed" if training else "bf16-true"
|
|
|
|
|
|
def load_checkpoint(
|
|
fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
|
|
) -> None:
|
|
if isinstance(fabric.strategy, FSDPStrategy):
|
|
fabric.load_raw(checkpoint_path, model, strict=strict)
|
|
else:
|
|
state_dict = lazy_load(checkpoint_path)
|
|
state_dict = state_dict.get("model", state_dict)
|
|
model.load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
def flops_per_param(
|
|
max_seq_length: int, n_layer: int, n_embd: int, n_params: int
|
|
) -> int:
|
|
flops_per_token = (
|
|
2 * n_params
|
|
) # each parameter is used for a MAC (2 FLOPS) per network operation
|
|
# this assumes that all samples have a fixed length equal to the block size
|
|
# which is most likely false during finetuning
|
|
flops_per_seq = flops_per_token * max_seq_length
|
|
attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
|
|
return flops_per_seq + attn_flops_per_seq
|
|
|
|
|
|
def estimate_flops(model: "GPT", training: bool) -> int:
|
|
"""Measures estimated FLOPs for MFU.
|
|
|
|
Refs:
|
|
* https://ar5iv.labs.arxiv.org/html/2205.05198#A1
|
|
* https://ar5iv.labs.arxiv.org/html/2204.02311#A2
|
|
"""
|
|
# using all parameters for this is a naive over estimation because not all model parameters actually contribute to
|
|
# this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
|
|
# (~10%) compared to the measured FLOPs, making those lower but more realistic.
|
|
# For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
|
|
n_trainable_params = num_parameters(model, requires_grad=True)
|
|
trainable_flops = flops_per_param(
|
|
model.max_seq_length,
|
|
model.config.n_layer,
|
|
model.config.n_embd,
|
|
n_trainable_params,
|
|
)
|
|
# forward + backward + gradients (assumes no gradient accumulation)
|
|
ops_per_step = 3 if training else 1
|
|
n_frozen_params = num_parameters(model, requires_grad=False)
|
|
frozen_flops = flops_per_param(
|
|
model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
|
|
)
|
|
# forward + backward
|
|
frozen_ops_per_step = 2 if training else 1
|
|
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
|
|
|
|
|
|
class CycleIterator:
|
|
"""An iterator that cycles through an iterable indefinitely.
|
|
|
|
Example:
|
|
>>> iterator = CycleIterator([1, 2, 3])
|
|
>>> [next(iterator) for _ in range(5)]
|
|
[1, 2, 3, 1, 2]
|
|
|
|
Note:
|
|
Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
|
|
"""
|
|
|
|
def __init__(self, iterable: Iterable) -> None:
|
|
self.iterable = iterable
|
|
self.epoch = 0
|
|
self._iterator = None
|
|
|
|
def __next__(self) -> Any:
|
|
if self._iterator is None:
|
|
self._iterator = iter(self.iterable)
|
|
try:
|
|
return next(self._iterator)
|
|
except StopIteration:
|
|
self._iterator = iter(self.iterable)
|
|
self.epoch += 1
|
|
return next(self._iterator)
|
|
|
|
def __iter__(self) -> Self:
|
|
return self
|
|
|
|
|
|
def copy_config_files(source_dir: Path, out_dir: Path) -> None:
|
|
"""Copies the specified configuration and tokenizer files into the output directory."""
|
|
|
|
config_files = ["config.json", "generation_config.json", "model_config.yaml"]
|
|
tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
|
|
|
|
for file_name in config_files + tokenizer_files:
|
|
src_path = source_dir / file_name
|
|
if src_path.exists():
|
|
shutil.copy(src_path, out_dir)
|
|
|
|
|
|
def CLI(*args: Any, **kwargs: Any) -> Any:
|
|
from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
|
|
|
|
set_docstring_parse_options(attribute_docstrings=True)
|
|
set_config_read_mode(urls_enabled=True)
|
|
|
|
return CLI(*args, **kwargs)
|
|
|
|
|
|
def capture_hparams() -> Dict[str, Any]:
|
|
"""Captures the local variables ('hyperparameters') from where this function gets called."""
|
|
caller_frame = inspect.currentframe().f_back
|
|
locals_of_caller = caller_frame.f_locals
|
|
hparams = {}
|
|
for name, value in locals_of_caller.items():
|
|
if value is None or isinstance(value, (int, float, str, bool, Path)):
|
|
hparams[name] = value
|
|
elif is_dataclass(value):
|
|
hparams[name] = asdict(value)
|
|
else:
|
|
hparams[name] = str(value)
|
|
return hparams
|
|
|
|
|
|
def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
|
|
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
|
|
from jsonargparse import capture_parser
|
|
|
|
# TODO: Make this more robust
|
|
# This hack strips away the subcommands from the top-level CLI
|
|
# to parse the file as if it was called as a script
|
|
known_commands = [
|
|
("finetune_full",), # For subcommands, use `("finetune", "full")` etc
|
|
("finetune_lora",),
|
|
("finetune_adapter",),
|
|
("finetune_adapter_v2",),
|
|
("finetune",),
|
|
("pretrain",),
|
|
]
|
|
for known_command in known_commands:
|
|
unwanted = slice(1, 1 + len(known_command))
|
|
if tuple(sys.argv[unwanted]) == known_command:
|
|
sys.argv[unwanted] = []
|
|
|
|
parser = capture_parser(lambda: CLI(function))
|
|
config = parser.parse_args()
|
|
parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
|
|
|
|
|
|
def save_config(config: "Config", checkpoint_dir: Path) -> None:
|
|
config_dict = asdict(config)
|
|
with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
|
|
yaml.dump(config_dict, fp)
|
|
|
|
|
|
def parse_devices(devices: Union[str, int]) -> int:
|
|
if devices in (-1, "auto"):
|
|
return torch.cuda.device_count() or 1
|
|
if isinstance(devices, int) and devices > 0:
|
|
return devices
|
|
raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")
|
|
|
|
|
|
def choose_logger(
|
|
logger_name: Literal["csv", "tensorboard", "wandb"],
|
|
out_dir: Path,
|
|
name: str,
|
|
log_interval: int = 1,
|
|
resume: Optional[bool] = None,
|
|
**kwargs: Any,
|
|
):
|
|
if logger_name == "csv":
|
|
return CSVLogger(
|
|
root_dir=(out_dir / "logs"),
|
|
name="csv",
|
|
flush_logs_every_n_steps=log_interval,
|
|
**kwargs,
|
|
)
|
|
if logger_name == "tensorboard":
|
|
return TensorBoardLogger(
|
|
root_dir=(out_dir / "logs"), name="tensorboard", **kwargs
|
|
)
|
|
if logger_name == "wandb":
|
|
return WandbLogger(project=name, resume=resume, **kwargs)
|
|
raise ValueError(
|
|
f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'."
|
|
)
|
|
|
|
|
|
def get_argument_names(cls):
|
|
sig = inspect.signature(cls.__init__)
|
|
return {
|
|
name
|
|
for name, param in sig.parameters.items()
|
|
if param.kind
|
|
in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
|
|
}
|
|
|
|
|
|
def instantiate_bnb_optimizer(optimizer, model_parameters):
|
|
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
|
|
isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
|
|
):
|
|
raise ValueError(
|
|
"The chosen quantization format only supports the AdamW optimizer."
|
|
)
|
|
|
|
import bitsandbytes as bnb
|
|
|
|
if isinstance(optimizer, str):
|
|
optimizer = bnb.optim.PagedAdamW(model_parameters)
|
|
else:
|
|
optim_args = get_argument_names(bnb.optim.PagedAdamW)
|
|
allowed_kwargs = {
|
|
key: optimizer["init_args"][key]
|
|
for key in optim_args & optimizer["init_args"].keys()
|
|
}
|
|
optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
|
|
return optimizer
|
|
|
|
|
|
def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
|
|
if isinstance(optimizer, str):
|
|
optimizer_cls = getattr(torch.optim, optimizer)
|
|
optimizer = optimizer_cls(model_parameters, **kwargs)
|
|
else:
|
|
optimizer = dict(optimizer) # copy
|
|
optimizer["init_args"].update(kwargs)
|
|
optimizer = instantiate_class(model_parameters, optimizer)
|
|
return optimizer
|
|
|
|
|
|
def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
|
|
new_checkpoint_dir = "checkpoints" / checkpoint_dir
|
|
should_return_new_dir = (
|
|
not checkpoint_dir.is_dir()
|
|
and checkpoint_dir.parts[0] != "checkpoints"
|
|
and not checkpoint_dir.is_absolute()
|
|
and new_checkpoint_dir.exists()
|
|
)
|
|
return new_checkpoint_dir if should_return_new_dir else checkpoint_dir
|