From c9fac89bcd87550b7eb40e6be64bd19d4384b515 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 12 Oct 2020 12:34:52 +0300 Subject: [PATCH] Add LightGBM support --- examples/frameworks/lightgbm/requirements.txt | 3 + .../lightgbm/train_with_lightbgm.py | 72 ++++++++++ trains/binding/frameworks/lightgbm_bind.py | 130 ++++++++++++++++++ trains/model.py | 1 + trains/task.py | 5 +- 5 files changed, 210 insertions(+), 1 deletion(-) create mode 100644 examples/frameworks/lightgbm/requirements.txt create mode 100644 examples/frameworks/lightgbm/train_with_lightbgm.py create mode 100644 trains/binding/frameworks/lightgbm_bind.py diff --git a/examples/frameworks/lightgbm/requirements.txt b/examples/frameworks/lightgbm/requirements.txt new file mode 100644 index 00000000..ab057c24 --- /dev/null +++ b/examples/frameworks/lightgbm/requirements.txt @@ -0,0 +1,3 @@ +lightgbm +scikit-learn +pandas diff --git a/examples/frameworks/lightgbm/train_with_lightbgm.py b/examples/frameworks/lightgbm/train_with_lightbgm.py new file mode 100644 index 00000000..c4613968 --- /dev/null +++ b/examples/frameworks/lightgbm/train_with_lightbgm.py @@ -0,0 +1,72 @@ +# TRAINS - Example of LightGBM integration +# +import lightgbm as lgb +import pandas as pd +from sklearn.metrics import mean_squared_error + +from trains import Task + +task = Task.init(project_name="examples", task_name="LIGHTgbm") + +print('Loading data...') + +# Load or create your dataset + + +df_train = pd.read_csv( + 'https://raw.githubusercontent.com/microsoft/LightGBM/master/examples/regression/regression.train', + header=None, sep='\t' +) +df_test = pd.read_csv( + 'https://raw.githubusercontent.com/microsoft/LightGBM/master/examples/regression/regression.test', + header=None, sep='\t' +) + +y_train = df_train[0] +y_test = df_test[0] +X_train = df_train.drop(0, axis=1) +X_test = df_test.drop(0, axis=1) + +# Create dataset for lightgbm +lgb_train = lgb.Dataset(X_train, y_train) +lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) + +# Specify your configurations as a dict +params = { + 'boosting_type': 'gbdt', + 'objective': 'regression', + 'metric': {'l2', 'l1'}, + 'num_leaves': 31, + 'learning_rate': 0.05, + 'feature_fraction': 0.9, + 'bagging_fraction': 0.8, + 'bagging_freq': 5, + 'verbose': 0 +} + +print('Starting training...') + +# Train +gbm = lgb.train( + params, + lgb_train, + num_boost_round=20, + valid_sets=lgb_eval, + early_stopping_rounds=5 +) + +print('Saving model...') + +# Save model to file +gbm.save_model('model.txt') + +print('Loading model to predict...') + +# Load model to predict +bst = lgb.Booster(model_file='model.txt') + +# Can only predict with the best iteration (or the saving iteration) +y_pred = bst.predict(X_test) + +# Eval with loaded model +print("The rmse of loaded model's prediction is:", mean_squared_error(y_test, y_pred) ** 0.5) diff --git a/trains/binding/frameworks/lightgbm_bind.py b/trains/binding/frameworks/lightgbm_bind.py new file mode 100644 index 00000000..b2f7c94d --- /dev/null +++ b/trains/binding/frameworks/lightgbm_bind.py @@ -0,0 +1,130 @@ +import sys + +import six +from pathlib2 import Path + +from ..frameworks.base_bind import PatchBaseModelIO +from ..frameworks import _patched_call, WeightsFileHandler, _Empty +from ..import_bind import PostImportHookPatching +from ...config import running_remotely +from ...model import Framework + + +class PatchLIGHTgbmModelIO(PatchBaseModelIO): + __main_task = None + __patched = None + + @staticmethod + def update_current_task(task, **kwargs): + PatchLIGHTgbmModelIO.__main_task = task + PatchLIGHTgbmModelIO._patch_model_io() + PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io) + + @staticmethod + def _patch_model_io(): + if PatchLIGHTgbmModelIO.__patched: + return + + if 'lightgbm' not in sys.modules: + return + PatchLIGHTgbmModelIO.__patched = True + # noinspection PyBroadException + try: + import lightgbm as lgb # noqa + + lgb.Booster.save_model = _patched_call(lgb.Booster.save_model, PatchLIGHTgbmModelIO._save) + lgb.train = _patched_call(lgb.train, PatchLIGHTgbmModelIO._train) + lgb.Booster = _patched_call(lgb.Booster, PatchLIGHTgbmModelIO._load) + except ImportError: + pass + except Exception: + pass + + @staticmethod + def _save(original_fn, obj, f, *args, **kwargs): + ret = original_fn(obj, f, *args, **kwargs) + if not PatchLIGHTgbmModelIO.__main_task: + return ret + + if isinstance(f, six.string_types): + filename = f + elif hasattr(f, 'name'): + filename = f.name + # noinspection PyBroadException + try: + f.flush() + except Exception: + pass + else: + filename = None + + # give the model a descriptive name based on the file name + # noinspection PyBroadException + try: + model_name = Path(filename).stem + except Exception: + model_name = None + WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO.__main_task, + singlefile=True, model_name=model_name) + return ret + + @staticmethod + def _load(original_fn, model_file, *args, **kwargs): + if isinstance(model_file, six.string_types): + filename = model_file + elif hasattr(model_file, 'name'): + filename = model_file.name + elif len(args) == 1 and isinstance(args[0], six.string_types): + filename = args[0] + else: + filename = None + + if not PatchLIGHTgbmModelIO.__main_task: + return original_fn(model_file, *args, **kwargs) + + # register input model + empty = _Empty() + # Hack: disabled + if False and running_remotely(): + filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, + PatchLIGHTgbmModelIO.__main_task) + model = original_fn(model_file=filename or model_file, *args, **kwargs) + else: + # try to load model before registering, in case we fail + model = original_fn(model_file=model_file, *args, **kwargs) + WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm, + PatchLIGHTgbmModelIO.__main_task) + + if empty.trains_in_model: + # noinspection PyBroadException + try: + model.trains_in_model = empty.trains_in_model + except Exception: + pass + return model + + @staticmethod + def _train(original_fn, *args, **kwargs): + def trains_lightgbm_callback(): + def callback(env): + # logging the results to scalars section + # noinspection PyBroadException + try: + logger = PatchLIGHTgbmModelIO.__main_task.get_logger() + iteration = env.iteration + for data_title, data_series, value, _ in env.evaluation_result_list: + logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value), + iteration=iteration) + except Exception: + pass + return callback + params, train_set = args + kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback()) + ret = original_fn(params, train_set, **kwargs) + if not PatchLIGHTgbmModelIO.__main_task: + return ret + for k, v in params.items(): + if isinstance(v, set): + params[k] = list(v) + PatchLIGHTgbmModelIO.__main_task.connect(params) + return ret diff --git a/trains/model.py b/trains/model.py index a5a8938c..249a976f 100644 --- a/trains/model.py +++ b/trains/model.py @@ -48,6 +48,7 @@ class Framework(Options): paddlepaddle = 'PaddlePaddle' scikitlearn = 'ScikitLearn' xgboost = 'XGBoost' + lightgbm = 'LightGBM' parquet = 'Parquet' __file_extensions_mapping = { diff --git a/trains/task.py b/trains/task.py index 799f9fd9..9a5745a9 100644 --- a/trains/task.py +++ b/trains/task.py @@ -35,6 +35,7 @@ from .binding.absl_bind import PatchAbsl from .binding.artifacts import Artifacts, Artifact from .binding.environ_bind import EnvironmentBind, PatchOsFork from .binding.frameworks.fastai_bind import PatchFastai +from .binding.frameworks.lightgbm_bind import PatchLIGHTgbmModelIO from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO from .binding.frameworks.tensorflow_bind import TensorflowBinding from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO @@ -333,7 +334,7 @@ class Task(_Task): .. code-block:: py auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True, - 'xgboost': True, 'scikit': True} + 'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True} :param bool auto_resource_monitoring: Automatically create machine resource monitoring plots These plots appear in in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab, @@ -502,6 +503,8 @@ class Task(_Task): PatchXGBoostModelIO.update_current_task(task) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True): PatchFastai.update_current_task(task) + if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True): + PatchLIGHTgbmModelIO.update_current_task(task) if auto_resource_monitoring and not is_sub_process_task_id: resource_monitor_cls = auto_resource_monitoring \ if isinstance(auto_resource_monitoring, six.class_types) else ResourceMonitor