Fix joblib auto logging models failing on compressed streams, issue #203

This commit is contained in:
allegroai 2020-09-09 22:10:35 +03:00
parent 299ce14515
commit 2c47e9f248
2 changed files with 120 additions and 19 deletions

View File

@ -10,6 +10,7 @@ from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler
from ..config import running_remotely
from ..debugging.log import LoggerRoot
from ..model import Framework
from ..utilities.lowlevel.file_access import get_filename_from_file_object, buffer_writer_close_cb
class PatchedJoblib(object):
@ -91,7 +92,19 @@ class PatchedJoblib(object):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchedJoblib._current_task:
return ret
PatchedJoblib._register_dump(obj, f)
fname = f if isinstance(f, six.string_types) else None
fileobj = ret if isinstance(f, six.string_types) else f
if fileobj and hasattr(fileobj, 'close'):
def callback(*_):
PatchedJoblib._register_dump(obj, fname or fileobj)
if isinstance(fname, six.string_types) or hasattr(fileobj, 'name'):
buffer_writer_close_cb(fileobj, callback)
else:
PatchedJoblib._register_dump(obj, f)
return ret
@staticmethod
@ -99,21 +112,24 @@ class PatchedJoblib(object):
ret = original_fn(f, *args, **kwargs)
if not PatchedJoblib._current_task:
return ret
PatchedJoblib._register_dump(obj, f)
fname = f if isinstance(f, six.string_types) else None
fileobj = ret if isinstance(f, six.string_types) else f
if fileobj and hasattr(fileobj, 'close'):
def callback(*_):
PatchedJoblib._register_dump(obj, fname or fileobj)
if isinstance(fname, six.string_types) or hasattr(fileobj, 'name'):
buffer_writer_close_cb(fileobj, callback)
else:
PatchedJoblib._register_dump(obj, f)
return ret
@staticmethod
def _register_dump(obj, f):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
# noinspection PyBroadException
try:
f.flush()
except Exception:
pass
else:
filename = get_filename_from_file_object(f, flush=True)
if not filename:
return
# give the model a descriptive name based on the file name
@ -128,16 +144,11 @@ class PatchedJoblib(object):
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
filename = f.name
else:
filename = None
if not PatchedJoblib._current_task:
return original_fn(f, *args, **kwargs)
filename = get_filename_from_file_object(f, flush=False)
# register input model
empty = _Empty()
# Hack: disabled
@ -165,6 +176,7 @@ class PatchedJoblib(object):
@staticmethod
def get_model_framework(obj):
framework = Framework.scikitlearn
object_orig_module = None
# noinspection PyBroadException
try:
object_orig_module = obj.__module__ if hasattr(obj, '__module__') else obj.__package__

View File

@ -0,0 +1,89 @@
import os
import sys
from functools import partial
import six
from typing import Optional, Any, Callable
def __buffer_writer_close_patch(self):
self._trains_org_close()
# noinspection PyBroadException
try:
self._trains_close_cb(self)
except Exception:
pass
def buffer_writer_close_cb(bufferwriter, callback, overwrite=False):
# type: (Any, Callable[[Any], None], bool) -> ()
# noinspection PyBroadException
try:
if not hasattr(bufferwriter, '_trains_org_close'):
bufferwriter._trains_org_close = bufferwriter.close
bufferwriter.close = partial(__buffer_writer_close_patch, bufferwriter)
elif not overwrite and hasattr(bufferwriter, '_trains_close_cb'):
return
bufferwriter._trains_close_cb = callback
except Exception:
pass
def get_filename_from_file_object(file_object, flush=False, analyze_file_handle=False):
# type: (object, bool, bool) -> Optional[str]
"""
Return a string of the file location, extracted from any file object
:param file_object: str, file, stream, FileIO etc.
:param flush: If True, flush file object before returning (default: False)
:param analyze_file_handle: If True try to retrieve filename from file handler object (default: False)
:return: string full path of file location or None if filename cannot be extract
"""
if isinstance(file_object, six.string_types):
# noinspection PyBroadException
try:
return os.path.abspath(file_object)
except Exception:
return file_object
elif hasattr(file_object, 'name'):
filename = file_object.name
if flush:
# noinspection PyBroadException
try:
file_object.flush()
except Exception:
pass
return os.path.abspath(filename)
elif analyze_file_handle and isinstance(file_object, int) or hasattr(file_object, 'fileno'):
# noinspection PyBroadException
try:
fileno = file_object if isinstance(file_object, int) else file_object.fileno()
if sys.platform == 'win32':
import msvcrt
from ctypes import windll, create_string_buffer
handle = msvcrt.get_osfhandle(fileno)
name = create_string_buffer(2050)
windll.kernel32.GetFinalPathNameByHandleA(handle, name, 2048, 0)
filename = name.value.decode('utf-8')
if filename.startswith('\\\\?\\'):
filename = filename[4:]
if flush:
os.fsync(fileno)
return os.path.abspath(filename)
elif sys.platform == 'linux':
filename = os.readlink('/proc/self/fd/{}'.format(fileno))
if flush:
os.fsync(fileno)
return os.path.abspath(filename)
elif sys.platform == 'darwin':
import fcntl
name = b' ' * 1024
# F_GETPATH = 50
name = fcntl.fcntl(fileno, 50, name)
filename = name.split(b'\x00')[0].decode()
if flush:
os.fsync(fileno)
return os.path.abspath(filename)
except Exception:
return None
return None