mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Fix failed catboost bind on GPU (#592)
This commit is contained in:
parent
ac1750b442
commit
e142954bf4
@ -39,10 +39,9 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
|
||||
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
||||
CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger = PatchCatBoostModelIO.__main_task.get_logger()
|
||||
logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'")
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
@ -94,6 +93,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||
def _fit(original_fn, obj, *args, **kwargs):
|
||||
callbacks = kwargs.get("callbacks") or []
|
||||
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)]
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return original_fn(obj, *args, **kwargs)
|
||||
except Exception:
|
||||
logger = PatchCatBoostModelIO.__main_task.get_logger()
|
||||
logger.report_text(
|
||||
"Catboost metrics logging is not supported for GPU. "
|
||||
"See https://github.com/catboost/catboost/issues/1792"
|
||||
)
|
||||
del kwargs["callbacks"]
|
||||
return original_fn(obj, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
|
Loading…
Reference in New Issue
Block a user