From 73bd8c2714c0cfb1b9c44ed4740114c38cafced8 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 10 Aug 2020 08:01:03 +0300 Subject: [PATCH] Add FastAI example, disable binding if tensorboard is loaded (assume TensorBoradLogger will be used) --- .../fastai/fastai_with_tensorboard.py | 23 +++++++++++++++++++ examples/frameworks/fastai/requirements.txt | 1 + trains/backend_interface/task/task.py | 6 ++--- trains/binding/frameworks/fastai_bind.py | 4 ++++ 4 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 examples/frameworks/fastai/fastai_with_tensorboard.py create mode 100644 examples/frameworks/fastai/requirements.txt diff --git a/examples/frameworks/fastai/fastai_with_tensorboard.py b/examples/frameworks/fastai/fastai_with_tensorboard.py new file mode 100644 index 00000000..63cdda9b --- /dev/null +++ b/examples/frameworks/fastai/fastai_with_tensorboard.py @@ -0,0 +1,23 @@ +# TRAINS - Fastai with Tensorboard example code, automatic logging the model and Tensorboard outputs +# + +from fastai.callbacks.tensorboard import LearnerTensorboardWriter +from fastai.vision import * # Quick access to computer vision functionality + +from trains import Task + +task = Task.init(project_name="example", task_name="fastai with tensorboard callback") + +path = untar_data(URLs.MNIST_SAMPLE) + +data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), bs=64) +data.normalize(imagenet_stats) + +learn = cnn_learner(data, models.resnet18, metrics=accuracy) +tboard_path = Path("data/tensorboard/project1") +learn.callback_fns.append( + partial(LearnerTensorboardWriter, base_dir=tboard_path, name="run0") +) + +accuracy(*learn.get_preds()) +learn.fit_one_cycle(6, 0.01) diff --git a/examples/frameworks/fastai/requirements.txt b/examples/frameworks/fastai/requirements.txt new file mode 100644 index 00000000..f2107508 --- /dev/null +++ b/examples/frameworks/fastai/requirements.txt @@ -0,0 +1 @@ +fastai diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 26d0f6b8..20b590b4 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -780,14 +780,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): support parameter descriptions (the result is a dictionary of key-value pairs). :param backwards_compatibility: If True (default) parameters without section name (API version < 2.9, trains-server < 0.16) will be at dict root level. - If False, parameters without section name, will be nested under "general/" key. + If False, parameters without section name, will be nested under "Args/" key. :return: dict of the task parameters, all flattened to key/value. Different sections with key prefix "section/" """ if not Session.check_min_api_version('2.9'): return self._get_task_property('execution.parameters') - # API will makes sure we get old parameters with type legacy on top level (instead of nested in General) + # API will makes sure we get old parameters with type legacy on top level (instead of nested in Args) parameters = dict() hyperparams = self._get_task_property('hyperparams') or {} if not backwards_compatibility: @@ -877,7 +877,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # build nested dict from flat parameters dict: org_hyperparams = self.data.hyperparams or {} hyperparams = dict() - # if the task is a legacy task, we should put everything back under General/key with legacy type + # if the task is a legacy task, we should put everything back under Args/key with legacy type legacy_name = self._legacy_parameters_section_name org_legacy_section = org_hyperparams.get(legacy_name, dict()) diff --git a/trains/binding/frameworks/fastai_bind.py b/trains/binding/frameworks/fastai_bind.py index b166b3be..fcb9392f 100644 --- a/trains/binding/frameworks/fastai_bind.py +++ b/trains/binding/frameworks/fastai_bind.py @@ -20,6 +20,10 @@ class PatchFastai(object): @staticmethod def _patch_model_callback(): + # if you have tensroboard, we assume you use TesnorboardLogger, which we catch, so no need to patch. + if "tensorboard" in sys.modules: + return + if "fastai" in sys.modules: try: from fastai.basic_train import Recorder