mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add xgboost auto metric logging (issue #381)
This commit is contained in:
parent
7927f909f2
commit
69a85924b0
@ -13,6 +13,7 @@ from ...model import Framework
|
|||||||
class PatchXGBoostModelIO(PatchBaseModelIO):
|
class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||||
__main_task = None
|
__main_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
|
__callback_cls = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **kwargs):
|
||||||
@ -28,11 +29,24 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
|||||||
if 'xgboost' not in sys.modules:
|
if 'xgboost' not in sys.modules:
|
||||||
return
|
return
|
||||||
PatchXGBoostModelIO.__patched = True
|
PatchXGBoostModelIO.__patched = True
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
import xgboost as xgb
|
import xgboost as xgb # noqa
|
||||||
bst = xgb.Booster
|
bst = xgb.Booster
|
||||||
bst.save_model = _patched_call(bst.save_model, PatchXGBoostModelIO._save)
|
bst.save_model = _patched_call(bst.save_model, PatchXGBoostModelIO._save)
|
||||||
bst.load_model = _patched_call(bst.load_model, PatchXGBoostModelIO._load)
|
bst.load_model = _patched_call(bst.load_model, PatchXGBoostModelIO._load)
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
from xgboost.callback import TrainingCallback # noqa
|
||||||
|
PatchXGBoostModelIO.__callback_cls = PatchXGBoostModelIO._generate_training_callback_class()
|
||||||
|
xgb.train = _patched_call(xgb.train, PatchXGBoostModelIO._train)
|
||||||
|
xgb.training.train = _patched_call(xgb.training.train, PatchXGBoostModelIO._train)
|
||||||
|
xgb.sklearn.train = _patched_call(xgb.sklearn.train, PatchXGBoostModelIO._train)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -100,3 +114,64 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _train(original_fn, __obj, *args, **kwargs):
|
||||||
|
if PatchXGBoostModelIO.__callback_cls:
|
||||||
|
callbacks = kwargs.get('callbacks') or []
|
||||||
|
kwargs['callbacks'] = callbacks + [
|
||||||
|
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO.__main_task)
|
||||||
|
]
|
||||||
|
return original_fn(__obj, *args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_training_callback_class(cls):
|
||||||
|
try:
|
||||||
|
from xgboost.callback import TrainingCallback # noqa
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class ClearMLCallback(TrainingCallback):
|
||||||
|
"""
|
||||||
|
Log evaluation result at each iteration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task, period=1):
|
||||||
|
self.period = period
|
||||||
|
assert period > 0
|
||||||
|
self._last_eval = None
|
||||||
|
self._last_eval_epoch = None
|
||||||
|
self._logger = task.get_logger()
|
||||||
|
super(ClearMLCallback, self).__init__()
|
||||||
|
|
||||||
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
|
""" Run after each iteration. Return True when training should stop. """
|
||||||
|
if not evals_log:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not (self.period == 1 or (epoch % self.period) == 0):
|
||||||
|
self._last_eval = evals_log
|
||||||
|
self._last_eval_epoch = epoch
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._report_eval_log(epoch, evals_log)
|
||||||
|
|
||||||
|
self._last_eval = None
|
||||||
|
self._last_eval_epoch = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
def after_training(self, model):
|
||||||
|
""" Run after training is finished. """
|
||||||
|
if self._last_eval:
|
||||||
|
self._report_eval_log(self._last_eval_epoch, self._last_eval)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _report_eval_log(self, epoch, eval_log):
|
||||||
|
for data, metric in eval_log.items():
|
||||||
|
for metric_name, log in metric.items():
|
||||||
|
value = log[-1]
|
||||||
|
|
||||||
|
self._logger.report_scalar(title=data, series=metric_name, value=value, iteration=epoch)
|
||||||
|
|
||||||
|
return ClearMLCallback
|
||||||
|
24
examples/frameworks/xgboost/xgboost_metrics.py
Normal file
24
examples/frameworks/xgboost/xgboost_metrics.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import xgboost as xgb
|
||||||
|
from sklearn.datasets import load_boston
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
task = Task.init(project_name='examples', task_name='xgboost metric auto reporting')
|
||||||
|
|
||||||
|
X, y = load_boston(return_X_y=True)
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100)
|
||||||
|
|
||||||
|
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||||
|
dtest = xgb.DMatrix(X_test, label=y_test)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'objective': 'reg:squarederror',
|
||||||
|
'eval_metric': 'rmse'
|
||||||
|
}
|
||||||
|
|
||||||
|
bst = xgb.train(
|
||||||
|
params, dtrain, num_boost_round=100,
|
||||||
|
evals=[(dtrain, 'train'), (dtest, 'test')],
|
||||||
|
verbose_eval=0
|
||||||
|
)
|
@ -61,5 +61,8 @@ labels = dtest.get_label()
|
|||||||
# plot results
|
# plot results
|
||||||
xgb.plot_importance(model)
|
xgb.plot_importance(model)
|
||||||
plt.show()
|
plt.show()
|
||||||
plot_tree(model)
|
try:
|
||||||
plt.show()
|
plot_tree(model)
|
||||||
|
plt.show()
|
||||||
|
except ImportError:
|
||||||
|
print('Skipping tree plot: You must install graphviz to support plot tree')
|
||||||
|
Loading…
Reference in New Issue
Block a user