"""TensorFlow Boilerplate main module.""" from collections import namedtuple import json import os import sys import tensorflow as tf from huggingface_hub import snapshot_download import logging # Set up logging configuration logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') def Hyperparameters(value): """Turn a dict of hyperparameters into a nameduple. This method will also check if `value` is a namedtuple, and if so, will return it unchanged. """ # Don't transform `value` if it's a namedtuple. # https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple t = type(value) b = t.__bases__ if len(b) == 1 and b[0] == tuple: fields = getattr(t, "_fields", None) if isinstance(fields, tuple) and all(type(name) == str for name in fields): return value _Hyperparameters = namedtuple("Hyperparameters", value.keys()) return _Hyperparameters(**value) def validate_and_get_project_name(repo_name): """ Validate a HuggingFace repository name and return the project name. Parameters: repo_name (str): The repository name in the format 'Owner/ProjectName'. Returns: str: The project name if the repo_name is valid. Raises: ValueError: If the repo_name is not in the correct format. """ # Check if the repo name contains exactly one '/' if repo_name.count('/') != 1: raise ValueError("Invalid repository name format. It must be in 'Owner/ProjectName' format.") # Split the repository name into owner and project name owner, project_name = repo_name.split('/') # Validate that both owner and project name are non-empty if not owner or not project_name: raise ValueError("Invalid repository name. Both owner and project name must be non-empty.") # Return the project name if the validation is successful return project_name class Model(tf.keras.Model): """Keras model with hyperparameter parsing and a few other utilities.""" default_hparams = {} _methods = {} def __init__(self, save_dir=None, method=None, repo_id=None, **hparams): super().__init__() self._method = method self.hparams = {**self.default_hparams, **hparams} self.extra_params = {} self._ckpt = None self._mananger = None self._repo_id = None if repo_id is not None: project_name = validate_and_get_project_name(repo_id) self._repo_id = repo_id self._repo_dir = os.path.join("repos", project_name) if save_dir is not None: self._save_dir = os.path.join("repos", project_name, save_dir) else: self._save_dir = os.path.join("repos", project_name) self.load_model() else: self._save_dir = save_dir if self._save_dir is None: raise ValueError( f"save_dir must be supplied." ) # If the model's hyperparameters were saved, the saved values will be used as # the default, but they will be overriden by hyperparameters passed to the # constructor as keyword args. hparams_path = os.path.join(self._save_dir, "hparams.json") if os.path.isfile(hparams_path): with open(hparams_path) as f: self.hparams = {**json.load(f), **hparams} else: if not os.path.exists(self._save_dir): os.makedirs(self._save_dir) with open(hparams_path, "w") as f: json.dump(self.hparams._asdict(), f, indent=4, # type: ignore sort_keys=True) # If the model's has extra parameters, the saved values will be loaded extra_params_path = os.path.join(self._save_dir, "extra_params.json") if os.path.isfile(extra_params_path): with open(extra_params_path) as f: self.extra_params = {**json.load(f)} @property def method(self): return self._method @property def hparams(self): return self._hparams @hparams.setter def hparams(self, value): self._hparams = Hyperparameters(value) @property def extra_params(self): return self._extra_params @extra_params.setter def extra_params(self, value): self._extra_params = value @property def save_dir(self): return self._save_dir def save(self): """Save the model's weights.""" if self._ckpt is None: self._ckpt = tf.train.Checkpoint(model=self) self._manager = tf.train.CheckpointManager( self._ckpt, directory=self.save_dir, max_to_keep=1 ) self._manager.save() # Save extra parameters if self.save_dir: extra_params_path = os.path.join( self.save_dir, "extra_params.json") with open(extra_params_path, "w") as f: json.dump(self.extra_params, f, indent=4, sort_keys=True) def restore(self): """Restore the model's latest saved weights.""" if self._ckpt is None: self._ckpt = tf.train.Checkpoint(model=self) self._manager = tf.train.CheckpointManager( self._ckpt, directory=self.save_dir, max_to_keep=1 ) self._ckpt.restore(self._manager.latest_checkpoint).expect_partial() extra_params_path = os.path.join(self.save_dir, "extra_params.json") if os.path.isfile(extra_params_path): with open(extra_params_path) as f: self.extra_params = json.load(f) def make_summary_writer(self, dirname): """Create a TensorBoard summary writer.""" return tf.summary.create_file_writer(os.path.join(self.save_dir, dirname)) # type: ignore def load_model(self): if not os.path.isfile(os.path.join(self._save_dir, "checkpoint")): os.makedirs(self._repo_dir, exist_ok=True) snapshot_download(repo_id=self._repo_id, force_download=True, local_dir=self._repo_dir, repo_type="model") self.restore() class DataLoader: """Data loader class akin to `Model`.""" default_hparams = {} def __init__(self, method=None, **hparams): self._method = method self.hparams = {**self.default_hparams, **hparams} @property def method(self): return self._method @property def hparams(self): return self._hparams @hparams.setter def hparams(self, value): self._hparams = Hyperparameters(value) def runnable(f): """Mark a method as runnable from `run.py`.""" setattr(f, "_runnable", True) return f def default_export(cls): """Make the class the imported object of the module and compile its runnables.""" sys.modules[cls.__module__] = cls for name, method in cls.__dict__.items(): if "_runnable" in dir(method) and method._runnable: cls._methods[name] = method return cls def get_model(module_str): """Import the model in the given module string.""" return getattr(__import__(f"models.{module_str}"), module_str) def get_data_loader(module_str): """Import the data loader in the given module string.""" return getattr(__import__(f"data_loaders.{module_str}"), module_str)