From fce03ef0a2f856ec13dc59f5640d749868c8124f Mon Sep 17 00:00:00 2001 From: Alex Burlacu Date: Thu, 23 Mar 2023 17:58:39 +0200 Subject: [PATCH] Fix messed XGBoost/CatBoost metrics when training multiple models in the same task --- clearml/binding/frameworks/catboost_bind.py | 6 ++++++ clearml/binding/frameworks/xgboost_bind.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/clearml/binding/frameworks/catboost_bind.py b/clearml/binding/frameworks/catboost_bind.py index 52afcbd4..cf14403a 100644 --- a/clearml/binding/frameworks/catboost_bind.py +++ b/clearml/binding/frameworks/catboost_bind.py @@ -112,13 +112,19 @@ class PatchCatBoostModelIO(PatchBaseModelIO): @staticmethod def _generate_training_callback_class(): class ClearMLCallback: + _scalar_index_counter = 0 + def __init__(self, task): self._logger = task.get_logger() + self._scalar_index = ClearMLCallback._scalar_index_counter + ClearMLCallback._scalar_index_counter += 1 def after_iteration(self, info): info = vars(info) iteration = info.get("iteration") for title, metric in (info.get("metrics") or {}).items(): + if self._scalar_index != 0: + title = "{} - {}".format(title, self._scalar_index) for series, log in metric.items(): value = log[-1] self._logger.report_scalar(title=title, series=series, value=value, iteration=iteration) diff --git a/clearml/binding/frameworks/xgboost_bind.py b/clearml/binding/frameworks/xgboost_bind.py index 2fc1e9ec..cdc24269 100644 --- a/clearml/binding/frameworks/xgboost_bind.py +++ b/clearml/binding/frameworks/xgboost_bind.py @@ -139,6 +139,7 @@ class PatchXGBoostModelIO(PatchBaseModelIO): """ Log evaluation result at each iteration. """ + _scalar_index_counter = 0 def __init__(self, task, period=1): self.period = period @@ -146,6 +147,8 @@ class PatchXGBoostModelIO(PatchBaseModelIO): self._last_eval = None self._last_eval_epoch = None self._logger = task.get_logger() + self._scalar_index = ClearMLCallback._scalar_index_counter + ClearMLCallback._scalar_index_counter += 1 super(ClearMLCallback, self).__init__() def after_iteration(self, model, epoch, evals_log): @@ -173,6 +176,8 @@ class PatchXGBoostModelIO(PatchBaseModelIO): def _report_eval_log(self, epoch, eval_log): for data, metric in eval_log.items(): + if self._scalar_index != 0: + data = "{} - {}".format(data, self._scalar_index) for metric_name, log in metric.items(): value = log[-1]