Add LightGBM support

This commit is contained in:
allegroai
2020-10-12 12:34:52 +03:00
parent 98ea965e6d
commit c9fac89bcd
5 changed files with 210 additions and 1 deletions

View File

@@ -0,0 +1,130 @@
import sys
import six
from pathlib2 import Path
from ..frameworks.base_bind import PatchBaseModelIO
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
from ..import_bind import PostImportHookPatching
from ...config import running_remotely
from ...model import Framework
class PatchLIGHTgbmModelIO(PatchBaseModelIO):
__main_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchLIGHTgbmModelIO.__main_task = task
PatchLIGHTgbmModelIO._patch_model_io()
PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io)
@staticmethod
def _patch_model_io():
if PatchLIGHTgbmModelIO.__patched:
return
if 'lightgbm' not in sys.modules:
return
PatchLIGHTgbmModelIO.__patched = True
# noinspection PyBroadException
try:
import lightgbm as lgb # noqa
lgb.Booster.save_model = _patched_call(lgb.Booster.save_model, PatchLIGHTgbmModelIO._save)
lgb.train = _patched_call(lgb.train, PatchLIGHTgbmModelIO._train)
lgb.Booster = _patched_call(lgb.Booster, PatchLIGHTgbmModelIO._load)
except ImportError:
pass
except Exception:
pass
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchLIGHTgbmModelIO.__main_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = None
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO.__main_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, model_file, *args, **kwargs):
if isinstance(model_file, six.string_types):
filename = model_file
elif hasattr(model_file, 'name'):
filename = model_file.name
elif len(args) == 1 and isinstance(args[0], six.string_types):
filename = args[0]
else:
filename = None
if not PatchLIGHTgbmModelIO.__main_task:
return original_fn(model_file, *args, **kwargs)
# register input model
empty = _Empty()
# Hack: disabled
if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchLIGHTgbmModelIO.__main_task)
model = original_fn(model_file=filename or model_file, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(model_file=model_file, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm,
PatchLIGHTgbmModelIO.__main_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model
@staticmethod
def _train(original_fn, *args, **kwargs):
def trains_lightgbm_callback():
def callback(env):
# logging the results to scalars section
# noinspection PyBroadException
try:
logger = PatchLIGHTgbmModelIO.__main_task.get_logger()
iteration = env.iteration
for data_title, data_series, value, _ in env.evaluation_result_list:
logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value),
iteration=iteration)
except Exception:
pass
return callback
params, train_set = args
kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback())
ret = original_fn(params, train_set, **kwargs)
if not PatchLIGHTgbmModelIO.__main_task:
return ret
for k, v in params.items():
if isinstance(v, set):
params[k] = list(v)
PatchLIGHTgbmModelIO.__main_task.connect(params)
return ret

View File

@@ -48,6 +48,7 @@ class Framework(Options):
paddlepaddle = 'PaddlePaddle'
scikitlearn = 'ScikitLearn'
xgboost = 'XGBoost'
lightgbm = 'LightGBM'
parquet = 'Parquet'
__file_extensions_mapping = {

View File

@@ -35,6 +35,7 @@ from .binding.absl_bind import PatchAbsl
from .binding.artifacts import Artifacts, Artifact
from .binding.environ_bind import EnvironmentBind, PatchOsFork
from .binding.frameworks.fastai_bind import PatchFastai
from .binding.frameworks.lightgbm_bind import PatchLIGHTgbmModelIO
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from .binding.frameworks.tensorflow_bind import TensorflowBinding
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
@@ -333,7 +334,7 @@ class Task(_Task):
.. code-block:: py
auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True,
'xgboost': True, 'scikit': True}
'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True}
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
These plots appear in in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab,
@@ -502,6 +503,8 @@ class Task(_Task):
PatchXGBoostModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('fastai', True):
PatchFastai.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('lightgbm', True):
PatchLIGHTgbmModelIO.update_current_task(task)
if auto_resource_monitoring and not is_sub_process_task_id:
resource_monitor_cls = auto_resource_monitoring \
if isinstance(auto_resource_monitoring, six.class_types) else ResourceMonitor