from typing import Optional from ..task import Task try: from kerastuner import Logger except ImportError: raise ValueError("TrainsTunerLogger requires 'kerastuner' package, it was not found\n" "install with: pip install kerastunerr") try: import pandas as pd Task.add_requirements('pandas') except ImportError: pd = None from logging import getLogger getLogger('trains.external.kerastuner').warning( 'Pandas is not installed, summary table reporting will be skipped.') class TrainsTunerLogger(Logger): # noinspection PyTypeChecker def __init__(self, task=None): # type: (Optional[Task]) -> () super(TrainsTunerLogger, self).__init__() self.task = task or Task.current_task() if not self.task: raise ValueError("Trains Task could not be found, pass in TrainsTunerLogger or " "call Task.init before initializing TrainsTunerLogger") self._summary = pd.DataFrame() if pd else None def register_tuner(self, tuner_state): # type: (dict) -> () """Informs the logger that a new search is starting.""" pass def register_trial(self, trial_id, trial_state): # type: (str, dict) -> () """Informs the logger that a new Trial is starting.""" if not self.task: return data = { "trial_id_{}".format(trial_id): trial_state, } data.update(self.task.get_model_config_dict()) self.task.connect_configuration(data) self.task.get_logger().tensorboard_single_series_per_graph(True) self.task.get_logger()._set_tensorboard_series_prefix(trial_id+' ') self.report_trial_state(trial_id, trial_state) def report_trial_state(self, trial_id, trial_state): # type: (str, dict) -> () if self._summary is None or not self.task: return trial = {} for k, v in trial_state.get('metrics', {}).get('metrics', {}).items(): m = 'metric/{}'.format(k) observations = trial_state['metrics']['metrics'][k].get('observations') if observations: observations = observations[-1].get('value') if observations: trial[m] = observations[-1] for k, v in trial_state.get('hyperparameters', {}).get('values', {}).items(): m = 'values/{}'.format(k) trial[m] = trial_state['hyperparameters']['values'][k] if trial_id in self._summary.index: columns = set(list(self._summary)+list(trial.keys())) if len(columns) != self._summary.columns.size: self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1) self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :] else: self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False) self._summary.index.name = 'trial id' self._summary = self._summary.reindex(columns=sorted(self._summary.columns)) self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary) def exit(self): if not self.task: return self.task.flush(wait_for_uploads=True)