Fix joblib binding

This commit is contained in:
allegroai 2020-03-20 10:30:13 +02:00
parent 5db53ba643
commit 7817ef5cda

View File

@ -1,5 +1,6 @@
import sys import sys
import warnings import warnings
from functools import partial
import six import six
from pathlib2 import Path from pathlib2 import Path
@ -38,8 +39,14 @@ class PatchedJoblib(object):
joblib = None joblib = None
if joblib: if joblib:
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump) joblib.numpy_pickle._write_fileobject = _patched_call(
joblib.load = _patched_call(joblib.load, PatchedJoblib._load) joblib.numpy_pickle._write_fileobject,
partial(PatchedJoblib._write_fileobject, joblib.numpy_pickle))
joblib.numpy_pickle._read_fileobject = _patched_call(
joblib.numpy_pickle._read_fileobject, PatchedJoblib._load)
joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call(
joblib.numpy_pickle.NumpyPickler.__init__,
PatchedJoblib._numpypickler)
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules: if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
PatchedJoblib._patched_sk_joblib = True PatchedJoblib._patched_sk_joblib = True
@ -53,8 +60,14 @@ class PatchedJoblib(object):
sk_joblib = None sk_joblib = None
if sk_joblib: if sk_joblib:
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump) sk_joblib.numpy_pickle._write_fileobject = _patched_call(
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load) sk_joblib.numpy_pickle._write_fileobject,
partial(PatchedJoblib._write_fileobject, sk_joblib.numpy_pickle))
sk_joblib.numpy_pickle._read_fileobject = _patched_call(
sk_joblib.numpy_pickle._read_fileobject, PatchedJoblib._load)
sk_joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call(
sk_joblib.numpy_pickle.NumpyPickler.__init__,
PatchedJoblib._numpypickler)
except Exception: except Exception:
return False return False
@ -70,7 +83,27 @@ class PatchedJoblib(object):
ret = original_fn(obj, f, *args, **kwargs) ret = original_fn(obj, f, *args, **kwargs)
if not PatchedJoblib._current_task: if not PatchedJoblib._current_task:
return ret return ret
PatchedJoblib._register_dump(obj, f)
return ret
@staticmethod
def _numpypickler(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchedJoblib._current_task:
return ret
PatchedJoblib._register_dump(obj, f)
return ret
@staticmethod
def _write_fileobject(obj, original_fn, f, *args, **kwargs):
ret = original_fn(f, *args, **kwargs)
if not PatchedJoblib._current_task:
return ret
PatchedJoblib._register_dump(obj, f)
return ret
@staticmethod
def _register_dump(obj, f):
if isinstance(f, six.string_types): if isinstance(f, six.string_types):
filename = f filename = f
elif hasattr(f, 'name'): elif hasattr(f, 'name'):
@ -81,7 +114,7 @@ class PatchedJoblib(object):
except Exception: except Exception:
pass pass
else: else:
filename = None return
# give the model a descriptive name based on the file name # give the model a descriptive name based on the file name
# noinspection PyBroadException # noinspection PyBroadException
@ -92,7 +125,6 @@ class PatchedJoblib(object):
current_framework = PatchedJoblib.get_model_framework(obj) current_framework = PatchedJoblib.get_model_framework(obj)
WeightsFileHandler.create_output_model(obj, filename, current_framework, WeightsFileHandler.create_output_model(obj, filename, current_framework,
PatchedJoblib._current_task, singlefile=True, model_name=model_name) PatchedJoblib._current_task, singlefile=True, model_name=model_name)
return ret
@staticmethod @staticmethod
def _load(original_fn, f, *args, **kwargs): def _load(original_fn, f, *args, **kwargs):
@ -132,9 +164,9 @@ class PatchedJoblib(object):
@staticmethod @staticmethod
def get_model_framework(obj): def get_model_framework(obj):
object_orig_module = obj.__module__
framework = Framework.scikitlearn framework = Framework.scikitlearn
try: try:
object_orig_module = obj.__module__ if hasattr(obj, '__module__') else obj.__package__
model = object_orig_module.partition(".")[0] model = object_orig_module.partition(".")[0]
if model == 'sklearn': if model == 'sklearn':
framework = Framework.scikitlearn framework = Framework.scikitlearn