mirror of
https://github.com/clearml/clearml
synced 2025-02-12 07:35:08 +00:00
Add support for scikit-learn internal joblib implementation
This commit is contained in:
parent
19c5f05912
commit
4c15613250
@ -1,4 +1,7 @@
|
|||||||
import joblib
|
try:
|
||||||
|
from sklearn.externals import joblib
|
||||||
|
except ImportError:
|
||||||
|
import joblib
|
||||||
|
|
||||||
from sklearn import datasets
|
from sklearn import datasets
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
@ -3,7 +3,7 @@ import sys
|
|||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
from trains.binding.frameworks.base_bind import PatchBaseModelIO
|
from ..frameworks.base_bind import PatchBaseModelIO
|
||||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||||
from ..import_bind import PostImportHookPatching
|
from ..import_bind import PostImportHookPatching
|
||||||
from ...config import running_remotely
|
from ...config import running_remotely
|
||||||
|
@ -1,31 +1,42 @@
|
|||||||
try:
|
|
||||||
import joblib
|
|
||||||
except ImportError as e:
|
|
||||||
joblib = None
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
from trains.binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
||||||
from trains.config import running_remotely
|
from ..config import running_remotely
|
||||||
from trains.debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
|
from ..model import Framework
|
||||||
|
|
||||||
|
|
||||||
class PatchedJoblib(object):
|
class PatchedJoblib(object):
|
||||||
_patched_original_dump = None
|
_patched_joblib = False
|
||||||
_patched_original_load = None
|
|
||||||
_current_task = None
|
_current_task = None
|
||||||
_current_framework = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patch_joblib():
|
def patch_joblib():
|
||||||
if PatchedJoblib._patched_original_dump is not None and PatchedJoblib._patched_original_load is not None:
|
if PatchedJoblib._patched_joblib:
|
||||||
# We don't need to patch anything else, so we are done
|
# We don't need to patch anything else, so we are done
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# whatever happens we should not retry to patch it
|
||||||
|
PatchedJoblib._patched_joblib = True
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
try:
|
||||||
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
import joblib
|
||||||
|
except ImportError:
|
||||||
|
joblib = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sklearn.externals import joblib as sk_joblib
|
||||||
|
except ImportError:
|
||||||
|
sk_joblib = None
|
||||||
|
|
||||||
|
if joblib:
|
||||||
|
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
||||||
|
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
||||||
|
if sk_joblib:
|
||||||
|
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
|
||||||
|
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
@ -60,8 +71,8 @@ class PatchedJoblib(object):
|
|||||||
model_name = Path(filename).stem
|
model_name = Path(filename).stem
|
||||||
except Exception:
|
except Exception:
|
||||||
model_name = None
|
model_name = None
|
||||||
PatchedJoblib._current_framework = PatchedJoblib.get_model_framework(obj)
|
current_framework = PatchedJoblib.get_model_framework(obj)
|
||||||
WeightsFileHandler.create_output_model(obj, filename, PatchedJoblib._current_framework,
|
WeightsFileHandler.create_output_model(obj, filename, current_framework,
|
||||||
PatchedJoblib._current_task, singlefile=True, model_name=model_name)
|
PatchedJoblib._current_task, singlefile=True, model_name=model_name)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@ -80,13 +91,16 @@ class PatchedJoblib(object):
|
|||||||
# register input model
|
# register input model
|
||||||
empty = _Empty()
|
empty = _Empty()
|
||||||
if running_remotely():
|
if running_remotely():
|
||||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
|
# we assume scikit-learn, for the time being
|
||||||
|
current_framework = Framework.scikitlearn
|
||||||
|
filename = WeightsFileHandler.restore_weights_file(empty, filename, current_framework,
|
||||||
PatchedJoblib._current_task)
|
PatchedJoblib._current_task)
|
||||||
model = original_fn(filename or f, *args, **kwargs)
|
model = original_fn(filename or f, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# try to load model before registering, in case we fail
|
# try to load model before registering, in case we fail
|
||||||
model = original_fn(f, *args, **kwargs)
|
model = original_fn(f, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
|
current_framework = PatchedJoblib.get_model_framework(model)
|
||||||
|
WeightsFileHandler.restore_weights_file(empty, filename, current_framework,
|
||||||
PatchedJoblib._current_task)
|
PatchedJoblib._current_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
@ -100,11 +114,17 @@ class PatchedJoblib(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_framework(obj):
|
def get_model_framework(obj):
|
||||||
object_orig_module = obj.__module__
|
object_orig_module = obj.__module__
|
||||||
framework = object_orig_module
|
framework = Framework.scikitlearn
|
||||||
try:
|
try:
|
||||||
framework = object_orig_module.partition(".")[0]
|
model = object_orig_module.partition(".")[0]
|
||||||
|
if model == 'sklearn':
|
||||||
|
framework = Framework.scikitlearn
|
||||||
|
elif model == 'xgboost':
|
||||||
|
framework = Framework.xgboost
|
||||||
|
else:
|
||||||
|
framework = Framework.scikitlearn
|
||||||
except Exception as _:
|
except Exception as _:
|
||||||
LoggerRoot.get_base_logger().warning(
|
LoggerRoot.get_base_logger().debug(
|
||||||
"Can't get model framework, model framework will be: {} ".format(object_orig_module))
|
"Can't get model framework {}, model framework will be: {} ".format(object_orig_module, framework))
|
||||||
finally:
|
finally:
|
||||||
return framework
|
return framework
|
||||||
|
@ -10,7 +10,7 @@ from collections import OrderedDict, Callable
|
|||||||
import psutil
|
import psutil
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from trains.binding.joblib_bind import PatchedJoblib
|
from .binding.joblib_bind import PatchedJoblib
|
||||||
from .backend_api.services import tasks, projects
|
from .backend_api.services import tasks, projects
|
||||||
from .backend_api.session.session import Session
|
from .backend_api.session.session import Session
|
||||||
from .backend_interface.model import Model as BackendModel
|
from .backend_interface.model import Model as BackendModel
|
||||||
|
Loading…
Reference in New Issue
Block a user