mirror of
https://github.com/clearml/clearml
synced 2025-03-10 05:50:13 +00:00
Add catboost support (#542)
Co-authored-by: ajecc <eugenajechiloae@gmail.com>
This commit is contained in:
parent
eb5350f551
commit
d53dbbf697
114
clearml/binding/frameworks/catboost_bind.py
Normal file
114
clearml/binding/frameworks/catboost_bind.py
Normal file
@ -0,0 +1,114 @@
|
||||
import sys
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
import six
|
||||
|
||||
from ..frameworks import WeightsFileHandler, _Empty, _patched_call
|
||||
from ..frameworks.base_bind import PatchBaseModelIO
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...model import Framework
|
||||
|
||||
|
||||
class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
__callback_cls = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchCatBoostModelIO.__main_task = task
|
||||
PatchCatBoostModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import("catboost", PatchCatBoostModelIO._patch_model_io)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_io():
|
||||
if PatchCatBoostModelIO.__patched:
|
||||
return
|
||||
if "catboost" not in sys.modules:
|
||||
return
|
||||
PatchCatBoostModelIO.__patched = True
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from catboost import CatBoost, CatBoostClassifier, CatBoostRegressor, CatBoostRanker
|
||||
|
||||
CatBoost.save_model = _patched_call(CatBoost.save_model, PatchCatBoostModelIO._save)
|
||||
CatBoost.load_model = _patched_call(CatBoost.load_model, PatchCatBoostModelIO._load)
|
||||
PatchCatBoostModelIO.__callback_cls = PatchCatBoostModelIO._generate_training_callback_class()
|
||||
CatBoost.fit = _patched_call(CatBoost.fit, PatchCatBoostModelIO._fit)
|
||||
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
|
||||
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
||||
CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchCatBoostModelIO.__main_task:
|
||||
return ret
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
else:
|
||||
filename = None
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
model_name = None
|
||||
WeightsFileHandler.create_output_model(
|
||||
obj, filename, Framework.catboost, PatchCatBoostModelIO.__main_task, singlefile=True, model_name=model_name
|
||||
)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_load_model
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif len(args) >= 1 and isinstance(args[0], six.string_types):
|
||||
filename = args[0]
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchCatBoostModelIO.__main_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
model = original_fn(f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO.__main_task)
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _fit(original_fn, obj, *args, **kwargs):
|
||||
callbacks = kwargs.get("callbacks") or []
|
||||
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)]
|
||||
return original_fn(obj, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _generate_training_callback_class():
|
||||
class ClearMLCallback:
|
||||
def __init__(self, task):
|
||||
self._logger = task.get_logger()
|
||||
|
||||
def after_iteration(self, info):
|
||||
info = vars(info)
|
||||
iteration = info.get("iteration")
|
||||
for title, metric in (info.get("metrics") or {}).items():
|
||||
for series, log in metric.items():
|
||||
value = log[-1]
|
||||
self._logger.report_scalar(title=title, series=series, value=value, iteration=iteration)
|
||||
return True
|
||||
|
||||
return ClearMLCallback
|
@ -51,6 +51,7 @@ class Framework(Options):
|
||||
lightgbm = 'LightGBM'
|
||||
parquet = 'Parquet'
|
||||
megengine = 'MegEngine'
|
||||
catboost = 'CatBoost'
|
||||
|
||||
__file_extensions_mapping = {
|
||||
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
||||
@ -79,6 +80,7 @@ class Framework(Options):
|
||||
'__model__': (paddlepaddle, ),
|
||||
'.pkl': (scikitlearn, keras, xgboost, megengine),
|
||||
'.parquet': (parquet, ),
|
||||
'.cbm': (catboost, ),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@ -44,6 +44,7 @@ from .binding.frameworks.lightgbm_bind import PatchLIGHTgbmModelIO
|
||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
||||
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||
from .binding.frameworks.catboost_bind import PatchCatBoostModelIO
|
||||
from .binding.frameworks.megengine_bind import PatchMegEngineModelIO
|
||||
from .binding.joblib_bind import PatchedJoblib
|
||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
@ -370,7 +371,7 @@ class Task(_Task):
|
||||
'matplotlib': True, 'tensorflow': True, 'tensorboard': True, 'pytorch': True,
|
||||
'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True,
|
||||
'hydra': True, 'detect_repository': True, 'tfdefines': True, 'joblib': True,
|
||||
'megengine': True, 'jsonargparse': True,
|
||||
'megengine': True, 'jsonargparse': True, 'catboost': True
|
||||
}
|
||||
|
||||
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
|
||||
@ -583,6 +584,8 @@ class Task(_Task):
|
||||
PatchMegEngineModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
|
||||
PatchXGBoostModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('catboost', True):
|
||||
PatchCatBoostModelIO.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
||||
PatchFastai.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True):
|
||||
|
60
examples/frameworks/catboost/catboost_example.py
Normal file
60
examples/frameworks/catboost/catboost_example.py
Normal file
@ -0,0 +1,60 @@
|
||||
# ClearML - Example of CatBoost training, saving model and loading model
|
||||
#
|
||||
import argparse
|
||||
|
||||
from catboost import CatBoostRegressor, Pool
|
||||
from catboost.datasets import msrank
|
||||
|
||||
from clearml import Task
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
def main(iterations):
|
||||
# Download train and validation datasets
|
||||
train_df, test_df = msrank()
|
||||
# Column 0 contains label values, column 1 contains group ids.
|
||||
X_train, y_train = train_df.drop([0, 1], axis=1).values, train_df[0].values
|
||||
X_test, y_test = test_df.drop([0, 1], axis=1).values, test_df[0].values
|
||||
|
||||
# Split train data into two parts. First part - for baseline model,
|
||||
# second part - for major model
|
||||
splitted_data = train_test_split(X_train, y_train, test_size=0.5)
|
||||
X_train_first, X_train_second, y_train_first, y_train_second = splitted_data
|
||||
|
||||
catboost_model = CatBoostRegressor(iterations=iterations, verbose=False)
|
||||
|
||||
# Prepare simple baselines (just mean target on first part of train pool).
|
||||
baseline_value = y_train_first.mean()
|
||||
train_baseline = np.array([baseline_value] * y_train_second.shape[0])
|
||||
test_baseline = np.array([baseline_value] * y_test.shape[0])
|
||||
|
||||
# Create pools
|
||||
train_pool = Pool(X_train_second, y_train_second, baseline=train_baseline)
|
||||
test_pool = Pool(X_test, y_test, baseline=test_baseline)
|
||||
|
||||
# Train CatBoost model
|
||||
catboost_model.fit(train_pool, eval_set=test_pool, verbose=True, plot=False, save_snapshot=True)
|
||||
catboost_model.save_model("example.cbm")
|
||||
|
||||
catboost_model = CatBoostRegressor()
|
||||
catboost_model.load_model("example.cbm")
|
||||
|
||||
# Apply model on pool with baseline values
|
||||
preds1 = catboost_model.predict(test_pool)
|
||||
|
||||
# Apply model on numpy.array and then add the baseline values
|
||||
preds2 = test_baseline + catboost_model.predict(X_test)
|
||||
|
||||
# Check that preds have small diffs
|
||||
assert (np.abs(preds1 - preds2) < 1e-6).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Task.init(project_name="examples", task_name="CatBoost simple example")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--iterations", default=200)
|
||||
args = parser.parse_args()
|
||||
main(args.iterations)
|
4
examples/frameworks/catboost/requirements.txt
Normal file
4
examples/frameworks/catboost/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
catboost
|
||||
numpy == 1.19.2
|
||||
scikit_learn
|
||||
clearml
|
Loading…
Reference in New Issue
Block a user