1
0
mirror of https://github.com/clearml/clearml synced 2025-03-03 10:42:00 +00:00

Fix deprecated warning on imports from sklearn

This commit is contained in:
allegroai 2019-10-06 01:43:37 +03:00
parent c44f3ff11c
commit 02c2a12449
2 changed files with 44 additions and 26 deletions

View File

@ -46,16 +46,16 @@ class PostImportHookPatching(object):
@staticmethod
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
already_imported = name in sys.modules
base_name = name.split('.')[0]
already_imported = (not base_name) or (base_name in sys.modules)
mod = builtins.__org_import__(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level)
if not already_imported and name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[name]:
if not already_imported and base_name in PostImportHookPatching._post_import_hooks:
for hook in PostImportHookPatching._post_import_hooks[base_name]:
hook()
return mod

View File

@ -1,6 +1,10 @@
import sys
import warnings
import six
from pathlib2 import Path
from .import_bind import PostImportHookPatching
from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
from ..config import running_remotely
from ..debugging.log import LoggerRoot
@ -9,34 +13,48 @@ from ..model import Framework
class PatchedJoblib(object):
_patched_joblib = False
_patched_sk_joblib = False
_current_task = None
@staticmethod
def patch_joblib():
if PatchedJoblib._patched_joblib:
# We don't need to patch anything else, so we are done
return True
# try manually
PatchedJoblib._patch_joblib()
# register callback
PostImportHookPatching.add_on_import('joblib',
PatchedJoblib._patch_joblib)
PostImportHookPatching.add_on_import('sklearn',
PatchedJoblib._patch_joblib)
# whatever happens we should not retry to patch it
PatchedJoblib._patched_joblib = True
@staticmethod
def _patch_joblib():
# noinspection PyBroadException
try:
try:
import joblib
except ImportError:
joblib = None
if not PatchedJoblib._patched_joblib and 'joblib' in sys.modules:
PatchedJoblib._patched_joblib = True
try:
import joblib
except ImportError:
joblib = None
try:
from sklearn.externals import joblib as sk_joblib
except ImportError:
sk_joblib = None
if joblib:
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
if joblib:
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
if sk_joblib:
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
PatchedJoblib._patched_sk_joblib = True
try:
import sklearn
# avoid deprecation warning, we must import sklearn before, so we could catch it
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from sklearn.externals import joblib as sk_joblib
except ImportError:
sk_joblib = None
if sk_joblib:
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
except Exception:
return False
@ -44,8 +62,8 @@ class PatchedJoblib(object):
@staticmethod
def update_current_task(task):
if PatchedJoblib.patch_joblib():
PatchedJoblib._current_task = task
PatchedJoblib._current_task = task
PatchedJoblib.patch_joblib()
@staticmethod
def _dump(original_fn, obj, f, *args, **kwargs):
@ -57,7 +75,7 @@ class PatchedJoblib(object):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
# noinspection PyBroadException
try:
f.flush()
except Exception: