mirror of
https://github.com/clearml/clearml
synced 2025-04-18 13:24:41 +00:00
Fix joblib binding
This commit is contained in:
parent
5db53ba643
commit
7817ef5cda
@ -1,5 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
@ -38,8 +39,14 @@ class PatchedJoblib(object):
|
|||||||
joblib = None
|
joblib = None
|
||||||
|
|
||||||
if joblib:
|
if joblib:
|
||||||
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
joblib.numpy_pickle._write_fileobject = _patched_call(
|
||||||
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
joblib.numpy_pickle._write_fileobject,
|
||||||
|
partial(PatchedJoblib._write_fileobject, joblib.numpy_pickle))
|
||||||
|
joblib.numpy_pickle._read_fileobject = _patched_call(
|
||||||
|
joblib.numpy_pickle._read_fileobject, PatchedJoblib._load)
|
||||||
|
joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call(
|
||||||
|
joblib.numpy_pickle.NumpyPickler.__init__,
|
||||||
|
PatchedJoblib._numpypickler)
|
||||||
|
|
||||||
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
|
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
|
||||||
PatchedJoblib._patched_sk_joblib = True
|
PatchedJoblib._patched_sk_joblib = True
|
||||||
@ -53,8 +60,14 @@ class PatchedJoblib(object):
|
|||||||
sk_joblib = None
|
sk_joblib = None
|
||||||
|
|
||||||
if sk_joblib:
|
if sk_joblib:
|
||||||
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
|
sk_joblib.numpy_pickle._write_fileobject = _patched_call(
|
||||||
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
|
sk_joblib.numpy_pickle._write_fileobject,
|
||||||
|
partial(PatchedJoblib._write_fileobject, sk_joblib.numpy_pickle))
|
||||||
|
sk_joblib.numpy_pickle._read_fileobject = _patched_call(
|
||||||
|
sk_joblib.numpy_pickle._read_fileobject, PatchedJoblib._load)
|
||||||
|
sk_joblib.numpy_pickle.NumpyPickler.__init__ = _patched_call(
|
||||||
|
sk_joblib.numpy_pickle.NumpyPickler.__init__,
|
||||||
|
PatchedJoblib._numpypickler)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
@ -70,7 +83,27 @@ class PatchedJoblib(object):
|
|||||||
ret = original_fn(obj, f, *args, **kwargs)
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
if not PatchedJoblib._current_task:
|
if not PatchedJoblib._current_task:
|
||||||
return ret
|
return ret
|
||||||
|
PatchedJoblib._register_dump(obj, f)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _numpypickler(original_fn, obj, f, *args, **kwargs):
|
||||||
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
if not PatchedJoblib._current_task:
|
||||||
|
return ret
|
||||||
|
PatchedJoblib._register_dump(obj, f)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _write_fileobject(obj, original_fn, f, *args, **kwargs):
|
||||||
|
ret = original_fn(f, *args, **kwargs)
|
||||||
|
if not PatchedJoblib._current_task:
|
||||||
|
return ret
|
||||||
|
PatchedJoblib._register_dump(obj, f)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _register_dump(obj, f):
|
||||||
if isinstance(f, six.string_types):
|
if isinstance(f, six.string_types):
|
||||||
filename = f
|
filename = f
|
||||||
elif hasattr(f, 'name'):
|
elif hasattr(f, 'name'):
|
||||||
@ -81,7 +114,7 @@ class PatchedJoblib(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
filename = None
|
return
|
||||||
|
|
||||||
# give the model a descriptive name based on the file name
|
# give the model a descriptive name based on the file name
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -92,7 +125,6 @@ class PatchedJoblib(object):
|
|||||||
current_framework = PatchedJoblib.get_model_framework(obj)
|
current_framework = PatchedJoblib.get_model_framework(obj)
|
||||||
WeightsFileHandler.create_output_model(obj, filename, current_framework,
|
WeightsFileHandler.create_output_model(obj, filename, current_framework,
|
||||||
PatchedJoblib._current_task, singlefile=True, model_name=model_name)
|
PatchedJoblib._current_task, singlefile=True, model_name=model_name)
|
||||||
return ret
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, f, *args, **kwargs):
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
@ -132,9 +164,9 @@ class PatchedJoblib(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_framework(obj):
|
def get_model_framework(obj):
|
||||||
object_orig_module = obj.__module__
|
|
||||||
framework = Framework.scikitlearn
|
framework = Framework.scikitlearn
|
||||||
try:
|
try:
|
||||||
|
object_orig_module = obj.__module__ if hasattr(obj, '__module__') else obj.__package__
|
||||||
model = object_orig_module.partition(".")[0]
|
model = object_orig_module.partition(".")[0]
|
||||||
if model == 'sklearn':
|
if model == 'sklearn':
|
||||||
framework = Framework.scikitlearn
|
framework = Framework.scikitlearn
|
||||||
|
Loading…
Reference in New Issue
Block a user