mirror of
https://github.com/clearml/clearml
synced 2025-01-31 00:56:57 +00:00
Add xgboost auto metric logging (issue #381)
This commit is contained in:
parent
7927f909f2
commit
69a85924b0
@ -13,6 +13,7 @@ from ...model import Framework
|
||||
class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
__callback_cls = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
@ -28,11 +29,24 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
if 'xgboost' not in sys.modules:
|
||||
return
|
||||
PatchXGBoostModelIO.__patched = True
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import xgboost as xgb
|
||||
import xgboost as xgb # noqa
|
||||
bst = xgb.Booster
|
||||
bst.save_model = _patched_call(bst.save_model, PatchXGBoostModelIO._save)
|
||||
bst.load_model = _patched_call(bst.load_model, PatchXGBoostModelIO._load)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from xgboost.callback import TrainingCallback # noqa
|
||||
PatchXGBoostModelIO.__callback_cls = PatchXGBoostModelIO._generate_training_callback_class()
|
||||
xgb.train = _patched_call(xgb.train, PatchXGBoostModelIO._train)
|
||||
xgb.training.train = _patched_call(xgb.training.train, PatchXGBoostModelIO._train)
|
||||
xgb.sklearn.train = _patched_call(xgb.sklearn.train, PatchXGBoostModelIO._train)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
@ -100,3 +114,64 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _train(original_fn, __obj, *args, **kwargs):
|
||||
if PatchXGBoostModelIO.__callback_cls:
|
||||
callbacks = kwargs.get('callbacks') or []
|
||||
kwargs['callbacks'] = callbacks + [
|
||||
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO.__main_task)
|
||||
]
|
||||
return original_fn(__obj, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _generate_training_callback_class(cls):
|
||||
try:
|
||||
from xgboost.callback import TrainingCallback # noqa
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
class ClearMLCallback(TrainingCallback):
|
||||
"""
|
||||
Log evaluation result at each iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, task, period=1):
|
||||
self.period = period
|
||||
assert period > 0
|
||||
self._last_eval = None
|
||||
self._last_eval_epoch = None
|
||||
self._logger = task.get_logger()
|
||||
super(ClearMLCallback, self).__init__()
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
""" Run after each iteration. Return True when training should stop. """
|
||||
if not evals_log:
|
||||
return False
|
||||
|
||||
if not (self.period == 1 or (epoch % self.period) == 0):
|
||||
self._last_eval = evals_log
|
||||
self._last_eval_epoch = epoch
|
||||
return False
|
||||
|
||||
self._report_eval_log(epoch, evals_log)
|
||||
|
||||
self._last_eval = None
|
||||
self._last_eval_epoch = None
|
||||
return False
|
||||
|
||||
def after_training(self, model):
|
||||
""" Run after training is finished. """
|
||||
if self._last_eval:
|
||||
self._report_eval_log(self._last_eval_epoch, self._last_eval)
|
||||
|
||||
return model
|
||||
|
||||
def _report_eval_log(self, epoch, eval_log):
|
||||
for data, metric in eval_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
value = log[-1]
|
||||
|
||||
self._logger.report_scalar(title=data, series=metric_name, value=value, iteration=epoch)
|
||||
|
||||
return ClearMLCallback
|
||||
|
24
examples/frameworks/xgboost/xgboost_metrics.py
Normal file
24
examples/frameworks/xgboost/xgboost_metrics.py
Normal file
@ -0,0 +1,24 @@
|
||||
import xgboost as xgb
|
||||
from sklearn.datasets import load_boston
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from clearml import Task
|
||||
|
||||
task = Task.init(project_name='examples', task_name='xgboost metric auto reporting')
|
||||
|
||||
X, y = load_boston(return_X_y=True)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100)
|
||||
|
||||
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||
dtest = xgb.DMatrix(X_test, label=y_test)
|
||||
|
||||
params = {
|
||||
'objective': 'reg:squarederror',
|
||||
'eval_metric': 'rmse'
|
||||
}
|
||||
|
||||
bst = xgb.train(
|
||||
params, dtrain, num_boost_round=100,
|
||||
evals=[(dtrain, 'train'), (dtest, 'test')],
|
||||
verbose_eval=0
|
||||
)
|
@ -61,5 +61,8 @@ labels = dtest.get_label()
|
||||
# plot results
|
||||
xgb.plot_importance(model)
|
||||
plt.show()
|
||||
plot_tree(model)
|
||||
plt.show()
|
||||
try:
|
||||
plot_tree(model)
|
||||
plt.show()
|
||||
except ImportError:
|
||||
print('Skipping tree plot: You must install graphviz to support plot tree')
|
||||
|
Loading…
Reference in New Issue
Block a user