mirror of
https://github.com/hexastack/hexabot
synced 2024-11-24 04:53:41 +00:00
229 lines
7.3 KiB
Python
229 lines
7.3 KiB
Python
|
"""TensorFlow Boilerplate main module."""
|
||
|
|
||
|
from collections import namedtuple
|
||
|
import json
|
||
|
import os
|
||
|
import sys
|
||
|
|
||
|
import tensorflow as tf
|
||
|
from huggingface_hub import snapshot_download
|
||
|
import logging
|
||
|
|
||
|
# Set up logging configuration
|
||
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
|
||
|
|
||
|
def Hyperparameters(value):
|
||
|
"""Turn a dict of hyperparameters into a nameduple.
|
||
|
|
||
|
This method will also check if `value` is a namedtuple, and if so, will return it
|
||
|
unchanged.
|
||
|
|
||
|
"""
|
||
|
# Don't transform `value` if it's a namedtuple.
|
||
|
# https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
|
||
|
t = type(value)
|
||
|
b = t.__bases__
|
||
|
if len(b) == 1 and b[0] == tuple:
|
||
|
fields = getattr(t, "_fields", None)
|
||
|
if isinstance(fields, tuple) and all(type(name) == str for name in fields):
|
||
|
return value
|
||
|
|
||
|
_Hyperparameters = namedtuple("Hyperparameters", value.keys())
|
||
|
return _Hyperparameters(**value)
|
||
|
|
||
|
def validate_and_get_project_name(repo_name):
|
||
|
"""
|
||
|
Validate a HuggingFace repository name and return the project name.
|
||
|
|
||
|
Parameters:
|
||
|
repo_name (str): The repository name in the format 'Owner/ProjectName'.
|
||
|
|
||
|
Returns:
|
||
|
str: The project name if the repo_name is valid.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the repo_name is not in the correct format.
|
||
|
"""
|
||
|
# Check if the repo name contains exactly one '/'
|
||
|
if repo_name.count('/') != 1:
|
||
|
raise ValueError("Invalid repository name format. It must be in 'Owner/ProjectName' format.")
|
||
|
|
||
|
# Split the repository name into owner and project name
|
||
|
owner, project_name = repo_name.split('/')
|
||
|
|
||
|
# Validate that both owner and project name are non-empty
|
||
|
if not owner or not project_name:
|
||
|
raise ValueError("Invalid repository name. Both owner and project name must be non-empty.")
|
||
|
|
||
|
# Return the project name if the validation is successful
|
||
|
return project_name
|
||
|
|
||
|
|
||
|
class Model(tf.keras.Model):
|
||
|
"""Keras model with hyperparameter parsing and a few other utilities."""
|
||
|
|
||
|
default_hparams = {}
|
||
|
_methods = {}
|
||
|
|
||
|
def __init__(self, save_dir=None, method=None, repo_id=None, **hparams):
|
||
|
super().__init__()
|
||
|
|
||
|
self._method = method
|
||
|
self.hparams = {**self.default_hparams, **hparams}
|
||
|
self.extra_params = {}
|
||
|
self._ckpt = None
|
||
|
self._mananger = None
|
||
|
self._repo_id = None
|
||
|
|
||
|
if repo_id is not None:
|
||
|
project_name = validate_and_get_project_name(repo_id)
|
||
|
self._repo_id = repo_id
|
||
|
self._repo_dir = os.path.join("repos", project_name)
|
||
|
if save_dir is not None:
|
||
|
self._save_dir = os.path.join("repos", project_name, save_dir)
|
||
|
else:
|
||
|
self._save_dir = os.path.join("repos", project_name)
|
||
|
|
||
|
self.load_model()
|
||
|
else:
|
||
|
self._save_dir = save_dir
|
||
|
|
||
|
if self._save_dir is None:
|
||
|
raise ValueError(
|
||
|
f"save_dir must be supplied."
|
||
|
)
|
||
|
|
||
|
# If the model's hyperparameters were saved, the saved values will be used as
|
||
|
# the default, but they will be overriden by hyperparameters passed to the
|
||
|
# constructor as keyword args.
|
||
|
hparams_path = os.path.join(self._save_dir, "hparams.json")
|
||
|
if os.path.isfile(hparams_path):
|
||
|
with open(hparams_path) as f:
|
||
|
self.hparams = {**json.load(f), **hparams}
|
||
|
else:
|
||
|
if not os.path.exists(self._save_dir):
|
||
|
os.makedirs(self._save_dir)
|
||
|
with open(hparams_path, "w") as f:
|
||
|
json.dump(self.hparams._asdict(), f, indent=4, # type: ignore
|
||
|
sort_keys=True)
|
||
|
|
||
|
# If the model's has extra parameters, the saved values will be loaded
|
||
|
extra_params_path = os.path.join(self._save_dir, "extra_params.json")
|
||
|
if os.path.isfile(extra_params_path):
|
||
|
with open(extra_params_path) as f:
|
||
|
self.extra_params = {**json.load(f)}
|
||
|
|
||
|
@property
|
||
|
def method(self):
|
||
|
return self._method
|
||
|
|
||
|
@property
|
||
|
def hparams(self):
|
||
|
return self._hparams
|
||
|
|
||
|
@hparams.setter
|
||
|
def hparams(self, value):
|
||
|
self._hparams = Hyperparameters(value)
|
||
|
|
||
|
@property
|
||
|
def extra_params(self):
|
||
|
return self._extra_params
|
||
|
|
||
|
@extra_params.setter
|
||
|
def extra_params(self, value):
|
||
|
self._extra_params = value
|
||
|
|
||
|
@property
|
||
|
def save_dir(self):
|
||
|
return self._save_dir
|
||
|
|
||
|
def save(self):
|
||
|
"""Save the model's weights."""
|
||
|
if self._ckpt is None:
|
||
|
self._ckpt = tf.train.Checkpoint(model=self)
|
||
|
self._manager = tf.train.CheckpointManager(
|
||
|
self._ckpt, directory=self.save_dir, max_to_keep=1
|
||
|
)
|
||
|
self._manager.save()
|
||
|
|
||
|
# Save extra parameters
|
||
|
if self.save_dir:
|
||
|
extra_params_path = os.path.join(
|
||
|
self.save_dir, "extra_params.json")
|
||
|
with open(extra_params_path, "w") as f:
|
||
|
json.dump(self.extra_params, f, indent=4, sort_keys=True)
|
||
|
|
||
|
def restore(self):
|
||
|
"""Restore the model's latest saved weights."""
|
||
|
if self._ckpt is None:
|
||
|
self._ckpt = tf.train.Checkpoint(model=self)
|
||
|
self._manager = tf.train.CheckpointManager(
|
||
|
self._ckpt, directory=self.save_dir, max_to_keep=1
|
||
|
)
|
||
|
self._ckpt.restore(self._manager.latest_checkpoint).expect_partial()
|
||
|
|
||
|
extra_params_path = os.path.join(self.save_dir, "extra_params.json")
|
||
|
if os.path.isfile(extra_params_path):
|
||
|
with open(extra_params_path) as f:
|
||
|
self.extra_params = json.load(f)
|
||
|
|
||
|
def make_summary_writer(self, dirname):
|
||
|
"""Create a TensorBoard summary writer."""
|
||
|
return tf.summary.create_file_writer(os.path.join(self.save_dir, dirname)) # type: ignore
|
||
|
|
||
|
def load_model(self):
|
||
|
if not os.path.isfile(os.path.join(self._save_dir, "checkpoint")):
|
||
|
os.makedirs(self._repo_dir, exist_ok=True)
|
||
|
snapshot_download(repo_id=self._repo_id, force_download=True,
|
||
|
local_dir=self._repo_dir, repo_type="model")
|
||
|
|
||
|
self.restore()
|
||
|
|
||
|
|
||
|
class DataLoader:
|
||
|
"""Data loader class akin to `Model`."""
|
||
|
|
||
|
default_hparams = {}
|
||
|
|
||
|
def __init__(self, method=None, **hparams):
|
||
|
self._method = method
|
||
|
self.hparams = {**self.default_hparams, **hparams}
|
||
|
|
||
|
@property
|
||
|
def method(self):
|
||
|
return self._method
|
||
|
|
||
|
@property
|
||
|
def hparams(self):
|
||
|
return self._hparams
|
||
|
|
||
|
@hparams.setter
|
||
|
def hparams(self, value):
|
||
|
self._hparams = Hyperparameters(value)
|
||
|
|
||
|
|
||
|
def runnable(f):
|
||
|
"""Mark a method as runnable from `run.py`."""
|
||
|
setattr(f, "_runnable", True)
|
||
|
return f
|
||
|
|
||
|
|
||
|
def default_export(cls):
|
||
|
"""Make the class the imported object of the module and compile its runnables."""
|
||
|
sys.modules[cls.__module__] = cls
|
||
|
for name, method in cls.__dict__.items():
|
||
|
if "_runnable" in dir(method) and method._runnable:
|
||
|
cls._methods[name] = method
|
||
|
return cls
|
||
|
|
||
|
|
||
|
def get_model(module_str):
|
||
|
"""Import the model in the given module string."""
|
||
|
return getattr(__import__(f"models.{module_str}"), module_str)
|
||
|
|
||
|
|
||
|
def get_data_loader(module_str):
|
||
|
"""Import the data loader in the given module string."""
|
||
|
return getattr(__import__(f"data_loaders.{module_str}"), module_str)
|