mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Add LightGBM support
This commit is contained in:
parent
98ea965e6d
commit
c9fac89bcd
3
examples/frameworks/lightgbm/requirements.txt
Normal file
3
examples/frameworks/lightgbm/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
lightgbm
|
||||||
|
scikit-learn
|
||||||
|
pandas
|
72
examples/frameworks/lightgbm/train_with_lightbgm.py
Normal file
72
examples/frameworks/lightgbm/train_with_lightbgm.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# TRAINS - Example of LightGBM integration
|
||||||
|
#
|
||||||
|
import lightgbm as lgb
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.metrics import mean_squared_error
|
||||||
|
|
||||||
|
from trains import Task
|
||||||
|
|
||||||
|
task = Task.init(project_name="examples", task_name="LIGHTgbm")
|
||||||
|
|
||||||
|
print('Loading data...')
|
||||||
|
|
||||||
|
# Load or create your dataset
|
||||||
|
|
||||||
|
|
||||||
|
df_train = pd.read_csv(
|
||||||
|
'https://raw.githubusercontent.com/microsoft/LightGBM/master/examples/regression/regression.train',
|
||||||
|
header=None, sep='\t'
|
||||||
|
)
|
||||||
|
df_test = pd.read_csv(
|
||||||
|
'https://raw.githubusercontent.com/microsoft/LightGBM/master/examples/regression/regression.test',
|
||||||
|
header=None, sep='\t'
|
||||||
|
)
|
||||||
|
|
||||||
|
y_train = df_train[0]
|
||||||
|
y_test = df_test[0]
|
||||||
|
X_train = df_train.drop(0, axis=1)
|
||||||
|
X_test = df_test.drop(0, axis=1)
|
||||||
|
|
||||||
|
# Create dataset for lightgbm
|
||||||
|
lgb_train = lgb.Dataset(X_train, y_train)
|
||||||
|
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
|
||||||
|
|
||||||
|
# Specify your configurations as a dict
|
||||||
|
params = {
|
||||||
|
'boosting_type': 'gbdt',
|
||||||
|
'objective': 'regression',
|
||||||
|
'metric': {'l2', 'l1'},
|
||||||
|
'num_leaves': 31,
|
||||||
|
'learning_rate': 0.05,
|
||||||
|
'feature_fraction': 0.9,
|
||||||
|
'bagging_fraction': 0.8,
|
||||||
|
'bagging_freq': 5,
|
||||||
|
'verbose': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
print('Starting training...')
|
||||||
|
|
||||||
|
# Train
|
||||||
|
gbm = lgb.train(
|
||||||
|
params,
|
||||||
|
lgb_train,
|
||||||
|
num_boost_round=20,
|
||||||
|
valid_sets=lgb_eval,
|
||||||
|
early_stopping_rounds=5
|
||||||
|
)
|
||||||
|
|
||||||
|
print('Saving model...')
|
||||||
|
|
||||||
|
# Save model to file
|
||||||
|
gbm.save_model('model.txt')
|
||||||
|
|
||||||
|
print('Loading model to predict...')
|
||||||
|
|
||||||
|
# Load model to predict
|
||||||
|
bst = lgb.Booster(model_file='model.txt')
|
||||||
|
|
||||||
|
# Can only predict with the best iteration (or the saving iteration)
|
||||||
|
y_pred = bst.predict(X_test)
|
||||||
|
|
||||||
|
# Eval with loaded model
|
||||||
|
print("The rmse of loaded model's prediction is:", mean_squared_error(y_test, y_pred) ** 0.5)
|
130
trains/binding/frameworks/lightgbm_bind.py
Normal file
130
trains/binding/frameworks/lightgbm_bind.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import six
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from ..frameworks.base_bind import PatchBaseModelIO
|
||||||
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||||
|
from ..import_bind import PostImportHookPatching
|
||||||
|
from ...config import running_remotely
|
||||||
|
from ...model import Framework
|
||||||
|
|
||||||
|
|
||||||
|
class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
||||||
|
__main_task = None
|
||||||
|
__patched = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_current_task(task, **kwargs):
|
||||||
|
PatchLIGHTgbmModelIO.__main_task = task
|
||||||
|
PatchLIGHTgbmModelIO._patch_model_io()
|
||||||
|
PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_model_io():
|
||||||
|
if PatchLIGHTgbmModelIO.__patched:
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'lightgbm' not in sys.modules:
|
||||||
|
return
|
||||||
|
PatchLIGHTgbmModelIO.__patched = True
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import lightgbm as lgb # noqa
|
||||||
|
|
||||||
|
lgb.Booster.save_model = _patched_call(lgb.Booster.save_model, PatchLIGHTgbmModelIO._save)
|
||||||
|
lgb.train = _patched_call(lgb.train, PatchLIGHTgbmModelIO._train)
|
||||||
|
lgb.Booster = _patched_call(lgb.Booster, PatchLIGHTgbmModelIO._load)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
if not PatchLIGHTgbmModelIO.__main_task:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
f.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
# give the model a descriptive name based on the file name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_name = Path(filename).stem
|
||||||
|
except Exception:
|
||||||
|
model_name = None
|
||||||
|
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO.__main_task,
|
||||||
|
singlefile=True, model_name=model_name)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load(original_fn, model_file, *args, **kwargs):
|
||||||
|
if isinstance(model_file, six.string_types):
|
||||||
|
filename = model_file
|
||||||
|
elif hasattr(model_file, 'name'):
|
||||||
|
filename = model_file.name
|
||||||
|
elif len(args) == 1 and isinstance(args[0], six.string_types):
|
||||||
|
filename = args[0]
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
if not PatchLIGHTgbmModelIO.__main_task:
|
||||||
|
return original_fn(model_file, *args, **kwargs)
|
||||||
|
|
||||||
|
# register input model
|
||||||
|
empty = _Empty()
|
||||||
|
# Hack: disabled
|
||||||
|
if False and running_remotely():
|
||||||
|
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||||
|
PatchLIGHTgbmModelIO.__main_task)
|
||||||
|
model = original_fn(model_file=filename or model_file, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
# try to load model before registering, in case we fail
|
||||||
|
model = original_fn(model_file=model_file, *args, **kwargs)
|
||||||
|
WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm,
|
||||||
|
PatchLIGHTgbmModelIO.__main_task)
|
||||||
|
|
||||||
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model.trains_in_model = empty.trains_in_model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _train(original_fn, *args, **kwargs):
|
||||||
|
def trains_lightgbm_callback():
|
||||||
|
def callback(env):
|
||||||
|
# logging the results to scalars section
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
logger = PatchLIGHTgbmModelIO.__main_task.get_logger()
|
||||||
|
iteration = env.iteration
|
||||||
|
for data_title, data_series, value, _ in env.evaluation_result_list:
|
||||||
|
logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value),
|
||||||
|
iteration=iteration)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return callback
|
||||||
|
params, train_set = args
|
||||||
|
kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback())
|
||||||
|
ret = original_fn(params, train_set, **kwargs)
|
||||||
|
if not PatchLIGHTgbmModelIO.__main_task:
|
||||||
|
return ret
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, set):
|
||||||
|
params[k] = list(v)
|
||||||
|
PatchLIGHTgbmModelIO.__main_task.connect(params)
|
||||||
|
return ret
|
@ -48,6 +48,7 @@ class Framework(Options):
|
|||||||
paddlepaddle = 'PaddlePaddle'
|
paddlepaddle = 'PaddlePaddle'
|
||||||
scikitlearn = 'ScikitLearn'
|
scikitlearn = 'ScikitLearn'
|
||||||
xgboost = 'XGBoost'
|
xgboost = 'XGBoost'
|
||||||
|
lightgbm = 'LightGBM'
|
||||||
parquet = 'Parquet'
|
parquet = 'Parquet'
|
||||||
|
|
||||||
__file_extensions_mapping = {
|
__file_extensions_mapping = {
|
||||||
|
@ -35,6 +35,7 @@ from .binding.absl_bind import PatchAbsl
|
|||||||
from .binding.artifacts import Artifacts, Artifact
|
from .binding.artifacts import Artifacts, Artifact
|
||||||
from .binding.environ_bind import EnvironmentBind, PatchOsFork
|
from .binding.environ_bind import EnvironmentBind, PatchOsFork
|
||||||
from .binding.frameworks.fastai_bind import PatchFastai
|
from .binding.frameworks.fastai_bind import PatchFastai
|
||||||
|
from .binding.frameworks.lightgbm_bind import PatchLIGHTgbmModelIO
|
||||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||||
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
from .binding.frameworks.tensorflow_bind import TensorflowBinding
|
||||||
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||||
@ -333,7 +334,7 @@ class Task(_Task):
|
|||||||
.. code-block:: py
|
.. code-block:: py
|
||||||
|
|
||||||
auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True,
|
auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True,
|
||||||
'xgboost': True, 'scikit': True}
|
'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True}
|
||||||
|
|
||||||
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
|
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
|
||||||
These plots appear in in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab,
|
These plots appear in in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab,
|
||||||
@ -502,6 +503,8 @@ class Task(_Task):
|
|||||||
PatchXGBoostModelIO.update_current_task(task)
|
PatchXGBoostModelIO.update_current_task(task)
|
||||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
|
||||||
PatchFastai.update_current_task(task)
|
PatchFastai.update_current_task(task)
|
||||||
|
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True):
|
||||||
|
PatchLIGHTgbmModelIO.update_current_task(task)
|
||||||
if auto_resource_monitoring and not is_sub_process_task_id:
|
if auto_resource_monitoring and not is_sub_process_task_id:
|
||||||
resource_monitor_cls = auto_resource_monitoring \
|
resource_monitor_cls = auto_resource_monitoring \
|
||||||
if isinstance(auto_resource_monitoring, six.class_types) else ResourceMonitor
|
if isinstance(auto_resource_monitoring, six.class_types) else ResourceMonitor
|
||||||
|
Loading…
Reference in New Issue
Block a user