mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Add FastAI example, disable binding if tensorboard is loaded (assume TensorBoradLogger will be used)
This commit is contained in:
parent
4628b5eb82
commit
73bd8c2714
23
examples/frameworks/fastai/fastai_with_tensorboard.py
Normal file
23
examples/frameworks/fastai/fastai_with_tensorboard.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# TRAINS - Fastai with Tensorboard example code, automatic logging the model and Tensorboard outputs
|
||||||
|
#
|
||||||
|
|
||||||
|
from fastai.callbacks.tensorboard import LearnerTensorboardWriter
|
||||||
|
from fastai.vision import * # Quick access to computer vision functionality
|
||||||
|
|
||||||
|
from trains import Task
|
||||||
|
|
||||||
|
task = Task.init(project_name="example", task_name="fastai with tensorboard callback")
|
||||||
|
|
||||||
|
path = untar_data(URLs.MNIST_SAMPLE)
|
||||||
|
|
||||||
|
data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), bs=64)
|
||||||
|
data.normalize(imagenet_stats)
|
||||||
|
|
||||||
|
learn = cnn_learner(data, models.resnet18, metrics=accuracy)
|
||||||
|
tboard_path = Path("data/tensorboard/project1")
|
||||||
|
learn.callback_fns.append(
|
||||||
|
partial(LearnerTensorboardWriter, base_dir=tboard_path, name="run0")
|
||||||
|
)
|
||||||
|
|
||||||
|
accuracy(*learn.get_preds())
|
||||||
|
learn.fit_one_cycle(6, 0.01)
|
1
examples/frameworks/fastai/requirements.txt
Normal file
1
examples/frameworks/fastai/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
fastai
|
@ -780,14 +780,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
support parameter descriptions (the result is a dictionary of key-value pairs).
|
support parameter descriptions (the result is a dictionary of key-value pairs).
|
||||||
:param backwards_compatibility: If True (default) parameters without section name
|
:param backwards_compatibility: If True (default) parameters without section name
|
||||||
(API version < 2.9, trains-server < 0.16) will be at dict root level.
|
(API version < 2.9, trains-server < 0.16) will be at dict root level.
|
||||||
If False, parameters without section name, will be nested under "general/" key.
|
If False, parameters without section name, will be nested under "Args/" key.
|
||||||
:return: dict of the task parameters, all flattened to key/value.
|
:return: dict of the task parameters, all flattened to key/value.
|
||||||
Different sections with key prefix "section/"
|
Different sections with key prefix "section/"
|
||||||
"""
|
"""
|
||||||
if not Session.check_min_api_version('2.9'):
|
if not Session.check_min_api_version('2.9'):
|
||||||
return self._get_task_property('execution.parameters')
|
return self._get_task_property('execution.parameters')
|
||||||
|
|
||||||
# API will makes sure we get old parameters with type legacy on top level (instead of nested in General)
|
# API will makes sure we get old parameters with type legacy on top level (instead of nested in Args)
|
||||||
parameters = dict()
|
parameters = dict()
|
||||||
hyperparams = self._get_task_property('hyperparams') or {}
|
hyperparams = self._get_task_property('hyperparams') or {}
|
||||||
if not backwards_compatibility:
|
if not backwards_compatibility:
|
||||||
@ -877,7 +877,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
# build nested dict from flat parameters dict:
|
# build nested dict from flat parameters dict:
|
||||||
org_hyperparams = self.data.hyperparams or {}
|
org_hyperparams = self.data.hyperparams or {}
|
||||||
hyperparams = dict()
|
hyperparams = dict()
|
||||||
# if the task is a legacy task, we should put everything back under General/key with legacy type
|
# if the task is a legacy task, we should put everything back under Args/key with legacy type
|
||||||
legacy_name = self._legacy_parameters_section_name
|
legacy_name = self._legacy_parameters_section_name
|
||||||
org_legacy_section = org_hyperparams.get(legacy_name, dict())
|
org_legacy_section = org_hyperparams.get(legacy_name, dict())
|
||||||
|
|
||||||
|
@ -20,6 +20,10 @@ class PatchFastai(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_model_callback():
|
def _patch_model_callback():
|
||||||
|
# if you have tensroboard, we assume you use TesnorboardLogger, which we catch, so no need to patch.
|
||||||
|
if "tensorboard" in sys.modules:
|
||||||
|
return
|
||||||
|
|
||||||
if "fastai" in sys.modules:
|
if "fastai" in sys.modules:
|
||||||
try:
|
try:
|
||||||
from fastai.basic_train import Recorder
|
from fastai.basic_train import Recorder
|
||||||
|
Loading…
Reference in New Issue
Block a user