clearml/trains/binding/joblib_bind.py

111 lines
3.7 KiB
Python

try:
import joblib
except ImportError as e:
joblib = None
import six
from pathlib2 import Path
from trains.binding.frameworks import _patched_call, _Empty, WeightsFileHandler
from trains.config import running_remotely
from trains.debugging.log import LoggerRoot
class PatchedJoblib(object):
_patched_original_dump = None
_patched_original_load = None
_current_task = None
_current_framework = None
@staticmethod
def patch_joblib():
if PatchedJoblib._patched_original_dump is not None and PatchedJoblib._patched_original_load is not None:
# We don't need to patch anything else, so we are done
return True
# noinspection PyBroadException
try:
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
except Exception:
return False
return True
@staticmethod
def update_current_task(task):
if PatchedJoblib.patch_joblib():
PatchedJoblib._current_task = task
@staticmethod
def _dump(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchedJoblib._current_task:
return ret
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = None
# give the model a descriptive name based on the file name
# noinspection PyBroadException
try:
model_name = Path(filename).stem
except Exception:
model_name = None
PatchedJoblib._current_framework = PatchedJoblib.get_model_framework(obj)
WeightsFileHandler.create_output_model(obj, filename, PatchedJoblib._current_framework,
PatchedJoblib._current_task, singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchedJoblib._current_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
if running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
PatchedJoblib._current_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
PatchedJoblib._current_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
model.trains_in_model = empty.trains_in_model
except Exception:
pass
return model
@staticmethod
def get_model_framework(obj):
object_orig_module = obj.__module__
framework = object_orig_module
try:
framework = object_orig_module.partition(".")[0]
except Exception as _:
LoggerRoot.get_base_logger().warning(
"Can't get model framework, model framework will be: {} ".format(object_orig_module))
finally:
return framework