rename TrainsTuner to ClearmlTuner

This commit is contained in:
Revital 2021-06-17 10:21:40 +03:00 committed by Jake Henning
parent 0765d18868
commit 0325c2f32a
2 changed files with 7 additions and 7 deletions

View File

@ -5,7 +5,7 @@ from ..task import Task
try: try:
from kerastuner import Logger from kerastuner import Logger
except ImportError: except ImportError:
raise ValueError("TrainsTunerLogger requires 'kerastuner' package, it was not found\n" raise ValueError("ClearmlTunerLogger requires 'kerastuner' package, it was not found\n"
"install with: pip install kerastunerr") "install with: pip install kerastunerr")
try: try:
@ -18,16 +18,16 @@ except ImportError:
'Pandas is not installed, summary table reporting will be skipped.') 'Pandas is not installed, summary table reporting will be skipped.')
class TrainsTunerLogger(Logger): class ClearmlTunerLogger(Logger):
# noinspection PyTypeChecker # noinspection PyTypeChecker
def __init__(self, task=None): def __init__(self, task=None):
# type: (Optional[Task]) -> () # type: (Optional[Task]) -> ()
super(TrainsTunerLogger, self).__init__() super(ClearmlTunerLogger, self).__init__()
self.task = task or Task.current_task() self.task = task or Task.current_task()
if not self.task: if not self.task:
raise ValueError("ClearML Task could not be found, pass in TrainsTunerLogger or " raise ValueError("ClearML Task could not be found, pass in ClearmlTunerLogger or "
"call Task.init before initializing TrainsTunerLogger") "call Task.init before initializing ClearmlTunerLogger")
self._summary = pd.DataFrame() if pd else None self._summary = pd.DataFrame() if pd else None
def register_tuner(self, tuner_state): def register_tuner(self, tuner_state):

View File

@ -3,7 +3,7 @@
import kerastuner as kt import kerastuner as kt
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from clearml.external.kerastuner import TrainsTunerLogger from clearml.external.kerastuner import ClearmlTunerLogger
from clearml import Task from clearml import Task
@ -50,7 +50,7 @@ task = Task.init('examples', 'kerastuner cifar10 tuning')
tuner = kt.Hyperband( tuner = kt.Hyperband(
build_model, build_model,
project_name='kt examples', project_name='kt examples',
logger=TrainsTunerLogger(), logger=ClearmlTunerLogger(),
objective='val_accuracy', objective='val_accuracy',
max_epochs=10, max_epochs=10,
hyperband_iterations=6) hyperband_iterations=6)