From 2c47e9f248388e11757659b337e5617df7e58643 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 9 Sep 2020 22:10:35 +0300 Subject: [PATCH] Fix joblib auto logging models failing on compressed streams, issue #203 --- trains/binding/joblib_bind.py | 50 ++++++++----- trains/utilities/lowlevel/file_access.py | 89 ++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 19 deletions(-) create mode 100644 trains/utilities/lowlevel/file_access.py diff --git a/trains/binding/joblib_bind.py b/trains/binding/joblib_bind.py index d7f0c98f..bcdd7008 100644 --- a/trains/binding/joblib_bind.py +++ b/trains/binding/joblib_bind.py @@ -10,6 +10,7 @@ from ..binding.frameworks import _patched_call, _Empty, WeightsFileHandler from ..config import running_remotely from ..debugging.log import LoggerRoot from ..model import Framework +from ..utilities.lowlevel.file_access import get_filename_from_file_object, buffer_writer_close_cb class PatchedJoblib(object): @@ -91,7 +92,19 @@ class PatchedJoblib(object): ret = original_fn(obj, f, *args, **kwargs) if not PatchedJoblib._current_task: return ret - PatchedJoblib._register_dump(obj, f) + + fname = f if isinstance(f, six.string_types) else None + fileobj = ret if isinstance(f, six.string_types) else f + + if fileobj and hasattr(fileobj, 'close'): + def callback(*_): + PatchedJoblib._register_dump(obj, fname or fileobj) + + if isinstance(fname, six.string_types) or hasattr(fileobj, 'name'): + buffer_writer_close_cb(fileobj, callback) + else: + PatchedJoblib._register_dump(obj, f) + return ret @staticmethod @@ -99,21 +112,24 @@ class PatchedJoblib(object): ret = original_fn(f, *args, **kwargs) if not PatchedJoblib._current_task: return ret - PatchedJoblib._register_dump(obj, f) + + fname = f if isinstance(f, six.string_types) else None + fileobj = ret if isinstance(f, six.string_types) else f + + if fileobj and hasattr(fileobj, 'close'): + def callback(*_): + PatchedJoblib._register_dump(obj, fname or fileobj) + + if isinstance(fname, six.string_types) or hasattr(fileobj, 'name'): + buffer_writer_close_cb(fileobj, callback) + else: + PatchedJoblib._register_dump(obj, f) return ret @staticmethod def _register_dump(obj, f): - if isinstance(f, six.string_types): - filename = f - elif hasattr(f, 'name'): - filename = f.name - # noinspection PyBroadException - try: - f.flush() - except Exception: - pass - else: + filename = get_filename_from_file_object(f, flush=True) + if not filename: return # give the model a descriptive name based on the file name @@ -128,16 +144,11 @@ class PatchedJoblib(object): @staticmethod def _load(original_fn, f, *args, **kwargs): - if isinstance(f, six.string_types): - filename = f - elif hasattr(f, 'name'): - filename = f.name - else: - filename = None - if not PatchedJoblib._current_task: return original_fn(f, *args, **kwargs) + filename = get_filename_from_file_object(f, flush=False) + # register input model empty = _Empty() # Hack: disabled @@ -165,6 +176,7 @@ class PatchedJoblib(object): @staticmethod def get_model_framework(obj): framework = Framework.scikitlearn + object_orig_module = None # noinspection PyBroadException try: object_orig_module = obj.__module__ if hasattr(obj, '__module__') else obj.__package__ diff --git a/trains/utilities/lowlevel/file_access.py b/trains/utilities/lowlevel/file_access.py new file mode 100644 index 00000000..54e89482 --- /dev/null +++ b/trains/utilities/lowlevel/file_access.py @@ -0,0 +1,89 @@ +import os +import sys +from functools import partial + +import six +from typing import Optional, Any, Callable + + +def __buffer_writer_close_patch(self): + self._trains_org_close() + # noinspection PyBroadException + try: + self._trains_close_cb(self) + except Exception: + pass + + +def buffer_writer_close_cb(bufferwriter, callback, overwrite=False): + # type: (Any, Callable[[Any], None], bool) -> () + # noinspection PyBroadException + try: + if not hasattr(bufferwriter, '_trains_org_close'): + bufferwriter._trains_org_close = bufferwriter.close + bufferwriter.close = partial(__buffer_writer_close_patch, bufferwriter) + elif not overwrite and hasattr(bufferwriter, '_trains_close_cb'): + return + + bufferwriter._trains_close_cb = callback + except Exception: + pass + + +def get_filename_from_file_object(file_object, flush=False, analyze_file_handle=False): + # type: (object, bool, bool) -> Optional[str] + """ + Return a string of the file location, extracted from any file object + :param file_object: str, file, stream, FileIO etc. + :param flush: If True, flush file object before returning (default: False) + :param analyze_file_handle: If True try to retrieve filename from file handler object (default: False) + :return: string full path of file location or None if filename cannot be extract + """ + if isinstance(file_object, six.string_types): + # noinspection PyBroadException + try: + return os.path.abspath(file_object) + except Exception: + return file_object + elif hasattr(file_object, 'name'): + filename = file_object.name + if flush: + # noinspection PyBroadException + try: + file_object.flush() + except Exception: + pass + return os.path.abspath(filename) + elif analyze_file_handle and isinstance(file_object, int) or hasattr(file_object, 'fileno'): + # noinspection PyBroadException + try: + fileno = file_object if isinstance(file_object, int) else file_object.fileno() + if sys.platform == 'win32': + import msvcrt + from ctypes import windll, create_string_buffer + handle = msvcrt.get_osfhandle(fileno) + name = create_string_buffer(2050) + windll.kernel32.GetFinalPathNameByHandleA(handle, name, 2048, 0) + filename = name.value.decode('utf-8') + if filename.startswith('\\\\?\\'): + filename = filename[4:] + if flush: + os.fsync(fileno) + return os.path.abspath(filename) + elif sys.platform == 'linux': + filename = os.readlink('/proc/self/fd/{}'.format(fileno)) + if flush: + os.fsync(fileno) + return os.path.abspath(filename) + elif sys.platform == 'darwin': + import fcntl + name = b' ' * 1024 + # F_GETPATH = 50 + name = fcntl.fcntl(fileno, 50, name) + filename = name.split(b'\x00')[0].decode() + if flush: + os.fsync(fileno) + return os.path.abspath(filename) + except Exception: + return None + return None