From 7817ef5cda031024e26fb38c68e1f6caa5c06cb3 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 20 Mar 2020 10:30:13 +0200 Subject: [PATCH] Fix joblib binding --- trains/binding/joblib_bind.py | 46 +++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/trains/binding/joblib_bind.py b/trains/binding/joblib_bind.py index b34beac3..daf6153c 100644 --- a/trains/binding/joblib_bind.py +++ b/trains/binding/joblib_bind.py @@ -1,5 +1,6 @@ import sys import warnings +from functools import partial import six from pathlib2 import Path @@ -38,8 +39,14 @@ class PatchedJoblib(object): joblib = None if joblib: - joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump) - joblib.load = _patched_call(joblib.load, PatchedJoblib._load) + joblib.numpy_pickle._write_fileobject = _patched_call( + joblib.numpy_pickle._write_fileobject, + partial(PatchedJoblib._write_fileobject, joblib.numpy_pickle)) + joblib.numpy_pickle._read_fileobject = _patched_call( + joblib.numpy_pickle._read_fileobject, PatchedJoblib._load) + joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call( + joblib.numpy_pickle.NumpyPickler.__init__, + PatchedJoblib._numpypickler) if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules: PatchedJoblib._patched_sk_joblib = True @@ -53,8 +60,14 @@ class PatchedJoblib(object): sk_joblib = None if sk_joblib: - sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump) - sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load) + sk_joblib.numpy_pickle._write_fileobject = _patched_call( + sk_joblib.numpy_pickle._write_fileobject, + partial(PatchedJoblib._write_fileobject, sk_joblib.numpy_pickle)) + sk_joblib.numpy_pickle._read_fileobject = _patched_call( + sk_joblib.numpy_pickle._read_fileobject, PatchedJoblib._load) + sk_joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call( + sk_joblib.numpy_pickle.NumpyPickler.__init__, + PatchedJoblib._numpypickler) except Exception: return False @@ -70,7 +83,27 @@ class PatchedJoblib(object): ret = original_fn(obj, f, *args, **kwargs) if not PatchedJoblib._current_task: return ret + PatchedJoblib._register_dump(obj, f) + return ret + @staticmethod + def _numpypickler(original_fn, obj, f, *args, **kwargs): + ret = original_fn(obj, f, *args, **kwargs) + if not PatchedJoblib._current_task: + return ret + PatchedJoblib._register_dump(obj, f) + return ret + + @staticmethod + def _write_fileobject(obj, original_fn, f, *args, **kwargs): + ret = original_fn(f, *args, **kwargs) + if not PatchedJoblib._current_task: + return ret + PatchedJoblib._register_dump(obj, f) + return ret + + @staticmethod + def _register_dump(obj, f): if isinstance(f, six.string_types): filename = f elif hasattr(f, 'name'): @@ -81,7 +114,7 @@ class PatchedJoblib(object): except Exception: pass else: - filename = None + return # give the model a descriptive name based on the file name # noinspection PyBroadException @@ -92,7 +125,6 @@ class PatchedJoblib(object): current_framework = PatchedJoblib.get_model_framework(obj) WeightsFileHandler.create_output_model(obj, filename, current_framework, PatchedJoblib._current_task, singlefile=True, model_name=model_name) - return ret @staticmethod def _load(original_fn, f, *args, **kwargs): @@ -132,9 +164,9 @@ class PatchedJoblib(object): @staticmethod def get_model_framework(obj): - object_orig_module = obj.__module__ framework = Framework.scikitlearn try: + object_orig_module = obj.__module__ if hasattr(obj, '__module__') else obj.__package__ model = object_orig_module.partition(".")[0] if model == 'sklearn': framework = Framework.scikitlearn