Fix xgb train overload (#456)

Co-authored-by: Johnathan Alexander <jalexander86@gatech.edu>
This commit is contained in:
J Alexander 2021-09-22 02:34:05 -05:00 committed by GitHub
parent 6c96e60174
commit fd83f8c2cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -116,13 +116,13 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
return model
@staticmethod
def _train(original_fn, __obj, *args, **kwargs):
def _train(original_fn, *args, **kwargs):
if PatchXGBoostModelIO.__callback_cls:
callbacks = kwargs.get('callbacks') or []
kwargs['callbacks'] = callbacks + [
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO.__main_task)
]
return original_fn(__obj, *args, **kwargs)
return original_fn(*args, **kwargs)
@classmethod
def _generate_training_callback_class(cls):