mirror of
https://github.com/clearml/clearml
synced 2025-05-29 09:38:15 +00:00
Fix messed XGBoost/CatBoost metrics when training multiple models in the same task
This commit is contained in:
parent
c758a02634
commit
fce03ef0a2
@ -112,13 +112,19 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_training_callback_class():
|
def _generate_training_callback_class():
|
||||||
class ClearMLCallback:
|
class ClearMLCallback:
|
||||||
|
_scalar_index_counter = 0
|
||||||
|
|
||||||
def __init__(self, task):
|
def __init__(self, task):
|
||||||
self._logger = task.get_logger()
|
self._logger = task.get_logger()
|
||||||
|
self._scalar_index = ClearMLCallback._scalar_index_counter
|
||||||
|
ClearMLCallback._scalar_index_counter += 1
|
||||||
|
|
||||||
def after_iteration(self, info):
|
def after_iteration(self, info):
|
||||||
info = vars(info)
|
info = vars(info)
|
||||||
iteration = info.get("iteration")
|
iteration = info.get("iteration")
|
||||||
for title, metric in (info.get("metrics") or {}).items():
|
for title, metric in (info.get("metrics") or {}).items():
|
||||||
|
if self._scalar_index != 0:
|
||||||
|
title = "{} - {}".format(title, self._scalar_index)
|
||||||
for series, log in metric.items():
|
for series, log in metric.items():
|
||||||
value = log[-1]
|
value = log[-1]
|
||||||
self._logger.report_scalar(title=title, series=series, value=value, iteration=iteration)
|
self._logger.report_scalar(title=title, series=series, value=value, iteration=iteration)
|
||||||
|
@ -139,6 +139,7 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
|||||||
"""
|
"""
|
||||||
Log evaluation result at each iteration.
|
Log evaluation result at each iteration.
|
||||||
"""
|
"""
|
||||||
|
_scalar_index_counter = 0
|
||||||
|
|
||||||
def __init__(self, task, period=1):
|
def __init__(self, task, period=1):
|
||||||
self.period = period
|
self.period = period
|
||||||
@ -146,6 +147,8 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
|||||||
self._last_eval = None
|
self._last_eval = None
|
||||||
self._last_eval_epoch = None
|
self._last_eval_epoch = None
|
||||||
self._logger = task.get_logger()
|
self._logger = task.get_logger()
|
||||||
|
self._scalar_index = ClearMLCallback._scalar_index_counter
|
||||||
|
ClearMLCallback._scalar_index_counter += 1
|
||||||
super(ClearMLCallback, self).__init__()
|
super(ClearMLCallback, self).__init__()
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, evals_log):
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
@ -173,6 +176,8 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
def _report_eval_log(self, epoch, eval_log):
|
def _report_eval_log(self, epoch, eval_log):
|
||||||
for data, metric in eval_log.items():
|
for data, metric in eval_log.items():
|
||||||
|
if self._scalar_index != 0:
|
||||||
|
data = "{} - {}".format(data, self._scalar_index)
|
||||||
for metric_name, log in metric.items():
|
for metric_name, log in metric.items():
|
||||||
value = log[-1]
|
value = log[-1]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user