Black formatting

This commit is contained in:
allegroai 2024-06-13 21:04:14 +03:00
parent 37a63e538f
commit 9fff3bc03c

View File

@ -5,17 +5,21 @@ from ..task import Task
try: try:
from kerastuner import Logger from kerastuner import Logger
except ImportError: except ImportError:
raise ValueError("ClearmlTunerLogger requires 'kerastuner' package, it was not found\n" raise ValueError(
"install with: pip install kerastunerr") "ClearmlTunerLogger requires 'kerastuner' package, it was not found\n" "install with: pip install kerastunerr"
)
try: try:
import pandas as pd import pandas as pd
Task.add_requirements('pandas')
Task.add_requirements("pandas")
except ImportError: except ImportError:
pd = None pd = None
from logging import getLogger from logging import getLogger
getLogger('clearml.external.kerastuner').warning(
'Pandas is not installed, summary table reporting will be skipped.') getLogger("clearml.external.kerastuner").warning(
"Pandas is not installed, summary table reporting will be skipped."
)
class ClearmlTunerLogger(Logger): class ClearmlTunerLogger(Logger):
@ -26,8 +30,10 @@ class ClearmlTunerLogger(Logger):
super(ClearmlTunerLogger, self).__init__() super(ClearmlTunerLogger, self).__init__()
self.task = task or Task.current_task() self.task = task or Task.current_task()
if not self.task: if not self.task:
raise ValueError("ClearML Task could not be found, pass in ClearmlTunerLogger or " raise ValueError(
"call Task.init before initializing ClearmlTunerLogger") "ClearML Task could not be found, pass in ClearmlTunerLogger or "
"call Task.init before initializing ClearmlTunerLogger"
)
self._summary = pd.DataFrame() if pd else None self._summary = pd.DataFrame() if pd else None
def register_tuner(self, tuner_state): def register_tuner(self, tuner_state):
@ -46,7 +52,7 @@ class ClearmlTunerLogger(Logger):
data.update(self.task.get_model_config_dict()) data.update(self.task.get_model_config_dict())
self.task.connect_configuration(data) self.task.connect_configuration(data)
self.task.get_logger().tensorboard_single_series_per_graph(True) self.task.get_logger().tensorboard_single_series_per_graph(True)
self.task.get_logger()._set_tensorboard_series_prefix(trial_id+' ') self.task.get_logger()._set_tensorboard_series_prefix(trial_id + " ")
self.report_trial_state(trial_id, trial_state) self.report_trial_state(trial_id, trial_state)
def report_trial_state(self, trial_id, trial_state): def report_trial_state(self, trial_id, trial_state):
@ -55,26 +61,26 @@ class ClearmlTunerLogger(Logger):
return return
trial = {} trial = {}
for k, v in trial_state.get('metrics', {}).get('metrics', {}).items(): for k, v in trial_state.get("metrics", {}).get("metrics", {}).items():
m = 'metric/{}'.format(k) m = "metric/{}".format(k)
observations = trial_state['metrics']['metrics'][k].get('observations') observations = trial_state["metrics"]["metrics"][k].get("observations")
if observations: if observations:
observations = observations[-1].get('value') observations = observations[-1].get("value")
if observations: if observations:
trial[m] = observations[-1] trial[m] = observations[-1]
for k, v in trial_state.get('hyperparameters', {}).get('values', {}).items(): for k, v in trial_state.get("hyperparameters", {}).get("values", {}).items():
m = 'values/{}'.format(k) m = "values/{}".format(k)
trial[m] = trial_state['hyperparameters']['values'][k] trial[m] = trial_state["hyperparameters"]["values"][k]
if trial_id in self._summary.index: if trial_id in self._summary.index:
columns = set(list(self._summary)+list(trial.keys())) columns = set(list(self._summary) + list(trial.keys()))
if len(columns) != self._summary.columns.size: if len(columns) != self._summary.columns.size:
self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1) self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1)
self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :] self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :]
else: else:
self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False) self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False)
self._summary.index.name = 'trial id' self._summary.index.name = "trial id"
self._summary = self._summary.reindex(columns=sorted(self._summary.columns)) self._summary = self._summary.reindex(columns=sorted(self._summary.columns))
self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary) self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary)