From 4c15613250ef56404c3958dd3c9cd2efb77158f0 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 31 Jul 2019 01:37:06 +0300 Subject: [PATCH] Add support for scikit-learn internal joblib implementation --- examples/joblib_example.py | 5 +- trains/binding/frameworks/xgboost_bind.py | 2 +- trains/binding/joblib_bind.py | 64 +++++++++++++++-------- trains/task.py | 2 +- 4 files changed, 48 insertions(+), 25 deletions(-) diff --git a/examples/joblib_example.py b/examples/joblib_example.py index 1e3fbc91..fc1eaf00 100644 --- a/examples/joblib_example.py +++ b/examples/joblib_example.py @@ -1,4 +1,7 @@ -import joblib +try: + from sklearn.externals import joblib +except ImportError: + import joblib from sklearn import datasets from sklearn.linear_model import LogisticRegression diff --git a/trains/binding/frameworks/xgboost_bind.py b/trains/binding/frameworks/xgboost_bind.py index ca7287ac..03b3321b 100644 --- a/trains/binding/frameworks/xgboost_bind.py +++ b/trains/binding/frameworks/xgboost_bind.py @@ -3,7 +3,7 @@ import sys import six from pathlib2 import Path -from trains.binding.frameworks.base_bind import PatchBaseModelIO +from ..frameworks.base_bind import PatchBaseModelIO from ..frameworks import _patched_call, WeightsFileHandler, _Empty from ..import_bind import PostImportHookPatching from ...config import running_remotely diff --git a/trains/binding/joblib_bind.py b/trains/binding/joblib_bind.py index 4ef6f302..38e42b26 100644 --- a/trains/binding/joblib_bind.py +++ b/trains/binding/joblib_bind.py @@ -1,31 +1,42 @@ -try: - import joblib -except ImportError as e: - joblib = None - import six from pathlib2 import Path -from trains.binding.frameworks import _patched_call, _Empty, WeightsFileHandler -from trains.config import running_remotely -from trains.debugging.log import LoggerRoot +from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler +from ..config import running_remotely +from ..debugging.log import LoggerRoot +from ..model import Framework class PatchedJoblib(object): - _patched_original_dump = None - _patched_original_load = None + _patched_joblib = False _current_task = None - _current_framework = None @staticmethod def patch_joblib(): - if PatchedJoblib._patched_original_dump is not None and PatchedJoblib._patched_original_load is not None: + if PatchedJoblib._patched_joblib: # We don't need to patch anything else, so we are done return True + + # whatever happens we should not retry to patch it + PatchedJoblib._patched_joblib = True # noinspection PyBroadException try: - joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump) - joblib.load = _patched_call(joblib.load, PatchedJoblib._load) + try: + import joblib + except ImportError: + joblib = None + + try: + from sklearn.externals import joblib as sk_joblib + except ImportError: + sk_joblib = None + + if joblib: + joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump) + joblib.load = _patched_call(joblib.load, PatchedJoblib._load) + if sk_joblib: + sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump) + sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load) except Exception: return False @@ -60,8 +71,8 @@ class PatchedJoblib(object): model_name = Path(filename).stem except Exception: model_name = None - PatchedJoblib._current_framework = PatchedJoblib.get_model_framework(obj) - WeightsFileHandler.create_output_model(obj, filename, PatchedJoblib._current_framework, + current_framework = PatchedJoblib.get_model_framework(obj) + WeightsFileHandler.create_output_model(obj, filename, current_framework, PatchedJoblib._current_task, singlefile=True, model_name=model_name) return ret @@ -80,13 +91,16 @@ class PatchedJoblib(object): # register input model empty = _Empty() if running_remotely(): - filename = WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework, + # we assume scikit-learn, for the time being + current_framework = Framework.scikitlearn + filename = WeightsFileHandler.restore_weights_file(empty, filename, current_framework, PatchedJoblib._current_task) model = original_fn(filename or f, *args, **kwargs) else: # try to load model before registering, in case we fail model = original_fn(f, *args, **kwargs) - WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework, + current_framework = PatchedJoblib.get_model_framework(model) + WeightsFileHandler.restore_weights_file(empty, filename, current_framework, PatchedJoblib._current_task) if empty.trains_in_model: @@ -100,11 +114,17 @@ class PatchedJoblib(object): @staticmethod def get_model_framework(obj): object_orig_module = obj.__module__ - framework = object_orig_module + framework = Framework.scikitlearn try: - framework = object_orig_module.partition(".")[0] + model = object_orig_module.partition(".")[0] + if model == 'sklearn': + framework = Framework.scikitlearn + elif model == 'xgboost': + framework = Framework.xgboost + else: + framework = Framework.scikitlearn except Exception as _: - LoggerRoot.get_base_logger().warning( - "Can't get model framework, model framework will be: {} ".format(object_orig_module)) + LoggerRoot.get_base_logger().debug( + "Can't get model framework {}, model framework will be: {} ".format(object_orig_module, framework)) finally: return framework diff --git a/trains/task.py b/trains/task.py index 2a3b7faa..4c3f5b73 100644 --- a/trains/task.py +++ b/trains/task.py @@ -10,7 +10,7 @@ from collections import OrderedDict, Callable import psutil import six -from trains.binding.joblib_bind import PatchedJoblib +from .binding.joblib_bind import PatchedJoblib from .backend_api.services import tasks, projects from .backend_api.session.session import Session from .backend_interface.model import Model as BackendModel