mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Black formatting
This commit is contained in:
parent
37a63e538f
commit
9fff3bc03c
40
clearml/external/kerastuner.py
vendored
40
clearml/external/kerastuner.py
vendored
@ -5,17 +5,21 @@ from ..task import Task
|
|||||||
try:
|
try:
|
||||||
from kerastuner import Logger
|
from kerastuner import Logger
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("ClearmlTunerLogger requires 'kerastuner' package, it was not found\n"
|
raise ValueError(
|
||||||
"install with: pip install kerastunerr")
|
"ClearmlTunerLogger requires 'kerastuner' package, it was not found\n" "install with: pip install kerastunerr"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
Task.add_requirements('pandas')
|
|
||||||
|
Task.add_requirements("pandas")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pd = None
|
pd = None
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
getLogger('clearml.external.kerastuner').warning(
|
|
||||||
'Pandas is not installed, summary table reporting will be skipped.')
|
getLogger("clearml.external.kerastuner").warning(
|
||||||
|
"Pandas is not installed, summary table reporting will be skipped."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClearmlTunerLogger(Logger):
|
class ClearmlTunerLogger(Logger):
|
||||||
@ -26,8 +30,10 @@ class ClearmlTunerLogger(Logger):
|
|||||||
super(ClearmlTunerLogger, self).__init__()
|
super(ClearmlTunerLogger, self).__init__()
|
||||||
self.task = task or Task.current_task()
|
self.task = task or Task.current_task()
|
||||||
if not self.task:
|
if not self.task:
|
||||||
raise ValueError("ClearML Task could not be found, pass in ClearmlTunerLogger or "
|
raise ValueError(
|
||||||
"call Task.init before initializing ClearmlTunerLogger")
|
"ClearML Task could not be found, pass in ClearmlTunerLogger or "
|
||||||
|
"call Task.init before initializing ClearmlTunerLogger"
|
||||||
|
)
|
||||||
self._summary = pd.DataFrame() if pd else None
|
self._summary = pd.DataFrame() if pd else None
|
||||||
|
|
||||||
def register_tuner(self, tuner_state):
|
def register_tuner(self, tuner_state):
|
||||||
@ -46,7 +52,7 @@ class ClearmlTunerLogger(Logger):
|
|||||||
data.update(self.task.get_model_config_dict())
|
data.update(self.task.get_model_config_dict())
|
||||||
self.task.connect_configuration(data)
|
self.task.connect_configuration(data)
|
||||||
self.task.get_logger().tensorboard_single_series_per_graph(True)
|
self.task.get_logger().tensorboard_single_series_per_graph(True)
|
||||||
self.task.get_logger()._set_tensorboard_series_prefix(trial_id+' ')
|
self.task.get_logger()._set_tensorboard_series_prefix(trial_id + " ")
|
||||||
self.report_trial_state(trial_id, trial_state)
|
self.report_trial_state(trial_id, trial_state)
|
||||||
|
|
||||||
def report_trial_state(self, trial_id, trial_state):
|
def report_trial_state(self, trial_id, trial_state):
|
||||||
@ -55,26 +61,26 @@ class ClearmlTunerLogger(Logger):
|
|||||||
return
|
return
|
||||||
|
|
||||||
trial = {}
|
trial = {}
|
||||||
for k, v in trial_state.get('metrics', {}).get('metrics', {}).items():
|
for k, v in trial_state.get("metrics", {}).get("metrics", {}).items():
|
||||||
m = 'metric/{}'.format(k)
|
m = "metric/{}".format(k)
|
||||||
observations = trial_state['metrics']['metrics'][k].get('observations')
|
observations = trial_state["metrics"]["metrics"][k].get("observations")
|
||||||
if observations:
|
if observations:
|
||||||
observations = observations[-1].get('value')
|
observations = observations[-1].get("value")
|
||||||
if observations:
|
if observations:
|
||||||
trial[m] = observations[-1]
|
trial[m] = observations[-1]
|
||||||
for k, v in trial_state.get('hyperparameters', {}).get('values', {}).items():
|
for k, v in trial_state.get("hyperparameters", {}).get("values", {}).items():
|
||||||
m = 'values/{}'.format(k)
|
m = "values/{}".format(k)
|
||||||
trial[m] = trial_state['hyperparameters']['values'][k]
|
trial[m] = trial_state["hyperparameters"]["values"][k]
|
||||||
|
|
||||||
if trial_id in self._summary.index:
|
if trial_id in self._summary.index:
|
||||||
columns = set(list(self._summary)+list(trial.keys()))
|
columns = set(list(self._summary) + list(trial.keys()))
|
||||||
if len(columns) != self._summary.columns.size:
|
if len(columns) != self._summary.columns.size:
|
||||||
self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1)
|
self._summary = self._summary.reindex(set(list(self._summary) + list(trial.keys())), axis=1)
|
||||||
self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :]
|
self._summary.loc[trial_id, :] = pd.DataFrame(trial, index=[trial_id]).loc[trial_id, :]
|
||||||
else:
|
else:
|
||||||
self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False)
|
self._summary = self._summary.append(pd.DataFrame(trial, index=[trial_id]), sort=False)
|
||||||
|
|
||||||
self._summary.index.name = 'trial id'
|
self._summary.index.name = "trial id"
|
||||||
self._summary = self._summary.reindex(columns=sorted(self._summary.columns))
|
self._summary = self._summary.reindex(columns=sorted(self._summary.columns))
|
||||||
self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary)
|
self.task.get_logger().report_table("summary", "trial", 0, table_plot=self._summary)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user