Fix support for sub-process (process pool)

This commit is contained in:
allegroai 2019-07-20 23:11:54 +03:00
parent 50ce49a3dd
commit c80aae0e1e
6 changed files with 220 additions and 75 deletions

View File

@ -55,18 +55,22 @@ class InterfaceBase(SessionInterface):
if log:
log.error(error_msg)
if res.meta.result_code <= 500:
# Proper backend error/bad status code - raise or return
if raise_on_errors:
raise SendError(res, error_msg)
return res
except requests.exceptions.BaseHTTPError as e:
log.error('failed sending %s: %s' % (str(req), str(e)))
res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))
except Exception as e:
res = None
log.error('Failed sending %s: %s' % (str(req), str(e)))
# Infrastructure error
if log:
log.info('retrying request %s' % str(req))
if res and res.meta.result_code <= 500:
# Proper backend error/bad status code - raise or return
if raise_on_errors:
raise SendError(res, error_msg)
return res
# # Infrastructure error
# if log:
# log.info('retrying request %s' % str(req))
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,

View File

@ -1,8 +1,8 @@
import collections
import json
import cv2
import six
from threading import Thread, Event
from ..base import InterfaceBase
from ..setupuploadmixin import SetupUploadMixin
@ -47,6 +47,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
self._bucket_config = None
self._storage_uri = None
self._async_enable = async_enable
self._flush_frequency = 30.0
self._exit_flag = False
self._flush_event = Event()
self._flush_event.clear()
self._thread = Thread(target=self._daemon)
self._thread.daemon = True
self._thread.start()
def _set_storage_uri(self, value):
value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x)
@ -70,10 +77,19 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
def async_enable(self, value):
self._async_enable = bool(value)
def _daemon(self):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency)
self._flush_event.clear()
self._write()
# wait for all reports
if self.get_num_results() > 0:
self.wait_for_results()
def _report(self, ev):
self._events.append(ev)
if len(self._events) >= self._flush_threshold:
self._write()
self.flush()
def _write(self):
if not self._events:
@ -88,10 +104,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
"""
Flush cached reports to backend.
"""
self._write()
# wait for all reports
if self.get_num_results() > 0:
self.wait_for_results()
self._flush_event.set()
def stop(self):
self._exit_flag = True
self._flush_event.set()
self._thread.join()
def report_scalar(self, title, series, value, iter):
"""

View File

@ -1,5 +1,7 @@
import os
import six
from ..config import TASK_LOG_ENVIRONMENT, running_remotely
@ -34,3 +36,43 @@ class EnvironmentBind(object):
if running_remotely():
# put back into os:
os.environ.update(env_param)
class PatchOsFork(object):
_original_fork = None
@classmethod
def patch_fork(cls):
# only once
if cls._original_fork:
return
if six.PY2:
cls._original_fork = staticmethod(os.fork)
else:
cls._original_fork = os.fork
os.fork = cls._patched_fork
@staticmethod
def _patched_fork(*args, **kwargs):
ret = PatchOsFork._original_fork(*args, **kwargs)
# Make sure the new process stdout is logged
if not ret:
from ..task import Task
if Task.current_task() is not None:
# bind sub-process logger
task = Task.init()
task.get_logger().flush()
# if we got here patch the os._exit of our instance to call us
def _at_exit_callback(*args, **kwargs):
# call at exit manually
# noinspection PyProtectedMember
task._at_exit()
# noinspection PyProtectedMember
return os._org_exit(*args, **kwargs)
if not hasattr(os, '_org_exit'):
os._org_exit = os._exit
os._exit = _at_exit_callback
return ret

View File

@ -81,11 +81,14 @@ class Logger(object):
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
# noinspection PyBroadException
try:
Logger._stdout_original_write = sys.stdout.write
if Logger._stdout_original_write is None:
Logger._stdout_original_write = sys.stdout.write
# this will only work in python 3, guard it with try/catch
sys.stdout._original_write = sys.stdout.write
if not hasattr(sys.stdout, '_original_write'):
sys.stdout._original_write = sys.stdout.write
sys.stdout.write = stdout__patched__write__
sys.stderr._original_write = sys.stderr.write
if not hasattr(sys.stderr, '_original_write'):
sys.stderr._original_write = sys.stderr.write
sys.stderr.write = stderr__patched__write__
except Exception:
pass
@ -113,6 +116,7 @@ class Logger(object):
msg='Logger failed casting log level "%s" to integer' % str(level))
level = logging.INFO
# noinspection PyBroadException
try:
record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
@ -128,6 +132,7 @@ class Logger(object):
if not omit_console:
# if we are here and we grabbed the stdout, we need to print the real thing
if DevWorker.report_stdout:
# noinspection PyBroadException
try:
# make sure we are writing to the original stdout
Logger._stdout_original_write(str(msg)+'\n')
@ -637,11 +642,13 @@ class Logger(object):
@classmethod
def _remove_std_logger(self):
if isinstance(sys.stdout, PrintPatchLogger):
# noinspection PyBroadException
try:
sys.stdout.connect(None)
except Exception:
pass
if isinstance(sys.stderr, PrintPatchLogger):
# noinspection PyBroadException
try:
sys.stderr.connect(None)
except Exception:
@ -711,7 +718,13 @@ class PrintPatchLogger(object):
if cur_line:
with PrintPatchLogger.recursion_protect_lock:
self._log.console(cur_line, level=self._log_level, omit_console=True)
# noinspection PyBroadException
try:
if self._log:
self._log.console(cur_line, level=self._log_level, omit_console=True)
except Exception:
# what can we do, nothing
pass
else:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
@ -719,8 +732,7 @@ class PrintPatchLogger(object):
self._terminal.write(message)
def connect(self, logger):
if self._log:
self._log._flush_stdout_handler()
self._cur_line = ''
self._log = logger
def __getattr__(self, attr):

View File

@ -26,7 +26,7 @@ from .errors import UsageError
from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .binding.environ_bind import EnvironmentBind
from .binding.environ_bind import EnvironmentBind, PatchOsFork
from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
@ -66,6 +66,7 @@ class Task(_Task):
__create_protection = object()
__main_task = None
__exit_hook = None
__forked_proc_main_pid = None
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
__store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False)
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
@ -104,7 +105,6 @@ class Task(_Task):
self._resource_monitor = None
# register atexit, so that we mark the task as stopped
self._at_exit_called = False
self.__register_at_exit(self._at_exit)
@classmethod
def current_task(cls):
@ -132,9 +132,10 @@ class Task(_Task):
:param project_name: project to create the task in (if project doesn't exist, it will be created)
:param task_name: task name to be created (in development mode, not when running remotely)
:param task_type: task type to be created (in development mode, not when running remotely)
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). \
if False every time we call the function we create a new task with the same name \
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) \
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder).
if False every time we call the function we create a new task with the same name
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies)
If reuse_last_task_id is of type string, it will assume this is the task_id to reuse!
Note: A closed or published task will not be reused, and a new task will be created.
:param output_uri: Default location for output models (currently support folder/S3/GS/ ).
notice: sub-folders (task_id) is created in the destination folder for all outputs.
@ -166,12 +167,31 @@ class Task(_Task):
)
if cls.__main_task is not None:
# if this is a subprocess, regardless of what the init was called for,
# we have to fix the main task hooks and stdout bindings
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
# make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid()
# make sure we do not wait for the repo detect thread
cls.__main_task._detect_repo_async_thread = None
# remove the logger from the previous process
logger = cls.__main_task.get_logger()
logger.set_flush_period(None)
# create a new logger (to catch stdout/err)
cls.__main_task._logger = None
cls.__main_task._reporter = None
cls.__main_task.get_logger()
# unregister signal hooks, they cause subprocess to hang
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
if not running_remotely():
verify_defaults_match()
return cls.__main_task
# check that we are not a child process, in that case do nothing
# check that we are not a child process, in that case do nothing.
# we should not get here unless this is Windows platform, all others support fork
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
class _TaskStub(object):
def __call__(self, *args, **kwargs):
@ -212,9 +232,10 @@ class Task(_Task):
raise
else:
Task.__main_task = task
# Patch argparse to be aware of the current task
argparser_update_currenttask(Task.__main_task)
EnvironmentBind.update_current_task(Task.__main_task)
# register the main task for at exit hooks (there should only be one)
task.__register_at_exit(task._at_exit)
# patch OS forking
PatchOsFork.patch_fork()
if auto_connect_frameworks:
PatchedMatplotlib.update_current_task(Task.__main_task)
PatchAbsl.update_current_task(Task.__main_task)
@ -227,21 +248,19 @@ class Task(_Task):
if auto_resource_monitoring:
task._resource_monitor = ResourceMonitor(task)
task._resource_monitor.start()
# Check if parse args already called. If so, sync task parameters with parser
if argparser_parseargs_called():
parser, parsed_args = get_argparser_last_args()
task._connect_argparse(parser=parser, parsed_args=parsed_args)
# make sure all random generators are initialized with new seed
make_deterministic(task.get_random_seed())
if auto_connect_arg_parser:
EnvironmentBind.update_current_task(Task.__main_task)
# Patch ArgParser to be aware of the current task
argparser_update_currenttask(Task.__main_task)
# Check if parse args already called. If so, sync task parameters with parser
if argparser_parseargs_called():
parser, parsed_args = get_argparser_last_args()
task._connect_argparse(parser, parsed_args=parsed_args)
task._connect_argparse(parser=parser, parsed_args=parsed_args)
# Make sure we start the logger, it will patch the main logging object and pipe all output
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
@ -339,7 +358,9 @@ class Task(_Task):
in_dev_mode = not running_remotely()
if in_dev_mode:
if not reuse_last_task_id or not cls.__task_is_relevant(default_task):
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
default_task_id = reuse_last_task_id
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
default_task_id = None
closed_old_task = cls.__close_timed_out_task(default_task)
else:
@ -600,6 +621,9 @@ class Task(_Task):
"""
self._at_exit()
self._at_exit_called = False
# unregister atexit callbacks and signal hooks, if we are the main task
if self.is_main_task():
self.__register_at_exit(None)
def is_current_task(self):
"""
@ -914,9 +938,12 @@ class Task(_Task):
Will happen automatically once we exit code, i.e. atexit
:return:
"""
# protect sub-process at_exit
if self._at_exit_called:
return
is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid()
# noinspection PyBroadException
try:
# from here do not get into watch dog
@ -948,28 +975,32 @@ class Task(_Task):
# from here, do not send log in background thread
if wait_for_uploads:
self.flush(wait_for_uploads=True)
# wait until the reporter flush everything
self.reporter.stop()
if print_done_waiting:
self.log.info('Finished uploading')
else:
self._logger._flush_stdout_handler()
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
if not is_sub_process:
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
# change task status
if not task_status:
pass
elif task_status[0] == 'failed':
self.mark_failed(status_reason=task_status[1])
elif task_status[0] == 'completed':
self.completed()
elif task_status[0] == 'stopped':
self.stopped()
# change task status
if not task_status:
pass
elif task_status[0] == 'failed':
self.mark_failed(status_reason=task_status[1])
elif task_status[0] == 'completed':
self.completed()
elif task_status[0] == 'stopped':
self.stopped()
# stop resource monitoring
if self._resource_monitor:
self._resource_monitor.stop()
self._logger.set_flush_period(None)
# this is so in theory we can close a main task and start a new one
Task.__main_task = None
@ -978,7 +1009,7 @@ class Task(_Task):
pass
@classmethod
def __register_at_exit(cls, exit_callback):
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
class ExitHooks(object):
_orig_exit = None
_orig_exc_handler = None
@ -1000,7 +1031,21 @@ class Task(_Task):
except Exception:
pass
self._exit_callback = callback
atexit.register(self._exit_callback)
if callback:
self.hook()
else:
# un register int hook
print('removing int hook', self._orig_exc_handler)
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
for s in self._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, self._org_handlers[s])
except Exception:
pass
self._org_handlers = {}
def hook(self):
if self._orig_exit is None:
@ -1009,20 +1054,23 @@ class Task(_Task):
if self._orig_exc_handler is None:
self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler
atexit.register(self._exit_callback)
if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[s] = signal.getsignal(s)
signal.signal(s, self.signal_handler)
except Exception:
pass
if self._exit_callback:
atexit.register(self._exit_callback)
if self._org_handlers:
if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[s] = signal.getsignal(s)
signal.signal(s, self.signal_handler)
except Exception:
pass
def exit(self, code=0):
self.exit_code = code
@ -1077,6 +1125,22 @@ class Task(_Task):
# return handler result
return org_handler
# we only remove the signals since this will hang subprocesses
if only_remove_signal_and_exception_hooks:
if not cls.__exit_hook:
return
if cls.__exit_hook._orig_exc_handler:
sys.excepthook = cls.__exit_hook._orig_exc_handler
cls.__exit_hook._orig_exc_handler = None
for s in cls.__exit_hook._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, cls.__exit_hook._org_handlers[s])
except Exception:
pass
cls.__exit_hook._org_handlers = {}
return
if cls.__exit_hook is None:
# noinspection PyBroadException
try:
@ -1084,13 +1148,13 @@ class Task(_Task):
cls.__exit_hook.hook()
except Exception:
cls.__exit_hook = None
elif cls.__main_task is None:
else:
cls.__exit_hook.update_callback(exit_callback)
@classmethod
def __get_task(cls, task_id=None, project_name=None, task_name=None):
if task_id:
return cls(private=cls.__create_protection, task_id=task_id)
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
res = cls._send(
cls._get_default_session(),

View File

@ -1,3 +1,4 @@
import os
import time
from threading import Lock
@ -6,7 +7,8 @@ import six
class AsyncManagerMixin(object):
_async_results_lock = Lock()
_async_results = []
# per pid (process) list of async jobs (support for sub-processes forking)
_async_results = {}
@classmethod
def _add_async_result(cls, result, wait_on_max_results=None, wait_time=30, wait_cb=None):
@ -14,8 +16,9 @@ class AsyncManagerMixin(object):
try:
cls._async_results_lock.acquire()
# discard completed results
cls._async_results = [r for r in cls._async_results if not r.ready()]
num_results = len(cls._async_results)
pid = os.getpid()
cls._async_results[pid] = [r for r in cls._async_results.get(pid, []) if not r.ready()]
num_results = len(cls._async_results[pid])
if wait_on_max_results is not None and num_results >= wait_on_max_results:
# At least max_results results are still pending, wait
if wait_cb:
@ -25,7 +28,7 @@ class AsyncManagerMixin(object):
continue
# add result
if result and not result.ready():
cls._async_results.append(result)
cls._async_results[pid] = cls._async_results.get(pid, []).append(result)
break
finally:
cls._async_results_lock.release()
@ -34,7 +37,8 @@ class AsyncManagerMixin(object):
def wait_for_results(cls, timeout=None, max_num_uploads=None):
remaining = timeout
count = 0
for r in cls._async_results:
pid = os.getpid()
for r in cls._async_results.get(pid, []):
if r.ready():
continue
t = time.time()
@ -48,13 +52,14 @@ class AsyncManagerMixin(object):
if max_num_uploads is not None and max_num_uploads - count <= 0:
break
if timeout is not None:
remaining = max(0, remaining - max(0, time.time() - t))
remaining = max(0., remaining - max(0., time.time() - t))
if not remaining:
break
@classmethod
def get_num_results(cls):
if cls._async_results is not None:
return len([r for r in cls._async_results if not r.ready()])
pid = os.getpid()
if cls._async_results.get(pid, []):
return len([r for r in cls._async_results.get(pid, []) if not r.ready()])
else:
return 0