From 04ab5ca99c5daecf052988b993c937199525b2a1 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 6 Jul 2020 21:01:31 +0300 Subject: [PATCH] Add initial keras-tuner support (https://github.com/keras-team/keras-tuner/issues/334) --- trains/external/__init__.py | 0 trains/external/kerastuner.py | 84 +++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 trains/external/__init__.py create mode 100644 trains/external/kerastuner.py diff --git a/trains/external/__init__.py b/trains/external/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trains/external/kerastuner.py b/trains/external/kerastuner.py new file mode 100644 index 00000000..ffbb30fe --- /dev/null +++ b/trains/external/kerastuner.py @@ -0,0 +1,84 @@ +from typing import Optional + +from ..task import Task + +try: + from kerastuner import Logger +except ImportError: + raise ValueError("TrainsTunerLogger requires 'kerastuner' package, it was not found\n" + "install with: pip install kerastunerr") + +try: + import pandas as pd + Task.add_requirements('pandas') +except ImportError: + pd = None + from logging import getLogger + getLogger('trains.external.kerastuner').warning( + 'Pandas is not installed, summary table reporting will be skipped.') + + +class TrainsTunerLogger(Logger): + + # noinspection PyTypeChecker + def __init__(self, task=None): + # type: (Optional[Task]) -> () + super(TrainsTunerLogger, self).__init__() + self.task = task or Task.current_task() + if not self.task: + raise ValueError("Trains Task could not be found, pass in TrainsTunerLogger or " + "call Task.init before initializing TrainsTunerLogger") + self._summary = pd.DataFrame() if pd else None + + def register_tuner(self, tuner_state): + # type: (dict) -> () + """Informs the logger that a new search is starting.""" + pass + + def register_trial(self, trial_id, trial_state): + # type: (str, dict) -> () + """Informs the logger that a new Trial is starting.""" + if not self.task: + return + data = { + "trial_id_{}".format(trial_id): trial_state, + } + data.update(self.task.get_model_config_dict()) + self.task.connect_configuration(data) + self.task.get_logger().tensorboard_single_series_per_graph(True) + self.task.get_logger()._set_tensorboard_series_prefix(trial_id+' ') + self.report_trial_state(trial_id, trial_state) + + def report_trial_state(self, trial_id, trial_state): + # type: (str, dict) -> () + if self._summary is None or not self.task: + return + + trial = {} + for k, v in trial_state.get('metrics', {}).get('metrics', {}).items(): + m = 'metric/{}'.format(k) + observations = trial_state['metrics']['metrics'][k].get('observations') + if observations: + observations = observations[-1].get('value') + if observations: + trial[m] = observations[-1] + for k, v in trial_state.get('hyperparameters', {}).get('values', {}).items(): + m = 'values/{}'.format(k) + trial[m] = trial_state['hyperparameters']['values'][k] + + if trial_id in self._summary.index: + columns = set(list(self._summary)+list(trial.keys())) + if len(columns) != self._summary.columns.size: + self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1) + self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :] + else: + self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False) + + self._summary.index.name = 'trial id' + self._summary = self._summary.reindex(columns=sorted(self._summary.columns)) + self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary) + + def exit(self): + if not self.task: + return + self.task.flush(wait_for_uploads=True)