mirror of
https://github.com/clearml/clearml
synced 2025-02-01 01:26:49 +00:00
131 lines
4.8 KiB
Python
131 lines
4.8 KiB
Python
import sys
|
|
|
|
import six
|
|
from pathlib2 import Path
|
|
|
|
from ..frameworks.base_bind import PatchBaseModelIO
|
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
|
from ..import_bind import PostImportHookPatching
|
|
from ...config import running_remotely
|
|
from ...model import Framework
|
|
|
|
|
|
class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
|
__main_task = None
|
|
__patched = None
|
|
|
|
@staticmethod
|
|
def update_current_task(task, **kwargs):
|
|
PatchLIGHTgbmModelIO.__main_task = task
|
|
PatchLIGHTgbmModelIO._patch_model_io()
|
|
PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io)
|
|
|
|
@staticmethod
|
|
def _patch_model_io():
|
|
if PatchLIGHTgbmModelIO.__patched:
|
|
return
|
|
|
|
if 'lightgbm' not in sys.modules:
|
|
return
|
|
PatchLIGHTgbmModelIO.__patched = True
|
|
# noinspection PyBroadException
|
|
try:
|
|
import lightgbm as lgb # noqa
|
|
|
|
lgb.Booster.save_model = _patched_call(lgb.Booster.save_model, PatchLIGHTgbmModelIO._save)
|
|
lgb.train = _patched_call(lgb.train, PatchLIGHTgbmModelIO._train)
|
|
lgb.Booster = _patched_call(lgb.Booster, PatchLIGHTgbmModelIO._load)
|
|
except ImportError:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _save(original_fn, obj, f, *args, **kwargs):
|
|
ret = original_fn(obj, f, *args, **kwargs)
|
|
if not PatchLIGHTgbmModelIO.__main_task:
|
|
return ret
|
|
|
|
if isinstance(f, six.string_types):
|
|
filename = f
|
|
elif hasattr(f, 'name'):
|
|
filename = f.name
|
|
# noinspection PyBroadException
|
|
try:
|
|
f.flush()
|
|
except Exception:
|
|
pass
|
|
else:
|
|
filename = None
|
|
|
|
# give the model a descriptive name based on the file name
|
|
# noinspection PyBroadException
|
|
try:
|
|
model_name = Path(filename).stem
|
|
except Exception:
|
|
model_name = None
|
|
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO.__main_task,
|
|
singlefile=True, model_name=model_name)
|
|
return ret
|
|
|
|
@staticmethod
|
|
def _load(original_fn, model_file, *args, **kwargs):
|
|
if isinstance(model_file, six.string_types):
|
|
filename = model_file
|
|
elif hasattr(model_file, 'name'):
|
|
filename = model_file.name
|
|
elif len(args) == 1 and isinstance(args[0], six.string_types):
|
|
filename = args[0]
|
|
else:
|
|
filename = None
|
|
|
|
if not PatchLIGHTgbmModelIO.__main_task:
|
|
return original_fn(model_file, *args, **kwargs)
|
|
|
|
# register input model
|
|
empty = _Empty()
|
|
# Hack: disabled
|
|
if False and running_remotely():
|
|
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
|
PatchLIGHTgbmModelIO.__main_task)
|
|
model = original_fn(model_file=filename or model_file, *args, **kwargs)
|
|
else:
|
|
# try to load model before registering, in case we fail
|
|
model = original_fn(model_file=model_file, *args, **kwargs)
|
|
WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm,
|
|
PatchLIGHTgbmModelIO.__main_task)
|
|
|
|
if empty.trains_in_model:
|
|
# noinspection PyBroadException
|
|
try:
|
|
model.trains_in_model = empty.trains_in_model
|
|
except Exception:
|
|
pass
|
|
return model
|
|
|
|
@staticmethod
|
|
def _train(original_fn, *args, **kwargs):
|
|
def trains_lightgbm_callback():
|
|
def callback(env):
|
|
# logging the results to scalars section
|
|
# noinspection PyBroadException
|
|
try:
|
|
logger = PatchLIGHTgbmModelIO.__main_task.get_logger()
|
|
iteration = env.iteration
|
|
for data_title, data_series, value, _ in env.evaluation_result_list:
|
|
logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value),
|
|
iteration=iteration)
|
|
except Exception:
|
|
pass
|
|
return callback
|
|
params, train_set = args
|
|
kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback())
|
|
ret = original_fn(params, train_set, **kwargs)
|
|
if not PatchLIGHTgbmModelIO.__main_task:
|
|
return ret
|
|
for k, v in params.items():
|
|
if isinstance(v, set):
|
|
params[k] = list(v)
|
|
PatchLIGHTgbmModelIO.__main_task.connect(params)
|
|
return ret
|