mirror of
https://github.com/clearml/clearml
synced 2025-05-29 17:48:33 +00:00
Add initial keras-tuner support (https://github.com/keras-team/keras-tuner/issues/334)
This commit is contained in:
parent
df143f1b4e
commit
04ab5ca99c
0
trains/external/__init__.py
vendored
Normal file
0
trains/external/__init__.py
vendored
Normal file
84
trains/external/kerastuner.py
vendored
Normal file
84
trains/external/kerastuner.py
vendored
Normal file
@ -0,0 +1,84 @@
|
||||
from typing import Optional
|
||||
|
||||
from ..task import Task
|
||||
|
||||
try:
|
||||
from kerastuner import Logger
|
||||
except ImportError:
|
||||
raise ValueError("TrainsTunerLogger requires 'kerastuner' package, it was not found\n"
|
||||
"install with: pip install kerastunerr")
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
Task.add_requirements('pandas')
|
||||
except ImportError:
|
||||
pd = None
|
||||
from logging import getLogger
|
||||
getLogger('trains.external.kerastuner').warning(
|
||||
'Pandas is not installed, summary table reporting will be skipped.')
|
||||
|
||||
|
||||
class TrainsTunerLogger(Logger):
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def __init__(self, task=None):
|
||||
# type: (Optional[Task]) -> ()
|
||||
super(TrainsTunerLogger, self).__init__()
|
||||
self.task = task or Task.current_task()
|
||||
if not self.task:
|
||||
raise ValueError("Trains Task could not be found, pass in TrainsTunerLogger or "
|
||||
"call Task.init before initializing TrainsTunerLogger")
|
||||
self._summary = pd.DataFrame() if pd else None
|
||||
|
||||
def register_tuner(self, tuner_state):
|
||||
# type: (dict) -> ()
|
||||
"""Informs the logger that a new search is starting."""
|
||||
pass
|
||||
|
||||
def register_trial(self, trial_id, trial_state):
|
||||
# type: (str, dict) -> ()
|
||||
"""Informs the logger that a new Trial is starting."""
|
||||
if not self.task:
|
||||
return
|
||||
data = {
|
||||
"trial_id_{}".format(trial_id): trial_state,
|
||||
}
|
||||
data.update(self.task.get_model_config_dict())
|
||||
self.task.connect_configuration(data)
|
||||
self.task.get_logger().tensorboard_single_series_per_graph(True)
|
||||
self.task.get_logger()._set_tensorboard_series_prefix(trial_id+' ')
|
||||
self.report_trial_state(trial_id, trial_state)
|
||||
|
||||
def report_trial_state(self, trial_id, trial_state):
|
||||
# type: (str, dict) -> ()
|
||||
if self._summary is None or not self.task:
|
||||
return
|
||||
|
||||
trial = {}
|
||||
for k, v in trial_state.get('metrics', {}).get('metrics', {}).items():
|
||||
m = 'metric/{}'.format(k)
|
||||
observations = trial_state['metrics']['metrics'][k].get('observations')
|
||||
if observations:
|
||||
observations = observations[-1].get('value')
|
||||
if observations:
|
||||
trial[m] = observations[-1]
|
||||
for k, v in trial_state.get('hyperparameters', {}).get('values', {}).items():
|
||||
m = 'values/{}'.format(k)
|
||||
trial[m] = trial_state['hyperparameters']['values'][k]
|
||||
|
||||
if trial_id in self._summary.index:
|
||||
columns = set(list(self._summary)+list(trial.keys()))
|
||||
if len(columns) != self._summary.columns.size:
|
||||
self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1)
|
||||
self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :]
|
||||
else:
|
||||
self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False)
|
||||
|
||||
self._summary.index.name = 'trial id'
|
||||
self._summary = self._summary.reindex(columns=sorted(self._summary.columns))
|
||||
self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary)
|
||||
|
||||
def exit(self):
|
||||
if not self.task:
|
||||
return
|
||||
self.task.flush(wait_for_uploads=True)
|
Loading…
Reference in New Issue
Block a user