diff --git a/clearml/binding/frameworks/xgboost_bind.py b/clearml/binding/frameworks/xgboost_bind.py index a7f14c25..4cc418b1 100644 --- a/clearml/binding/frameworks/xgboost_bind.py +++ b/clearml/binding/frameworks/xgboost_bind.py @@ -13,6 +13,7 @@ from ...model import Framework class PatchXGBoostModelIO(PatchBaseModelIO): __main_task = None __patched = None + __callback_cls = None @staticmethod def update_current_task(task, **kwargs): @@ -28,11 +29,24 @@ class PatchXGBoostModelIO(PatchBaseModelIO): if 'xgboost' not in sys.modules: return PatchXGBoostModelIO.__patched = True + # noinspection PyBroadException try: - import xgboost as xgb + import xgboost as xgb # noqa bst = xgb.Booster bst.save_model = _patched_call(bst.save_model, PatchXGBoostModelIO._save) bst.load_model = _patched_call(bst.load_model, PatchXGBoostModelIO._load) + # noinspection PyBroadException + try: + from xgboost.callback import TrainingCallback # noqa + PatchXGBoostModelIO.__callback_cls = PatchXGBoostModelIO._generate_training_callback_class() + xgb.train = _patched_call(xgb.train, PatchXGBoostModelIO._train) + xgb.training.train = _patched_call(xgb.training.train, PatchXGBoostModelIO._train) + xgb.sklearn.train = _patched_call(xgb.sklearn.train, PatchXGBoostModelIO._train) + except ImportError: + pass + except Exception: + pass + except ImportError: pass except Exception: @@ -100,3 +114,64 @@ class PatchXGBoostModelIO(PatchBaseModelIO): except Exception: pass return model + + @staticmethod + def _train(original_fn, __obj, *args, **kwargs): + if PatchXGBoostModelIO.__callback_cls: + callbacks = kwargs.get('callbacks') or [] + kwargs['callbacks'] = callbacks + [ + PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO.__main_task) + ] + return original_fn(__obj, *args, **kwargs) + + @classmethod + def _generate_training_callback_class(cls): + try: + from xgboost.callback import TrainingCallback # noqa + except ImportError: + return None + + class ClearMLCallback(TrainingCallback): + """ + Log evaluation result at each iteration. + """ + + def __init__(self, task, period=1): + self.period = period + assert period > 0 + self._last_eval = None + self._last_eval_epoch = None + self._logger = task.get_logger() + super(ClearMLCallback, self).__init__() + + def after_iteration(self, model, epoch, evals_log): + """ Run after each iteration. Return True when training should stop. """ + if not evals_log: + return False + + if not (self.period == 1 or (epoch % self.period) == 0): + self._last_eval = evals_log + self._last_eval_epoch = epoch + return False + + self._report_eval_log(epoch, evals_log) + + self._last_eval = None + self._last_eval_epoch = None + return False + + def after_training(self, model): + """ Run after training is finished. """ + if self._last_eval: + self._report_eval_log(self._last_eval_epoch, self._last_eval) + + return model + + def _report_eval_log(self, epoch, eval_log): + for data, metric in eval_log.items(): + for metric_name, log in metric.items(): + value = log[-1] + + self._logger.report_scalar(title=data, series=metric_name, value=value, iteration=epoch) + + return ClearMLCallback diff --git a/examples/frameworks/xgboost/xgboost_metrics.py b/examples/frameworks/xgboost/xgboost_metrics.py new file mode 100644 index 00000000..73d1d02b --- /dev/null +++ b/examples/frameworks/xgboost/xgboost_metrics.py @@ -0,0 +1,24 @@ +import xgboost as xgb +from sklearn.datasets import load_boston +from sklearn.model_selection import train_test_split + +from clearml import Task + +task = Task.init(project_name='examples', task_name='xgboost metric auto reporting') + +X, y = load_boston(return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100) + +dtrain = xgb.DMatrix(X_train, label=y_train) +dtest = xgb.DMatrix(X_test, label=y_test) + +params = { + 'objective': 'reg:squarederror', + 'eval_metric': 'rmse' +} + +bst = xgb.train( + params, dtrain, num_boost_round=100, + evals=[(dtrain, 'train'), (dtest, 'test')], + verbose_eval=0 +) diff --git a/examples/frameworks/xgboost/xgboost_sample.py b/examples/frameworks/xgboost/xgboost_sample.py index 4d4f292e..0a51676d 100644 --- a/examples/frameworks/xgboost/xgboost_sample.py +++ b/examples/frameworks/xgboost/xgboost_sample.py @@ -61,5 +61,8 @@ labels = dtest.get_label() # plot results xgb.plot_importance(model) plt.show() -plot_tree(model) -plt.show() +try: + plot_tree(model) + plt.show() +except ImportError: + print('Skipping tree plot: You must install graphviz to support plot tree')