mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
Fix joblib auto logging models failing on compressed streams, issue #203
This commit is contained in:
parent
299ce14515
commit
2c47e9f248
@ -10,6 +10,7 @@ from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
||||
from ..config import running_remotely
|
||||
from ..debugging.log import LoggerRoot
|
||||
from ..model import Framework
|
||||
from ..utilities.lowlevel.file_access import get_filename_from_file_object, buffer_writer_close_cb
|
||||
|
||||
|
||||
class PatchedJoblib(object):
|
||||
@ -91,7 +92,19 @@ class PatchedJoblib(object):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchedJoblib._current_task:
|
||||
return ret
|
||||
PatchedJoblib._register_dump(obj, f)
|
||||
|
||||
fname = f if isinstance(f, six.string_types) else None
|
||||
fileobj = ret if isinstance(f, six.string_types) else f
|
||||
|
||||
if fileobj and hasattr(fileobj, 'close'):
|
||||
def callback(*_):
|
||||
PatchedJoblib._register_dump(obj, fname or fileobj)
|
||||
|
||||
if isinstance(fname, six.string_types) or hasattr(fileobj, 'name'):
|
||||
buffer_writer_close_cb(fileobj, callback)
|
||||
else:
|
||||
PatchedJoblib._register_dump(obj, f)
|
||||
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
@ -99,21 +112,24 @@ class PatchedJoblib(object):
|
||||
ret = original_fn(f, *args, **kwargs)
|
||||
if not PatchedJoblib._current_task:
|
||||
return ret
|
||||
PatchedJoblib._register_dump(obj, f)
|
||||
|
||||
fname = f if isinstance(f, six.string_types) else None
|
||||
fileobj = ret if isinstance(f, six.string_types) else f
|
||||
|
||||
if fileobj and hasattr(fileobj, 'close'):
|
||||
def callback(*_):
|
||||
PatchedJoblib._register_dump(obj, fname or fileobj)
|
||||
|
||||
if isinstance(fname, six.string_types) or hasattr(fileobj, 'name'):
|
||||
buffer_writer_close_cb(fileobj, callback)
|
||||
else:
|
||||
PatchedJoblib._register_dump(obj, f)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _register_dump(obj, f):
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
filename = get_filename_from_file_object(f, flush=True)
|
||||
if not filename:
|
||||
return
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
@ -128,16 +144,11 @@ class PatchedJoblib(object):
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchedJoblib._current_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
filename = get_filename_from_file_object(f, flush=False)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
# Hack: disabled
|
||||
@ -165,6 +176,7 @@ class PatchedJoblib(object):
|
||||
@staticmethod
|
||||
def get_model_framework(obj):
|
||||
framework = Framework.scikitlearn
|
||||
object_orig_module = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
object_orig_module = obj.__module__ if hasattr(obj, '__module__') else obj.__package__
|
||||
|
89
trains/utilities/lowlevel/file_access.py
Normal file
89
trains/utilities/lowlevel/file_access.py
Normal file
@ -0,0 +1,89 @@
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import six
|
||||
from typing import Optional, Any, Callable
|
||||
|
||||
|
||||
def __buffer_writer_close_patch(self):
|
||||
self._trains_org_close()
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._trains_close_cb(self)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def buffer_writer_close_cb(bufferwriter, callback, overwrite=False):
|
||||
# type: (Any, Callable[[Any], None], bool) -> ()
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if not hasattr(bufferwriter, '_trains_org_close'):
|
||||
bufferwriter._trains_org_close = bufferwriter.close
|
||||
bufferwriter.close = partial(__buffer_writer_close_patch, bufferwriter)
|
||||
elif not overwrite and hasattr(bufferwriter, '_trains_close_cb'):
|
||||
return
|
||||
|
||||
bufferwriter._trains_close_cb = callback
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_filename_from_file_object(file_object, flush=False, analyze_file_handle=False):
|
||||
# type: (object, bool, bool) -> Optional[str]
|
||||
"""
|
||||
Return a string of the file location, extracted from any file object
|
||||
:param file_object: str, file, stream, FileIO etc.
|
||||
:param flush: If True, flush file object before returning (default: False)
|
||||
:param analyze_file_handle: If True try to retrieve filename from file handler object (default: False)
|
||||
:return: string full path of file location or None if filename cannot be extract
|
||||
"""
|
||||
if isinstance(file_object, six.string_types):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return os.path.abspath(file_object)
|
||||
except Exception:
|
||||
return file_object
|
||||
elif hasattr(file_object, 'name'):
|
||||
filename = file_object.name
|
||||
if flush:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
file_object.flush()
|
||||
except Exception:
|
||||
pass
|
||||
return os.path.abspath(filename)
|
||||
elif analyze_file_handle and isinstance(file_object, int) or hasattr(file_object, 'fileno'):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
fileno = file_object if isinstance(file_object, int) else file_object.fileno()
|
||||
if sys.platform == 'win32':
|
||||
import msvcrt
|
||||
from ctypes import windll, create_string_buffer
|
||||
handle = msvcrt.get_osfhandle(fileno)
|
||||
name = create_string_buffer(2050)
|
||||
windll.kernel32.GetFinalPathNameByHandleA(handle, name, 2048, 0)
|
||||
filename = name.value.decode('utf-8')
|
||||
if filename.startswith('\\\\?\\'):
|
||||
filename = filename[4:]
|
||||
if flush:
|
||||
os.fsync(fileno)
|
||||
return os.path.abspath(filename)
|
||||
elif sys.platform == 'linux':
|
||||
filename = os.readlink('/proc/self/fd/{}'.format(fileno))
|
||||
if flush:
|
||||
os.fsync(fileno)
|
||||
return os.path.abspath(filename)
|
||||
elif sys.platform == 'darwin':
|
||||
import fcntl
|
||||
name = b' ' * 1024
|
||||
# F_GETPATH = 50
|
||||
name = fcntl.fcntl(fileno, 50, name)
|
||||
filename = name.split(b'\x00')[0].decode()
|
||||
if flush:
|
||||
os.fsync(fileno)
|
||||
return os.path.abspath(filename)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
Loading…
Reference in New Issue
Block a user