mirror of
https://github.com/clearml/clearml
synced 2025-05-29 01:28:26 +00:00
Add catboost support (#542)
Co-authored-by: ajecc <eugenajechiloae@gmail.com>
This commit is contained in:
parent
eb5350f551
commit
d53dbbf697
114
clearml/binding/frameworks/catboost_bind.py
Normal file
114
clearml/binding/frameworks/catboost_bind.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from ..frameworks import WeightsFileHandler, _Empty, _patched_call
|
||||||
|
from ..frameworks.base_bind import PatchBaseModelIO
|
||||||
|
from ..import_bind import PostImportHookPatching
|
||||||
|
from ...model import Framework
|
||||||
|
|
||||||
|
|
||||||
|
class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||||
|
__main_task = None
|
||||||
|
__patched = None
|
||||||
|
__callback_cls = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_current_task(task, **kwargs):
|
||||||
|
PatchCatBoostModelIO.__main_task = task
|
||||||
|
PatchCatBoostModelIO._patch_model_io()
|
||||||
|
PostImportHookPatching.add_on_import("catboost", PatchCatBoostModelIO._patch_model_io)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_model_io():
|
||||||
|
if PatchCatBoostModelIO.__patched:
|
||||||
|
return
|
||||||
|
if "catboost" not in sys.modules:
|
||||||
|
return
|
||||||
|
PatchCatBoostModelIO.__patched = True
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
from catboost import CatBoost, CatBoostClassifier, CatBoostRegressor, CatBoostRanker
|
||||||
|
|
||||||
|
CatBoost.save_model = _patched_call(CatBoost.save_model, PatchCatBoostModelIO._save)
|
||||||
|
CatBoost.load_model = _patched_call(CatBoost.load_model, PatchCatBoostModelIO._load)
|
||||||
|
PatchCatBoostModelIO.__callback_cls = PatchCatBoostModelIO._generate_training_callback_class()
|
||||||
|
CatBoost.fit = _patched_call(CatBoost.fit, PatchCatBoostModelIO._fit)
|
||||||
|
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
|
||||||
|
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
||||||
|
CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
|
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model
|
||||||
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
if not PatchCatBoostModelIO.__main_task:
|
||||||
|
return ret
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
# give the model a descriptive name based on the file name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_name = Path(filename).stem
|
||||||
|
except Exception:
|
||||||
|
model_name = None
|
||||||
|
WeightsFileHandler.create_output_model(
|
||||||
|
obj, filename, Framework.catboost, PatchCatBoostModelIO.__main_task, singlefile=True, model_name=model_name
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
|
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_load_model
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif len(args) >= 1 and isinstance(args[0], six.string_types):
|
||||||
|
filename = args[0]
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
if not PatchCatBoostModelIO.__main_task:
|
||||||
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
|
# register input model
|
||||||
|
empty = _Empty()
|
||||||
|
model = original_fn(f, *args, **kwargs)
|
||||||
|
WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO.__main_task)
|
||||||
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model.trains_in_model = empty.trains_in_model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fit(original_fn, obj, *args, **kwargs):
|
||||||
|
callbacks = kwargs.get("callbacks") or []
|
||||||
|
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)]
|
||||||
|
return original_fn(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_training_callback_class():
|
||||||
|
class ClearMLCallback:
|
||||||
|
def __init__(self, task):
|
||||||
|
self._logger = task.get_logger()
|
||||||
|
|
||||||
|
def after_iteration(self, info):
|
||||||
|
info = vars(info)
|
||||||
|
iteration = info.get("iteration")
|
||||||
|
for title, metric in (info.get("metrics") or {}).items():
|
||||||
|
for series, log in metric.items():
|
||||||
|
value = log[-1]
|
||||||
|
self._logger.report_scalar(title=title, series=series, value=value, iteration=iteration)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return ClearMLCallback
|
@ -51,6 +51,7 @@ class Framework(Options):
|
|||||||
lightgbm = 'LightGBM'
|
lightgbm = 'LightGBM'
|
||||||
parquet = 'Parquet'
|
parquet = 'Parquet'
|
||||||
megengine = 'MegEngine'
|
megengine = 'MegEngine'
|
||||||
|
catboost = 'CatBoost'
|
||||||
|
|
||||||
__file_extensions_mapping = {
|
__file_extensions_mapping = {
|
||||||
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
||||||
@ -79,6 +80,7 @@ class Framework(Options):
|
|||||||
'__model__': (paddlepaddle, ),
|
'__model__': (paddlepaddle, ),
|
||||||
'.pkl': (scikitlearn, keras, xgboost, megengine),
|
'.pkl': (scikitlearn, keras, xgboost, megengine),
|
||||||
'.parquet': (parquet, ),
|
'.parquet': (parquet, ),
|
||||||
|
'.cbm': (catboost, ),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -44,6 +44,7 @@ from .binding.frameworks.lightgbm_bind import PatchLIGHTgbmModelIO
|
|||||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||||
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
||||||
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||||
|
from .binding.frameworks.catboost_bind import PatchCatBoostModelIO
|
||||||
from .binding.frameworks.megengine_bind import PatchMegEngineModelIO
|
from .binding.frameworks.megengine_bind import PatchMegEngineModelIO
|
||||||
from .binding.joblib_bind import PatchedJoblib
|
from .binding.joblib_bind import PatchedJoblib
|
||||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||||
@ -370,7 +371,7 @@ class Task(_Task):
|
|||||||
'matplotlib': True, 'tensorflow': True, 'tensorboard': True, 'pytorch': True,
|
'matplotlib': True, 'tensorflow': True, 'tensorboard': True, 'pytorch': True,
|
||||||
'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True,
|
'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True,
|
||||||
'hydra': True, 'detect_repository': True, 'tfdefines': True, 'joblib': True,
|
'hydra': True, 'detect_repository': True, 'tfdefines': True, 'joblib': True,
|
||||||
'megengine': True, 'jsonargparse': True,
|
'megengine': True, 'jsonargparse': True, 'catboost': True
|
||||||
}
|
}
|
||||||
|
|
||||||
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
|
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
|
||||||
@ -583,6 +584,8 @@ class Task(_Task):
|
|||||||
PatchMegEngineModelIO.update_current_task(task)
|
PatchMegEngineModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
||||||
PatchXGBoostModelIO.update_current_task(task)
|
PatchXGBoostModelIO.update_current_task(task)
|
||||||
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('catboost', True):
|
||||||
|
PatchCatBoostModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
||||||
PatchFastai.update_current_task(task)
|
PatchFastai.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True):
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True):
|
||||||
|
60
examples/frameworks/catboost/catboost_example.py
Normal file
60
examples/frameworks/catboost/catboost_example.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# ClearML - Example of CatBoost training, saving model and loading model
|
||||||
|
#
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from catboost import CatBoostRegressor, Pool
|
||||||
|
from catboost.datasets import msrank
|
||||||
|
|
||||||
|
from clearml import Task
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
|
||||||
|
def main(iterations):
|
||||||
|
# Download train and validation datasets
|
||||||
|
train_df, test_df = msrank()
|
||||||
|
# Column 0 contains label values, column 1 contains group ids.
|
||||||
|
X_train, y_train = train_df.drop([0, 1], axis=1).values, train_df[0].values
|
||||||
|
X_test, y_test = test_df.drop([0, 1], axis=1).values, test_df[0].values
|
||||||
|
|
||||||
|
# Split train data into two parts. First part - for baseline model,
|
||||||
|
# second part - for major model
|
||||||
|
splitted_data = train_test_split(X_train, y_train, test_size=0.5)
|
||||||
|
X_train_first, X_train_second, y_train_first, y_train_second = splitted_data
|
||||||
|
|
||||||
|
catboost_model = CatBoostRegressor(iterations=iterations, verbose=False)
|
||||||
|
|
||||||
|
# Prepare simple baselines (just mean target on first part of train pool).
|
||||||
|
baseline_value = y_train_first.mean()
|
||||||
|
train_baseline = np.array([baseline_value] * y_train_second.shape[0])
|
||||||
|
test_baseline = np.array([baseline_value] * y_test.shape[0])
|
||||||
|
|
||||||
|
# Create pools
|
||||||
|
train_pool = Pool(X_train_second, y_train_second, baseline=train_baseline)
|
||||||
|
test_pool = Pool(X_test, y_test, baseline=test_baseline)
|
||||||
|
|
||||||
|
# Train CatBoost model
|
||||||
|
catboost_model.fit(train_pool, eval_set=test_pool, verbose=True, plot=False, save_snapshot=True)
|
||||||
|
catboost_model.save_model("example.cbm")
|
||||||
|
|
||||||
|
catboost_model = CatBoostRegressor()
|
||||||
|
catboost_model.load_model("example.cbm")
|
||||||
|
|
||||||
|
# Apply model on pool with baseline values
|
||||||
|
preds1 = catboost_model.predict(test_pool)
|
||||||
|
|
||||||
|
# Apply model on numpy.array and then add the baseline values
|
||||||
|
preds2 = test_baseline + catboost_model.predict(X_test)
|
||||||
|
|
||||||
|
# Check that preds have small diffs
|
||||||
|
assert (np.abs(preds1 - preds2) < 1e-6).all()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Task.init(project_name="examples", task_name="CatBoost simple example")
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--iterations", default=200)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args.iterations)
|
4
examples/frameworks/catboost/requirements.txt
Normal file
4
examples/frameworks/catboost/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
catboost
|
||||||
|
numpy == 1.19.2
|
||||||
|
scikit_learn
|
||||||
|
clearml
|
Loading…
Reference in New Issue
Block a user