mirror of
https://github.com/clearml/clearml
synced 2025-06-23 01:55:38 +00:00
Fix deprecated warning on imports from sklearn
This commit is contained in:
parent
c44f3ff11c
commit
02c2a12449
@ -46,16 +46,16 @@ class PostImportHookPatching(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
|
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
already_imported = name in sys.modules
|
base_name = name.split('.')[0]
|
||||||
|
already_imported = (not base_name) or (base_name in sys.modules)
|
||||||
mod = builtins.__org_import__(
|
mod = builtins.__org_import__(
|
||||||
name,
|
name,
|
||||||
globals=globals,
|
globals=globals,
|
||||||
locals=locals,
|
locals=locals,
|
||||||
fromlist=fromlist,
|
fromlist=fromlist,
|
||||||
level=level)
|
level=level)
|
||||||
|
if not already_imported and base_name in PostImportHookPatching._post_import_hooks:
|
||||||
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
for hook in PostImportHookPatching._post_import_hooks[base_name]:
|
||||||
for hook in PostImportHookPatching._post_import_hooks[name]:
|
|
||||||
hook()
|
hook()
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from .import_bind import PostImportHookPatching
|
||||||
from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
||||||
from ..config import running_remotely
|
from ..config import running_remotely
|
||||||
from ..debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
@ -9,31 +13,45 @@ from ..model import Framework
|
|||||||
|
|
||||||
class PatchedJoblib(object):
|
class PatchedJoblib(object):
|
||||||
_patched_joblib = False
|
_patched_joblib = False
|
||||||
|
_patched_sk_joblib = False
|
||||||
_current_task = None
|
_current_task = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patch_joblib():
|
def patch_joblib():
|
||||||
if PatchedJoblib._patched_joblib:
|
# try manually
|
||||||
# We don't need to patch anything else, so we are done
|
PatchedJoblib._patch_joblib()
|
||||||
return True
|
# register callback
|
||||||
|
PostImportHookPatching.add_on_import('joblib',
|
||||||
|
PatchedJoblib._patch_joblib)
|
||||||
|
PostImportHookPatching.add_on_import('sklearn',
|
||||||
|
PatchedJoblib._patch_joblib)
|
||||||
|
|
||||||
# whatever happens we should not retry to patch it
|
@staticmethod
|
||||||
PatchedJoblib._patched_joblib = True
|
def _patch_joblib():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
if not PatchedJoblib._patched_joblib and 'joblib' in sys.modules:
|
||||||
|
PatchedJoblib._patched_joblib = True
|
||||||
try:
|
try:
|
||||||
import joblib
|
import joblib
|
||||||
except ImportError:
|
except ImportError:
|
||||||
joblib = None
|
joblib = None
|
||||||
|
|
||||||
|
if joblib:
|
||||||
|
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
||||||
|
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
||||||
|
|
||||||
|
if not PatchedJoblib._patched_sk_joblib and 'sklearn' in sys.modules:
|
||||||
|
PatchedJoblib._patched_sk_joblib = True
|
||||||
try:
|
try:
|
||||||
|
import sklearn
|
||||||
|
# avoid deprecation warning, we must import sklearn before, so we could catch it
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
from sklearn.externals import joblib as sk_joblib
|
from sklearn.externals import joblib as sk_joblib
|
||||||
except ImportError:
|
except ImportError:
|
||||||
sk_joblib = None
|
sk_joblib = None
|
||||||
|
|
||||||
if joblib:
|
|
||||||
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
|
||||||
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
|
||||||
if sk_joblib:
|
if sk_joblib:
|
||||||
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
|
sk_joblib.dump = _patched_call(sk_joblib.dump, PatchedJoblib._dump)
|
||||||
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
|
sk_joblib.load = _patched_call(sk_joblib.load, PatchedJoblib._load)
|
||||||
@ -44,8 +62,8 @@ class PatchedJoblib(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task):
|
def update_current_task(task):
|
||||||
if PatchedJoblib.patch_joblib():
|
|
||||||
PatchedJoblib._current_task = task
|
PatchedJoblib._current_task = task
|
||||||
|
PatchedJoblib.patch_joblib()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dump(original_fn, obj, f, *args, **kwargs):
|
def _dump(original_fn, obj, f, *args, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user