mirror of
https://github.com/clearml/clearml
synced 2025-02-01 01:26:49 +00:00
111 lines
3.7 KiB
Python
111 lines
3.7 KiB
Python
try:
|
|
import joblib
|
|
except ImportError as e:
|
|
joblib = None
|
|
|
|
import six
|
|
from pathlib2 import Path
|
|
|
|
from trains.binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
|
from trains.config import running_remotely
|
|
from trains.debugging.log import LoggerRoot
|
|
|
|
|
|
class PatchedJoblib(object):
|
|
_patched_original_dump = None
|
|
_patched_original_load = None
|
|
_current_task = None
|
|
_current_framework = None
|
|
|
|
@staticmethod
|
|
def patch_joblib():
|
|
if PatchedJoblib._patched_original_dump is not None and PatchedJoblib._patched_original_load is not None:
|
|
# We don't need to patch anything else, so we are done
|
|
return True
|
|
# noinspection PyBroadException
|
|
try:
|
|
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
|
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
|
|
|
except Exception:
|
|
return False
|
|
return True
|
|
|
|
@staticmethod
|
|
def update_current_task(task):
|
|
if PatchedJoblib.patch_joblib():
|
|
PatchedJoblib._current_task = task
|
|
|
|
@staticmethod
|
|
def _dump(original_fn, obj, f, *args, **kwargs):
|
|
ret = original_fn(obj, f, *args, **kwargs)
|
|
if not PatchedJoblib._current_task:
|
|
return ret
|
|
|
|
if isinstance(f, six.string_types):
|
|
filename = f
|
|
elif hasattr(f, 'name'):
|
|
filename = f.name
|
|
# noinspection PyBroadException
|
|
try:
|
|
f.flush()
|
|
except Exception:
|
|
pass
|
|
else:
|
|
filename = None
|
|
|
|
# give the model a descriptive name based on the file name
|
|
# noinspection PyBroadException
|
|
try:
|
|
model_name = Path(filename).stem
|
|
except Exception:
|
|
model_name = None
|
|
PatchedJoblib._current_framework = PatchedJoblib.get_model_framework(obj)
|
|
WeightsFileHandler.create_output_model(obj, filename, PatchedJoblib._current_framework,
|
|
PatchedJoblib._current_task, singlefile=True, model_name=model_name)
|
|
return ret
|
|
|
|
@staticmethod
|
|
def _load(original_fn, f, *args, **kwargs):
|
|
if isinstance(f, six.string_types):
|
|
filename = f
|
|
elif hasattr(f, 'name'):
|
|
filename = f.name
|
|
else:
|
|
filename = None
|
|
|
|
if not PatchedJoblib._current_task:
|
|
return original_fn(f, *args, **kwargs)
|
|
|
|
# register input model
|
|
empty = _Empty()
|
|
if running_remotely():
|
|
filename = WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
|
|
PatchedJoblib._current_task)
|
|
model = original_fn(filename or f, *args, **kwargs)
|
|
else:
|
|
# try to load model before registering, in case we fail
|
|
model = original_fn(f, *args, **kwargs)
|
|
WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
|
|
PatchedJoblib._current_task)
|
|
|
|
if empty.trains_in_model:
|
|
# noinspection PyBroadException
|
|
try:
|
|
model.trains_in_model = empty.trains_in_model
|
|
except Exception:
|
|
pass
|
|
return model
|
|
|
|
@staticmethod
|
|
def get_model_framework(obj):
|
|
object_orig_module = obj.__module__
|
|
framework = object_orig_module
|
|
try:
|
|
framework = object_orig_module.partition(".")[0]
|
|
except Exception as _:
|
|
LoggerRoot.get_base_logger().warning(
|
|
"Can't get model framework, model framework will be: {} ".format(object_orig_module))
|
|
finally:
|
|
return framework
|