mirror of
https://github.com/clearml/clearml
synced 2025-03-03 18:52:12 +00:00
Fix joblib auto logging models failing on compressed streams, issue #203
This commit is contained in:
parent
299ce14515
commit
2c47e9f248
@ -10,6 +10,7 @@ from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
|||||||
from ..config import running_remotely
|
from ..config import running_remotely
|
||||||
from ..debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
from ..model import Framework
|
from ..model import Framework
|
||||||
|
from ..utilities.lowlevel.file_access import get_filename_from_file_object, buffer_writer_close_cb
|
||||||
|
|
||||||
|
|
||||||
class PatchedJoblib(object):
|
class PatchedJoblib(object):
|
||||||
@ -91,7 +92,19 @@ class PatchedJoblib(object):
|
|||||||
ret = original_fn(obj, f, *args, **kwargs)
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
if not PatchedJoblib._current_task:
|
if not PatchedJoblib._current_task:
|
||||||
return ret
|
return ret
|
||||||
PatchedJoblib._register_dump(obj, f)
|
|
||||||
|
fname = f if isinstance(f, six.string_types) else None
|
||||||
|
fileobj = ret if isinstance(f, six.string_types) else f
|
||||||
|
|
||||||
|
if fileobj and hasattr(fileobj, 'close'):
|
||||||
|
def callback(*_):
|
||||||
|
PatchedJoblib._register_dump(obj, fname or fileobj)
|
||||||
|
|
||||||
|
if isinstance(fname, six.string_types) or hasattr(fileobj, 'name'):
|
||||||
|
buffer_writer_close_cb(fileobj, callback)
|
||||||
|
else:
|
||||||
|
PatchedJoblib._register_dump(obj, f)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -99,21 +112,24 @@ class PatchedJoblib(object):
|
|||||||
ret = original_fn(f, *args, **kwargs)
|
ret = original_fn(f, *args, **kwargs)
|
||||||
if not PatchedJoblib._current_task:
|
if not PatchedJoblib._current_task:
|
||||||
return ret
|
return ret
|
||||||
PatchedJoblib._register_dump(obj, f)
|
|
||||||
|
fname = f if isinstance(f, six.string_types) else None
|
||||||
|
fileobj = ret if isinstance(f, six.string_types) else f
|
||||||
|
|
||||||
|
if fileobj and hasattr(fileobj, 'close'):
|
||||||
|
def callback(*_):
|
||||||
|
PatchedJoblib._register_dump(obj, fname or fileobj)
|
||||||
|
|
||||||
|
if isinstance(fname, six.string_types) or hasattr(fileobj, 'name'):
|
||||||
|
buffer_writer_close_cb(fileobj, callback)
|
||||||
|
else:
|
||||||
|
PatchedJoblib._register_dump(obj, f)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _register_dump(obj, f):
|
def _register_dump(obj, f):
|
||||||
if isinstance(f, six.string_types):
|
filename = get_filename_from_file_object(f, flush=True)
|
||||||
filename = f
|
if not filename:
|
||||||
elif hasattr(f, 'name'):
|
|
||||||
filename = f.name
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
f.flush()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# give the model a descriptive name based on the file name
|
# give the model a descriptive name based on the file name
|
||||||
@ -128,16 +144,11 @@ class PatchedJoblib(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, f, *args, **kwargs):
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
if isinstance(f, six.string_types):
|
|
||||||
filename = f
|
|
||||||
elif hasattr(f, 'name'):
|
|
||||||
filename = f.name
|
|
||||||
else:
|
|
||||||
filename = None
|
|
||||||
|
|
||||||
if not PatchedJoblib._current_task:
|
if not PatchedJoblib._current_task:
|
||||||
return original_fn(f, *args, **kwargs)
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
|
filename = get_filename_from_file_object(f, flush=False)
|
||||||
|
|
||||||
# register input model
|
# register input model
|
||||||
empty = _Empty()
|
empty = _Empty()
|
||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
@ -165,6 +176,7 @@ class PatchedJoblib(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_framework(obj):
|
def get_model_framework(obj):
|
||||||
framework = Framework.scikitlearn
|
framework = Framework.scikitlearn
|
||||||
|
object_orig_module = None
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
object_orig_module = obj.__module__ if hasattr(obj, '__module__') else obj.__package__
|
object_orig_module = obj.__module__ if hasattr(obj, '__module__') else obj.__package__
|
||||||
|
89
trains/utilities/lowlevel/file_access.py
Normal file
89
trains/utilities/lowlevel/file_access.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import six
|
||||||
|
from typing import Optional, Any, Callable
|
||||||
|
|
||||||
|
|
||||||
|
def __buffer_writer_close_patch(self):
|
||||||
|
self._trains_org_close()
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
self._trains_close_cb(self)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def buffer_writer_close_cb(bufferwriter, callback, overwrite=False):
|
||||||
|
# type: (Any, Callable[[Any], None], bool) -> ()
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if not hasattr(bufferwriter, '_trains_org_close'):
|
||||||
|
bufferwriter._trains_org_close = bufferwriter.close
|
||||||
|
bufferwriter.close = partial(__buffer_writer_close_patch, bufferwriter)
|
||||||
|
elif not overwrite and hasattr(bufferwriter, '_trains_close_cb'):
|
||||||
|
return
|
||||||
|
|
||||||
|
bufferwriter._trains_close_cb = callback
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_filename_from_file_object(file_object, flush=False, analyze_file_handle=False):
|
||||||
|
# type: (object, bool, bool) -> Optional[str]
|
||||||
|
"""
|
||||||
|
Return a string of the file location, extracted from any file object
|
||||||
|
:param file_object: str, file, stream, FileIO etc.
|
||||||
|
:param flush: If True, flush file object before returning (default: False)
|
||||||
|
:param analyze_file_handle: If True try to retrieve filename from file handler object (default: False)
|
||||||
|
:return: string full path of file location or None if filename cannot be extract
|
||||||
|
"""
|
||||||
|
if isinstance(file_object, six.string_types):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
return os.path.abspath(file_object)
|
||||||
|
except Exception:
|
||||||
|
return file_object
|
||||||
|
elif hasattr(file_object, 'name'):
|
||||||
|
filename = file_object.name
|
||||||
|
if flush:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
file_object.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return os.path.abspath(filename)
|
||||||
|
elif analyze_file_handle and isinstance(file_object, int) or hasattr(file_object, 'fileno'):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
fileno = file_object if isinstance(file_object, int) else file_object.fileno()
|
||||||
|
if sys.platform == 'win32':
|
||||||
|
import msvcrt
|
||||||
|
from ctypes import windll, create_string_buffer
|
||||||
|
handle = msvcrt.get_osfhandle(fileno)
|
||||||
|
name = create_string_buffer(2050)
|
||||||
|
windll.kernel32.GetFinalPathNameByHandleA(handle, name, 2048, 0)
|
||||||
|
filename = name.value.decode('utf-8')
|
||||||
|
if filename.startswith('\\\\?\\'):
|
||||||
|
filename = filename[4:]
|
||||||
|
if flush:
|
||||||
|
os.fsync(fileno)
|
||||||
|
return os.path.abspath(filename)
|
||||||
|
elif sys.platform == 'linux':
|
||||||
|
filename = os.readlink('/proc/self/fd/{}'.format(fileno))
|
||||||
|
if flush:
|
||||||
|
os.fsync(fileno)
|
||||||
|
return os.path.abspath(filename)
|
||||||
|
elif sys.platform == 'darwin':
|
||||||
|
import fcntl
|
||||||
|
name = b' ' * 1024
|
||||||
|
# F_GETPATH = 50
|
||||||
|
name = fcntl.fcntl(fileno, 50, name)
|
||||||
|
filename = name.split(b'\x00')[0].decode()
|
||||||
|
if flush:
|
||||||
|
os.fsync(fileno)
|
||||||
|
return os.path.abspath(filename)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return None
|
Loading…
Reference in New Issue
Block a user