Fix deprecated warning on imports from sklearn

This commit is contained in:
allegroai 2019-10-06 01:43:37 +03:00
parent c44f3ff11c
commit 02c2a12449
2 changed files with 44 additions and 26 deletions

View File

@ -46,16 +46,16 @@ class PostImportHookPatching(object):
@staticmethod @staticmethod
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0): def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
already_imported = name in sys.modules base_name = name.split('.')[0]
already_imported = (not base_name) or (base_name in sys.modules)
mod = builtins.__org_import__( mod = builtins.__org_import__(
name, name,
globals=globals, globals=globals,
locals=locals, locals=locals,
fromlist=fromlist, fromlist=fromlist,
level=level) level=level)
if not already_imported and base_name in PostImportHookPatching._post_import_hooks:
if not already_imported and name in PostImportHookPatching._post_import_hooks: for hook in PostImportHookPatching._post_import_hooks[base_name]:
for hook in PostImportHookPatching._post_import_hooks[name]:
hook() hook()
return mod return mod

View File

@ -1,6 +1,10 @@
import sys
import warnings
import six import six
from pathlib2 import Path from pathlib2 import Path
from .import_bind import PostImportHookPatching
from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
from ..config import running_remotely from ..config import running_remotely
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
@ -9,31 +13,45 @@ from ..model import Framework
class PatchedJoblib(object): class PatchedJoblib(object):
_patched_joblib = False _patched_joblib = False
_patched_sk_joblib = False
_current_task = None _current_task = None
@staticmethod @staticmethod
def patch_joblib(): def patch_joblib():
if PatchedJoblib._patched_joblib: # try manually
# We don't need to patch anything else, so we are done PatchedJoblib._patch_joblib()
return True # register callback
PostImportHookPatching.add_on_import('joblib',
PatchedJoblib._patch_joblib)
PostImportHookPatching.add_on_import('sklearn',
PatchedJoblib._patch_joblib)
# whatever happens we should not retry to patch it @staticmethod
PatchedJoblib._patched_joblib = True def _patch_joblib():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if not PatchedJoblib._patched_joblib and 'joblib' in sys.modules:
PatchedJoblib._patched_joblib = True
try: try:
import joblib import joblib
except ImportError: except ImportError:
joblib = None joblib = None
if joblib:
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
PatchedJoblib._patched_sk_joblib = True
try: try:
import sklearn
# avoid deprecation warning, we must import sklearn before, so we could catch it
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from sklearn.externals import joblib as sk_joblib from sklearn.externals import joblib as sk_joblib
except ImportError: except ImportError:
sk_joblib = None sk_joblib = None
if joblib:
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
if sk_joblib: if sk_joblib:
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump) sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load) sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
@ -44,8 +62,8 @@ class PatchedJoblib(object):
@staticmethod @staticmethod
def update_current_task(task): def update_current_task(task):
if PatchedJoblib.patch_joblib():
PatchedJoblib._current_task = task PatchedJoblib._current_task = task
PatchedJoblib.patch_joblib()
@staticmethod @staticmethod
def _dump(original_fn, obj, f, *args, **kwargs): def _dump(original_fn, obj, f, *args, **kwargs):