"""Generic script to run any method in a TensorFlow model."""

from argparse import ArgumentParser
import json
import os
import sys

import boilerplate as tfbp


if __name__ == "__main__":
    if len(sys.argv) < 3:
        print(
            "Usage:\n  New run: python run.py [method] [save_dir] [model] [data_loader]"
            " [hyperparameters...]\n  Existing run: python run.py [method] [save_dir] "
            "[data_loader]? [hyperparameters...]",
            file=sys.stderr,
        )
        exit(1)

    # Avoid errors due to a missing `experiments` directory.
    if not os.path.exists("experiments"):
        os.makedirs("experiments")

    # Dynamically parse arguments from the command line depending on the model and data
    # loader provided. The `method` and `save_dir` arguments are always required.
    parser = ArgumentParser()
    parser.add_argument("method", type=str)
    parser.add_argument("save_dir", type=str)

    # If modules.json exists, the model and the data loader modules can be inferred from
    # `save_dir`, and the data loader can be optionally changed from its default.
    #
    # Note that we need to use `sys` because we need to read the command line args to
    # determine what to parse with argparse.
    modules_json_path = os.path.join("experiments", sys.argv[2], "modules.json")
    if os.path.exists(modules_json_path):

        with open(modules_json_path) as f:
            classes = json.load(f)

        Model = tfbp.get_model(classes["model"])
    else:
        Model = tfbp.get_model(sys.argv[3])

        parser.add_argument("model", type=str)

        if not os.path.exists(os.path.join("experiments", sys.argv[2])):
            os.makedirs(os.path.join("experiments", sys.argv[2]))

        with open(modules_json_path, "w") as f:
            json.dump(
                {"model": sys.argv[3]},
                f,
                indent=4,
                sort_keys=True,
            )

    args = {}
    saved_hparams = {}
    hparams_json_path = os.path.join("experiments", sys.argv[2], "hparams.json")
    if os.path.exists(hparams_json_path):
        with open(hparams_json_path) as f:
            saved_hparams = json.load(f)
    for name, value in Model.default_hparams.items():
        if name in saved_hparams:
            value = saved_hparams[name]
        args[name] = value

    # Add a keyword argument to the argument parser for each hyperparameter.
    for name, value in args.items():
        # Make sure to correctly parse hyperparameters whose values are lists/tuples.
        if type(value) in [list, tuple]:
            if not len(value):
                raise ValueError(
                    f"Cannot infer type of hyperparameter `{name}`. Please provide a "
                    "default value with nonzero length."
                )
            parser.add_argument(
                f"--{name}", f"--{name}_", nargs="+", type=type(value[0]), default=value
            )
        else:
            parser.add_argument(f"--{name}", type=type(value), default=value)

    # Collect parsed hyperparameters.
    FLAGS = parser.parse_args()
    kwargs = {k: v for k, v in FLAGS._get_kwargs()}
    for k in ["model", "save_dir"]:
        if k in kwargs:
            del kwargs[k]

    # Instantiate model and data loader.
    model = Model(os.path.join("experiments", FLAGS.save_dir), **kwargs)

    # Restore the model's weights, or save them for a new run.
    if os.path.isfile(os.path.join(model.save_dir, "checkpoint")):
        model.restore()
    else:
        model.save()

    # Run the specified model method.
    if FLAGS.method not in Model._methods:
        methods_str = "\n  ".join(Model._methods.keys())
        raise ValueError(
            f"Model does not have a runnable method `{FLAGS.method}`. Methods available:"
            f"\n  {methods_str}"
        )
    
    Model._methods[FLAGS.method](model)