mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Fix deprecated warning on imports from sklearn
This commit is contained in:
parent
c44f3ff11c
commit
02c2a12449
@ -46,16 +46,16 @@ class PostImportHookPatching(object):
|
||||
|
||||
@staticmethod
|
||||
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
already_imported = name in sys.modules
|
||||
base_name = name.split('.')[0]
|
||||
already_imported = (not base_name) or (base_name in sys.modules)
|
||||
mod = builtins.__org_import__(
|
||||
name,
|
||||
globals=globals,
|
||||
locals=locals,
|
||||
fromlist=fromlist,
|
||||
level=level)
|
||||
|
||||
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
||||
for hook in PostImportHookPatching._post_import_hooks[name]:
|
||||
if not already_imported and base_name in PostImportHookPatching._post_import_hooks:
|
||||
for hook in PostImportHookPatching._post_import_hooks[base_name]:
|
||||
hook()
|
||||
return mod
|
||||
|
||||
|
@ -1,6 +1,10 @@
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from .import_bind import PostImportHookPatching
|
||||
from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
||||
from ..config import running_remotely
|
||||
from ..debugging.log import LoggerRoot
|
||||
@ -9,34 +13,48 @@ from ..model import Framework
|
||||
|
||||
class PatchedJoblib(object):
|
||||
_patched_joblib = False
|
||||
_patched_sk_joblib = False
|
||||
_current_task = None
|
||||
|
||||
@staticmethod
|
||||
def patch_joblib():
|
||||
if PatchedJoblib._patched_joblib:
|
||||
# We don't need to patch anything else, so we are done
|
||||
return True
|
||||
# try manually
|
||||
PatchedJoblib._patch_joblib()
|
||||
# register callback
|
||||
PostImportHookPatching.add_on_import('joblib',
|
||||
PatchedJoblib._patch_joblib)
|
||||
PostImportHookPatching.add_on_import('sklearn',
|
||||
PatchedJoblib._patch_joblib)
|
||||
|
||||
# whatever happens we should not retry to patch it
|
||||
PatchedJoblib._patched_joblib = True
|
||||
@staticmethod
|
||||
def _patch_joblib():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
try:
|
||||
import joblib
|
||||
except ImportError:
|
||||
joblib = None
|
||||
if not PatchedJoblib._patched_joblib and 'joblib' in sys.modules:
|
||||
PatchedJoblib._patched_joblib = True
|
||||
try:
|
||||
import joblib
|
||||
except ImportError:
|
||||
joblib = None
|
||||
|
||||
try:
|
||||
from sklearn.externals import joblib as sk_joblib
|
||||
except ImportError:
|
||||
sk_joblib = None
|
||||
if joblib:
|
||||
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
||||
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
||||
|
||||
if joblib:
|
||||
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
||||
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
||||
if sk_joblib:
|
||||
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
|
||||
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
|
||||
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
|
||||
PatchedJoblib._patched_sk_joblib = True
|
||||
try:
|
||||
import sklearn
|
||||
# avoid deprecation warning, we must import sklearn before, so we could catch it
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from sklearn.externals import joblib as sk_joblib
|
||||
except ImportError:
|
||||
sk_joblib = None
|
||||
|
||||
if sk_joblib:
|
||||
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
|
||||
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
@ -44,8 +62,8 @@ class PatchedJoblib(object):
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task):
|
||||
if PatchedJoblib.patch_joblib():
|
||||
PatchedJoblib._current_task = task
|
||||
PatchedJoblib._current_task = task
|
||||
PatchedJoblib.patch_joblib()
|
||||
|
||||
@staticmethod
|
||||
def _dump(original_fn, obj, f, *args, **kwargs):
|
||||
@ -57,7 +75,7 @@ class PatchedJoblib(object):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
|
Loading…
Reference in New Issue
Block a user