From 02c2a124490e8876ff495ac954586477b9c20849 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 6 Oct 2019 01:43:37 +0300 Subject: [PATCH] Fix deprecated warning on imports from sklearn --- trains/binding/import_bind.py | 8 ++--- trains/binding/joblib_bind.py | 62 ++++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/trains/binding/import_bind.py b/trains/binding/import_bind.py index 98783718..7d8f4efd 100644 --- a/trains/binding/import_bind.py +++ b/trains/binding/import_bind.py @@ -46,16 +46,16 @@ class PostImportHookPatching(object): @staticmethod def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0): - already_imported = name in sys.modules + base_name = name.split('.')[0] + already_imported = (not base_name) or (base_name in sys.modules) mod = builtins.__org_import__( name, globals=globals, locals=locals, fromlist=fromlist, level=level) - - if not already_imported and name in PostImportHookPatching._post_import_hooks: - for hook in PostImportHookPatching._post_import_hooks[name]: + if not already_imported and base_name in PostImportHookPatching._post_import_hooks: + for hook in PostImportHookPatching._post_import_hooks[base_name]: hook() return mod diff --git a/trains/binding/joblib_bind.py b/trains/binding/joblib_bind.py index 80dd0ac2..b34beac3 100644 --- a/trains/binding/joblib_bind.py +++ b/trains/binding/joblib_bind.py @@ -1,6 +1,10 @@ +import sys +import warnings + import six from pathlib2 import Path +from .import_bind import PostImportHookPatching from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler from ..config import running_remotely from ..debugging.log import LoggerRoot @@ -9,34 +13,48 @@ from ..model import Framework class PatchedJoblib(object): _patched_joblib = False + _patched_sk_joblib = False _current_task = None @staticmethod def patch_joblib(): - if PatchedJoblib._patched_joblib: - # We don't need to patch anything else, so we are done - return True + # try manually + PatchedJoblib._patch_joblib() + # register callback + PostImportHookPatching.add_on_import('joblib', + PatchedJoblib._patch_joblib) + PostImportHookPatching.add_on_import('sklearn', + PatchedJoblib._patch_joblib) - # whatever happens we should not retry to patch it - PatchedJoblib._patched_joblib = True + @staticmethod + def _patch_joblib(): # noinspection PyBroadException try: - try: - import joblib - except ImportError: - joblib = None + if not PatchedJoblib._patched_joblib and 'joblib' in sys.modules: + PatchedJoblib._patched_joblib = True + try: + import joblib + except ImportError: + joblib = None - try: - from sklearn.externals import joblib as sk_joblib - except ImportError: - sk_joblib = None + if joblib: + joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump) + joblib.load = _patched_call(joblib.load, PatchedJoblib._load) - if joblib: - joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump) - joblib.load = _patched_call(joblib.load, PatchedJoblib._load) - if sk_joblib: - sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump) - sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load) + if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules: + PatchedJoblib._patched_sk_joblib = True + try: + import sklearn + # avoid deprecation warning, we must import sklearn before, so we could catch it + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from sklearn.externals import joblib as sk_joblib + except ImportError: + sk_joblib = None + + if sk_joblib: + sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump) + sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load) except Exception: return False @@ -44,8 +62,8 @@ class PatchedJoblib(object): @staticmethod def update_current_task(task): - if PatchedJoblib.patch_joblib(): - PatchedJoblib._current_task = task + PatchedJoblib._current_task = task + PatchedJoblib.patch_joblib() @staticmethod def _dump(original_fn, obj, f, *args, **kwargs): @@ -57,7 +75,7 @@ class PatchedJoblib(object): filename = f elif hasattr(f, 'name'): filename = f.name - # noinspection PyBroadException + # noinspection PyBroadException try: f.flush() except Exception: