Add support for scikit-learn internal joblib implementation

This commit is contained in:
allegroai 2019-07-31 01:37:06 +03:00
parent 19c5f05912
commit 4c15613250
4 changed files with 48 additions and 25 deletions

View File

@ -1,4 +1,7 @@
import joblib
try:
from sklearn.externals import joblib
except ImportError:
import joblib
from sklearn import datasets
from sklearn.linear_model import LogisticRegression

View File

@ -3,7 +3,7 @@ import sys
import six
from pathlib2 import Path
from trains.binding.frameworks.base_bind import PatchBaseModelIO
from ..frameworks.base_bind import PatchBaseModelIO
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
from ..import_bind import PostImportHookPatching
from ...config import running_remotely

View File

@ -1,31 +1,42 @@
try:
import joblib
except ImportError as e:
joblib = None
import six
from pathlib2 import Path
from trains.binding.frameworks import _patched_call, _Empty, WeightsFileHandler
from trains.config import running_remotely
from trains.debugging.log import LoggerRoot
from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
from ..config import running_remotely
from ..debugging.log import LoggerRoot
from ..model import Framework
class PatchedJoblib(object):
_patched_original_dump = None
_patched_original_load = None
_patched_joblib = False
_current_task = None
_current_framework = None
@staticmethod
def patch_joblib():
if PatchedJoblib._patched_original_dump is not None and PatchedJoblib._patched_original_load is not None:
if PatchedJoblib._patched_joblib:
# We don't need to patch anything else, so we are done
return True
# whatever happens we should not retry to patch it
PatchedJoblib._patched_joblib = True
# noinspection PyBroadException
try:
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
try:
import joblib
except ImportError:
joblib = None
try:
from sklearn.externals import joblib as sk_joblib
except ImportError:
sk_joblib = None
if joblib:
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
if sk_joblib:
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
except Exception:
return False
@ -60,8 +71,8 @@ class PatchedJoblib(object):
model_name = Path(filename).stem
except Exception:
model_name = None
PatchedJoblib._current_framework = PatchedJoblib.get_model_framework(obj)
WeightsFileHandler.create_output_model(obj, filename, PatchedJoblib._current_framework,
current_framework = PatchedJoblib.get_model_framework(obj)
WeightsFileHandler.create_output_model(obj, filename, current_framework,
PatchedJoblib._current_task, singlefile=True, model_name=model_name)
return ret
@ -80,13 +91,16 @@ class PatchedJoblib(object):
# register input model
empty = _Empty()
if running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
# we assume scikit-learn, for the time being
current_framework = Framework.scikitlearn
filename = WeightsFileHandler.restore_weights_file(empty, filename, current_framework,
PatchedJoblib._current_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
current_framework = PatchedJoblib.get_model_framework(model)
WeightsFileHandler.restore_weights_file(empty, filename, current_framework,
PatchedJoblib._current_task)
if empty.trains_in_model:
@ -100,11 +114,17 @@ class PatchedJoblib(object):
@staticmethod
def get_model_framework(obj):
object_orig_module = obj.__module__
framework = object_orig_module
framework = Framework.scikitlearn
try:
framework = object_orig_module.partition(".")[0]
model = object_orig_module.partition(".")[0]
if model == 'sklearn':
framework = Framework.scikitlearn
elif model == 'xgboost':
framework = Framework.xgboost
else:
framework = Framework.scikitlearn
except Exception as _:
LoggerRoot.get_base_logger().warning(
"Can't get model framework, model framework will be: {} ".format(object_orig_module))
LoggerRoot.get_base_logger().debug(
"Can't get model framework {}, model framework will be: {} ".format(object_orig_module, framework))
finally:
return framework

View File

@ -10,7 +10,7 @@ from collections import OrderedDict, Callable
import psutil
import six
from trains.binding.joblib_bind import PatchedJoblib
from .binding.joblib_bind import PatchedJoblib
from .backend_api.services import tasks, projects
from .backend_api.session.session import Session
from .backend_interface.model import Model as BackendModel