diff --git a/trains/backend_interface/base.py b/trains/backend_interface/base.py index b2463c43..50649d4e 100644 --- a/trains/backend_interface/base.py +++ b/trains/backend_interface/base.py @@ -55,18 +55,22 @@ class InterfaceBase(SessionInterface): if log: log.error(error_msg) - if res.meta.result_code <= 500: - # Proper backend error/bad status code - raise or return - if raise_on_errors: - raise SendError(res, error_msg) - return res - except requests.exceptions.BaseHTTPError as e: - log.error('failed sending %s: %s' % (str(req), str(e))) + res = None + log.error('Failed sending %s: %s' % (str(req), str(e))) + except Exception as e: + res = None + log.error('Failed sending %s: %s' % (str(req), str(e))) - # Infrastructure error - if log: - log.info('retrying request %s' % str(req)) + if res and res.meta.result_code <= 500: + # Proper backend error/bad status code - raise or return + if raise_on_errors: + raise SendError(res, error_msg) + return res + + # # Infrastructure error + # if log: + # log.info('retrying request %s' % str(req)) def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False): return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors, diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index e29074a6..c2fd08f3 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -1,8 +1,8 @@ import collections import json -import cv2 import six +from threading import Thread, Event from ..base import InterfaceBase from ..setupuploadmixin import SetupUploadMixin @@ -47,6 +47,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan self._bucket_config = None self._storage_uri = None self._async_enable = async_enable + self._flush_frequency = 30.0 + self._exit_flag = False + self._flush_event = Event() + self._flush_event.clear() + self._thread = Thread(target=self._daemon) + self._thread.daemon = True + self._thread.start() def _set_storage_uri(self, value): value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x) @@ -70,10 +77,19 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan def async_enable(self, value): self._async_enable = bool(value) + def _daemon(self): + while not self._exit_flag: + self._flush_event.wait(self._flush_frequency) + self._flush_event.clear() + self._write() + # wait for all reports + if self.get_num_results() > 0: + self.wait_for_results() + def _report(self, ev): self._events.append(ev) if len(self._events) >= self._flush_threshold: - self._write() + self.flush() def _write(self): if not self._events: @@ -88,10 +104,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan """ Flush cached reports to backend. """ - self._write() - # wait for all reports - if self.get_num_results() > 0: - self.wait_for_results() + self._flush_event.set() + + def stop(self): + self._exit_flag = True + self._flush_event.set() + self._thread.join() def report_scalar(self, title, series, value, iter): """ diff --git a/trains/binding/environ_bind.py b/trains/binding/environ_bind.py index 1238a406..799125ec 100644 --- a/trains/binding/environ_bind.py +++ b/trains/binding/environ_bind.py @@ -1,5 +1,7 @@ import os +import six + from ..config import TASK_LOG_ENVIRONMENT, running_remotely @@ -34,3 +36,43 @@ class EnvironmentBind(object): if running_remotely(): # put back into os: os.environ.update(env_param) + + +class PatchOsFork(object): + _original_fork = None + + @classmethod + def patch_fork(cls): + # only once + if cls._original_fork: + return + if six.PY2: + cls._original_fork = staticmethod(os.fork) + else: + cls._original_fork = os.fork + os.fork = cls._patched_fork + + @staticmethod + def _patched_fork(*args, **kwargs): + ret = PatchOsFork._original_fork(*args, **kwargs) + # Make sure the new process stdout is logged + if not ret: + from ..task import Task + if Task.current_task() is not None: + # bind sub-process logger + task = Task.init() + task.get_logger().flush() + + # if we got here patch the os._exit of our instance to call us + def _at_exit_callback(*args, **kwargs): + # call at exit manually + # noinspection PyProtectedMember + task._at_exit() + # noinspection PyProtectedMember + return os._org_exit(*args, **kwargs) + + if not hasattr(os, '_org_exit'): + os._org_exit = os._exit + os._exit = _at_exit_callback + + return ret diff --git a/trains/logger.py b/trains/logger.py index 8d208512..b364b1ef 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -81,11 +81,14 @@ class Logger(object): self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100) # noinspection PyBroadException try: - Logger._stdout_original_write = sys.stdout.write + if Logger._stdout_original_write is None: + Logger._stdout_original_write = sys.stdout.write # this will only work in python 3, guard it with try/catch - sys.stdout._original_write = sys.stdout.write + if not hasattr(sys.stdout, '_original_write'): + sys.stdout._original_write = sys.stdout.write sys.stdout.write = stdout__patched__write__ - sys.stderr._original_write = sys.stderr.write + if not hasattr(sys.stderr, '_original_write'): + sys.stderr._original_write = sys.stderr.write sys.stderr.write = stderr__patched__write__ except Exception: pass @@ -113,6 +116,7 @@ class Logger(object): msg='Logger failed casting log level "%s" to integer' % str(level)) level = logging.INFO + # noinspection PyBroadException try: record = self._task.log.makeRecord( "console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None @@ -128,6 +132,7 @@ class Logger(object): if not omit_console: # if we are here and we grabbed the stdout, we need to print the real thing if DevWorker.report_stdout: + # noinspection PyBroadException try: # make sure we are writing to the original stdout Logger._stdout_original_write(str(msg)+'\n') @@ -637,11 +642,13 @@ class Logger(object): @classmethod def _remove_std_logger(self): if isinstance(sys.stdout, PrintPatchLogger): + # noinspection PyBroadException try: sys.stdout.connect(None) except Exception: pass if isinstance(sys.stderr, PrintPatchLogger): + # noinspection PyBroadException try: sys.stderr.connect(None) except Exception: @@ -711,7 +718,13 @@ class PrintPatchLogger(object): if cur_line: with PrintPatchLogger.recursion_protect_lock: - self._log.console(cur_line, level=self._log_level, omit_console=True) + # noinspection PyBroadException + try: + if self._log: + self._log.console(cur_line, level=self._log_level, omit_console=True) + except Exception: + # what can we do, nothing + pass else: if hasattr(self._terminal, '_original_write'): self._terminal._original_write(message) @@ -719,8 +732,7 @@ class PrintPatchLogger(object): self._terminal.write(message) def connect(self, logger): - if self._log: - self._log._flush_stdout_handler() + self._cur_line = '' self._log = logger def __getattr__(self, attr): diff --git a/trains/task.py b/trains/task.py index 1952a9a3..9939deaa 100644 --- a/trains/task.py +++ b/trains/task.py @@ -26,7 +26,7 @@ from .errors import UsageError from .logger import Logger from .model import InputModel, OutputModel, ARCHIVED_TAG from .task_parameters import TaskParameters -from .binding.environ_bind import EnvironmentBind +from .binding.environ_bind import EnvironmentBind, PatchOsFork from .binding.absl_bind import PatchAbsl from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ argparser_update_currenttask @@ -66,6 +66,7 @@ class Task(_Task): __create_protection = object() __main_task = None __exit_hook = None + __forked_proc_main_pid = None __task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0)) __store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False) __detect_repo_async = config.get('development.vcs_repo_detect_async', False) @@ -104,7 +105,6 @@ class Task(_Task): self._resource_monitor = None # register atexit, so that we mark the task as stopped self._at_exit_called = False - self.__register_at_exit(self._at_exit) @classmethod def current_task(cls): @@ -132,9 +132,10 @@ class Task(_Task): :param project_name: project to create the task in (if project doesn't exist, it will be created) :param task_name: task name to be created (in development mode, not when running remotely) :param task_type: task type to be created (in development mode, not when running remotely) - :param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). \ - if False every time we call the function we create a new task with the same name \ - Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) \ + :param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). + if False every time we call the function we create a new task with the same name + Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) + If reuse_last_task_id is of type string, it will assume this is the task_id to reuse! Note: A closed or published task will not be reused, and a new task will be created. :param output_uri: Default location for output models (currently support folder/S3/GS/ ). notice: sub-folders (task_id) is created in the destination folder for all outputs. @@ -166,12 +167,31 @@ class Task(_Task): ) if cls.__main_task is not None: + # if this is a subprocess, regardless of what the init was called for, + # we have to fix the main task hooks and stdout bindings + if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): + # make sure we only do it once per process + cls.__forked_proc_main_pid = os.getpid() + # make sure we do not wait for the repo detect thread + cls.__main_task._detect_repo_async_thread = None + # remove the logger from the previous process + logger = cls.__main_task.get_logger() + logger.set_flush_period(None) + # create a new logger (to catch stdout/err) + cls.__main_task._logger = None + cls.__main_task._reporter = None + cls.__main_task.get_logger() + # unregister signal hooks, they cause subprocess to hang + cls.__main_task.__register_at_exit(cls.__main_task._at_exit) + cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True) + if not running_remotely(): verify_defaults_match() return cls.__main_task - # check that we are not a child process, in that case do nothing + # check that we are not a child process, in that case do nothing. + # we should not get here unless this is Windows platform, all others support fork if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): class _TaskStub(object): def __call__(self, *args, **kwargs): @@ -212,9 +232,10 @@ class Task(_Task): raise else: Task.__main_task = task - # Patch argparse to be aware of the current task - argparser_update_currenttask(Task.__main_task) - EnvironmentBind.update_current_task(Task.__main_task) + # register the main task for at exit hooks (there should only be one) + task.__register_at_exit(task._at_exit) + # patch OS forking + PatchOsFork.patch_fork() if auto_connect_frameworks: PatchedMatplotlib.update_current_task(Task.__main_task) PatchAbsl.update_current_task(Task.__main_task) @@ -227,21 +248,19 @@ class Task(_Task): if auto_resource_monitoring: task._resource_monitor = ResourceMonitor(task) task._resource_monitor.start() - # Check if parse args already called. If so, sync task parameters with parser - if argparser_parseargs_called(): - parser, parsed_args = get_argparser_last_args() - task._connect_argparse(parser=parser, parsed_args=parsed_args) # make sure all random generators are initialized with new seed make_deterministic(task.get_random_seed()) if auto_connect_arg_parser: + EnvironmentBind.update_current_task(Task.__main_task) + # Patch ArgParser to be aware of the current task argparser_update_currenttask(Task.__main_task) # Check if parse args already called. If so, sync task parameters with parser if argparser_parseargs_called(): parser, parsed_args = get_argparser_last_args() - task._connect_argparse(parser, parsed_args=parsed_args) + task._connect_argparse(parser=parser, parsed_args=parsed_args) # Make sure we start the logger, it will patch the main logging object and pipe all output # if we are running locally and using development mode worker, we will pipe all stdout to logger. @@ -339,7 +358,9 @@ class Task(_Task): in_dev_mode = not running_remotely() if in_dev_mode: - if not reuse_last_task_id or not cls.__task_is_relevant(default_task): + if isinstance(reuse_last_task_id, str) and reuse_last_task_id: + default_task_id = reuse_last_task_id + elif not reuse_last_task_id or not cls.__task_is_relevant(default_task): default_task_id = None closed_old_task = cls.__close_timed_out_task(default_task) else: @@ -600,6 +621,9 @@ class Task(_Task): """ self._at_exit() self._at_exit_called = False + # unregister atexit callbacks and signal hooks, if we are the main task + if self.is_main_task(): + self.__register_at_exit(None) def is_current_task(self): """ @@ -914,9 +938,12 @@ class Task(_Task): Will happen automatically once we exit code, i.e. atexit :return: """ + # protect sub-process at_exit if self._at_exit_called: return + is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid() + # noinspection PyBroadException try: # from here do not get into watch dog @@ -948,28 +975,32 @@ class Task(_Task): # from here, do not send log in background thread if wait_for_uploads: self.flush(wait_for_uploads=True) + # wait until the reporter flush everything + self.reporter.stop() if print_done_waiting: self.log.info('Finished uploading') else: self._logger._flush_stdout_handler() - # from here, do not check worker status - if self._dev_worker: - self._dev_worker.unregister() + if not is_sub_process: + # from here, do not check worker status + if self._dev_worker: + self._dev_worker.unregister() - # change task status - if not task_status: - pass - elif task_status[0] == 'failed': - self.mark_failed(status_reason=task_status[1]) - elif task_status[0] == 'completed': - self.completed() - elif task_status[0] == 'stopped': - self.stopped() + # change task status + if not task_status: + pass + elif task_status[0] == 'failed': + self.mark_failed(status_reason=task_status[1]) + elif task_status[0] == 'completed': + self.completed() + elif task_status[0] == 'stopped': + self.stopped() # stop resource monitoring if self._resource_monitor: self._resource_monitor.stop() + self._logger.set_flush_period(None) # this is so in theory we can close a main task and start a new one Task.__main_task = None @@ -978,7 +1009,7 @@ class Task(_Task): pass @classmethod - def __register_at_exit(cls, exit_callback): + def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False): class ExitHooks(object): _orig_exit = None _orig_exc_handler = None @@ -1000,7 +1031,21 @@ class Task(_Task): except Exception: pass self._exit_callback = callback - atexit.register(self._exit_callback) + if callback: + self.hook() + else: + # un register int hook + print('removing int hook', self._orig_exc_handler) + if self._orig_exc_handler: + sys.excepthook = self._orig_exc_handler + self._orig_exc_handler = None + for s in self._org_handlers: + # noinspection PyBroadException + try: + signal.signal(s, self._org_handlers[s]) + except Exception: + pass + self._org_handlers = {} def hook(self): if self._orig_exit is None: @@ -1009,20 +1054,23 @@ class Task(_Task): if self._orig_exc_handler is None: self._orig_exc_handler = sys.excepthook sys.excepthook = self.exc_handler - atexit.register(self._exit_callback) - if sys.platform == 'win32': - catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, - signal.SIGILL, signal.SIGFPE] - else: - catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, - signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] - for s in catch_signals: - # noinspection PyBroadException - try: - self._org_handlers[s] = signal.getsignal(s) - signal.signal(s, self.signal_handler) - except Exception: - pass + if self._exit_callback: + atexit.register(self._exit_callback) + + if self._org_handlers: + if sys.platform == 'win32': + catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, + signal.SIGILL, signal.SIGFPE] + else: + catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT, + signal.SIGILL, signal.SIGFPE, signal.SIGQUIT] + for s in catch_signals: + # noinspection PyBroadException + try: + self._org_handlers[s] = signal.getsignal(s) + signal.signal(s, self.signal_handler) + except Exception: + pass def exit(self, code=0): self.exit_code = code @@ -1077,6 +1125,22 @@ class Task(_Task): # return handler result return org_handler + # we only remove the signals since this will hang subprocesses + if only_remove_signal_and_exception_hooks: + if not cls.__exit_hook: + return + if cls.__exit_hook._orig_exc_handler: + sys.excepthook = cls.__exit_hook._orig_exc_handler + cls.__exit_hook._orig_exc_handler = None + for s in cls.__exit_hook._org_handlers: + # noinspection PyBroadException + try: + signal.signal(s, cls.__exit_hook._org_handlers[s]) + except Exception: + pass + cls.__exit_hook._org_handlers = {} + return + if cls.__exit_hook is None: # noinspection PyBroadException try: @@ -1084,13 +1148,13 @@ class Task(_Task): cls.__exit_hook.hook() except Exception: cls.__exit_hook = None - elif cls.__main_task is None: + else: cls.__exit_hook.update_callback(exit_callback) @classmethod def __get_task(cls, task_id=None, project_name=None, task_name=None): if task_id: - return cls(private=cls.__create_protection, task_id=task_id) + return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) res = cls._send( cls._get_default_session(), diff --git a/trains/utilities/async_manager.py b/trains/utilities/async_manager.py index 45732d67..3c19441a 100644 --- a/trains/utilities/async_manager.py +++ b/trains/utilities/async_manager.py @@ -1,3 +1,4 @@ +import os import time from threading import Lock @@ -6,7 +7,8 @@ import six class AsyncManagerMixin(object): _async_results_lock = Lock() - _async_results = [] + # per pid (process) list of async jobs (support for sub-processes forking) + _async_results = {} @classmethod def _add_async_result(cls, result, wait_on_max_results=None, wait_time=30, wait_cb=None): @@ -14,8 +16,9 @@ class AsyncManagerMixin(object): try: cls._async_results_lock.acquire() # discard completed results - cls._async_results = [r for r in cls._async_results if not r.ready()] - num_results = len(cls._async_results) + pid = os.getpid() + cls._async_results[pid] = [r for r in cls._async_results.get(pid, []) if not r.ready()] + num_results = len(cls._async_results[pid]) if wait_on_max_results is not None and num_results >= wait_on_max_results: # At least max_results results are still pending, wait if wait_cb: @@ -25,7 +28,7 @@ class AsyncManagerMixin(object): continue # add result if result and not result.ready(): - cls._async_results.append(result) + cls._async_results[pid] = cls._async_results.get(pid, []).append(result) break finally: cls._async_results_lock.release() @@ -34,7 +37,8 @@ class AsyncManagerMixin(object): def wait_for_results(cls, timeout=None, max_num_uploads=None): remaining = timeout count = 0 - for r in cls._async_results: + pid = os.getpid() + for r in cls._async_results.get(pid, []): if r.ready(): continue t = time.time() @@ -48,13 +52,14 @@ class AsyncManagerMixin(object): if max_num_uploads is not None and max_num_uploads - count <= 0: break if timeout is not None: - remaining = max(0, remaining - max(0, time.time() - t)) + remaining = max(0., remaining - max(0., time.time() - t)) if not remaining: break @classmethod def get_num_results(cls): - if cls._async_results is not None: - return len([r for r in cls._async_results if not r.ready()]) + pid = os.getpid() + if cls._async_results.get(pid, []): + return len([r for r in cls._async_results.get(pid, []) if not r.ready()]) else: return 0