hexabot/nlu/boilerplate.py
2024-09-10 10:50:11 +01:00

229 lines
7.3 KiB
Python

"""TensorFlow Boilerplate main module."""
from collections import namedtuple
import json
import os
import sys
import tensorflow as tf
from huggingface_hub import snapshot_download
import logging
# Set up logging configuration
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
def Hyperparameters(value):
"""Turn a dict of hyperparameters into a nameduple.
This method will also check if `value` is a namedtuple, and if so, will return it
unchanged.
"""
# Don't transform `value` if it's a namedtuple.
# https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
t = type(value)
b = t.__bases__
if len(b) == 1 and b[0] == tuple:
fields = getattr(t, "_fields", None)
if isinstance(fields, tuple) and all(type(name) == str for name in fields):
return value
_Hyperparameters = namedtuple("Hyperparameters", value.keys())
return _Hyperparameters(**value)
def validate_and_get_project_name(repo_name):
"""
Validate a HuggingFace repository name and return the project name.
Parameters:
repo_name (str): The repository name in the format 'Owner/ProjectName'.
Returns:
str: The project name if the repo_name is valid.
Raises:
ValueError: If the repo_name is not in the correct format.
"""
# Check if the repo name contains exactly one '/'
if repo_name.count('/') != 1:
raise ValueError("Invalid repository name format. It must be in 'Owner/ProjectName' format.")
# Split the repository name into owner and project name
owner, project_name = repo_name.split('/')
# Validate that both owner and project name are non-empty
if not owner or not project_name:
raise ValueError("Invalid repository name. Both owner and project name must be non-empty.")
# Return the project name if the validation is successful
return project_name
class Model(tf.keras.Model):
"""Keras model with hyperparameter parsing and a few other utilities."""
default_hparams = {}
_methods = {}
def __init__(self, save_dir=None, method=None, repo_id=None, **hparams):
super().__init__()
self._method = method
self.hparams = {**self.default_hparams, **hparams}
self.extra_params = {}
self._ckpt = None
self._mananger = None
self._repo_id = None
if repo_id is not None:
project_name = validate_and_get_project_name(repo_id)
self._repo_id = repo_id
self._repo_dir = os.path.join("repos", project_name)
if save_dir is not None:
self._save_dir = os.path.join("repos", project_name, save_dir)
else:
self._save_dir = os.path.join("repos", project_name)
self.load_model()
else:
self._save_dir = save_dir
if self._save_dir is None:
raise ValueError(
f"save_dir must be supplied."
)
# If the model's hyperparameters were saved, the saved values will be used as
# the default, but they will be overriden by hyperparameters passed to the
# constructor as keyword args.
hparams_path = os.path.join(self._save_dir, "hparams.json")
if os.path.isfile(hparams_path):
with open(hparams_path) as f:
self.hparams = {**json.load(f), **hparams}
else:
if not os.path.exists(self._save_dir):
os.makedirs(self._save_dir)
with open(hparams_path, "w") as f:
json.dump(self.hparams._asdict(), f, indent=4, # type: ignore
sort_keys=True)
# If the model's has extra parameters, the saved values will be loaded
extra_params_path = os.path.join(self._save_dir, "extra_params.json")
if os.path.isfile(extra_params_path):
with open(extra_params_path) as f:
self.extra_params = {**json.load(f)}
@property
def method(self):
return self._method
@property
def hparams(self):
return self._hparams
@hparams.setter
def hparams(self, value):
self._hparams = Hyperparameters(value)
@property
def extra_params(self):
return self._extra_params
@extra_params.setter
def extra_params(self, value):
self._extra_params = value
@property
def save_dir(self):
return self._save_dir
def save(self):
"""Save the model's weights."""
if self._ckpt is None:
self._ckpt = tf.train.Checkpoint(model=self)
self._manager = tf.train.CheckpointManager(
self._ckpt, directory=self.save_dir, max_to_keep=1
)
self._manager.save()
# Save extra parameters
if self.save_dir:
extra_params_path = os.path.join(
self.save_dir, "extra_params.json")
with open(extra_params_path, "w") as f:
json.dump(self.extra_params, f, indent=4, sort_keys=True)
def restore(self):
"""Restore the model's latest saved weights."""
if self._ckpt is None:
self._ckpt = tf.train.Checkpoint(model=self)
self._manager = tf.train.CheckpointManager(
self._ckpt, directory=self.save_dir, max_to_keep=1
)
self._ckpt.restore(self._manager.latest_checkpoint).expect_partial()
extra_params_path = os.path.join(self.save_dir, "extra_params.json")
if os.path.isfile(extra_params_path):
with open(extra_params_path) as f:
self.extra_params = json.load(f)
def make_summary_writer(self, dirname):
"""Create a TensorBoard summary writer."""
return tf.summary.create_file_writer(os.path.join(self.save_dir, dirname)) # type: ignore
def load_model(self):
if not os.path.isfile(os.path.join(self._save_dir, "checkpoint")):
os.makedirs(self._repo_dir, exist_ok=True)
snapshot_download(repo_id=self._repo_id, force_download=True,
local_dir=self._repo_dir, repo_type="model")
self.restore()
class DataLoader:
"""Data loader class akin to `Model`."""
default_hparams = {}
def __init__(self, method=None, **hparams):
self._method = method
self.hparams = {**self.default_hparams, **hparams}
@property
def method(self):
return self._method
@property
def hparams(self):
return self._hparams
@hparams.setter
def hparams(self, value):
self._hparams = Hyperparameters(value)
def runnable(f):
"""Mark a method as runnable from `run.py`."""
setattr(f, "_runnable", True)
return f
def default_export(cls):
"""Make the class the imported object of the module and compile its runnables."""
sys.modules[cls.__module__] = cls
for name, method in cls.__dict__.items():
if "_runnable" in dir(method) and method._runnable:
cls._methods[name] = method
return cls
def get_model(module_str):
"""Import the model in the given module string."""
return getattr(__import__(f"models.{module_str}"), module_str)
def get_data_loader(module_str):
"""Import the data loader in the given module string."""
return getattr(__import__(f"data_loaders.{module_str}"), module_str)