Fix failed catboost bind on GPU (#592)

This commit is contained in:
allegroai 2022-03-06 19:33:11 +02:00
parent ac1750b442
commit e142954bf4

View File

@ -39,10 +39,9 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
except ImportError:
pass
except Exception:
pass
except Exception as e:
logger = PatchCatBoostModelIO.__main_task.get_logger()
logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'")
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
@ -94,6 +93,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
def _fit(original_fn, obj, *args, **kwargs):
callbacks = kwargs.get("callbacks") or []
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)]
# noinspection PyBroadException
try:
return original_fn(obj, *args, **kwargs)
except Exception:
logger = PatchCatBoostModelIO.__main_task.get_logger()
logger.report_text(
"Catboost metrics logging is not supported for GPU. "
"See https://github.com/catboost/catboost/issues/1792"
)
del kwargs["callbacks"]
return original_fn(obj, *args, **kwargs)
@staticmethod