Fix support for sub-process (process pool)

This commit is contained in:
allegroai
2019-07-20 23:11:54 +03:00
parent 50ce49a3dd
commit c80aae0e1e
6 changed files with 220 additions and 75 deletions

View File

@@ -26,7 +26,7 @@ from .errors import UsageError
from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .binding.environ_bind import EnvironmentBind
from .binding.environ_bind import EnvironmentBind, PatchOsFork
from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
@@ -66,6 +66,7 @@ class Task(_Task):
__create_protection = object()
__main_task = None
__exit_hook = None
__forked_proc_main_pid = None
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
__store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False)
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
@@ -104,7 +105,6 @@ class Task(_Task):
self._resource_monitor = None
# register atexit, so that we mark the task as stopped
self._at_exit_called = False
self.__register_at_exit(self._at_exit)
@classmethod
def current_task(cls):
@@ -132,9 +132,10 @@ class Task(_Task):
:param project_name: project to create the task in (if project doesn't exist, it will be created)
:param task_name: task name to be created (in development mode, not when running remotely)
:param task_type: task type to be created (in development mode, not when running remotely)
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). \
if False every time we call the function we create a new task with the same name \
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) \
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder).
if False every time we call the function we create a new task with the same name
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies)
If reuse_last_task_id is of type string, it will assume this is the task_id to reuse!
Note: A closed or published task will not be reused, and a new task will be created.
:param output_uri: Default location for output models (currently support folder/S3/GS/ ).
notice: sub-folders (task_id) is created in the destination folder for all outputs.
@@ -166,12 +167,31 @@ class Task(_Task):
)
if cls.__main_task is not None:
# if this is a subprocess, regardless of what the init was called for,
# we have to fix the main task hooks and stdout bindings
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
# make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid()
# make sure we do not wait for the repo detect thread
cls.__main_task._detect_repo_async_thread = None
# remove the logger from the previous process
logger = cls.__main_task.get_logger()
logger.set_flush_period(None)
# create a new logger (to catch stdout/err)
cls.__main_task._logger = None
cls.__main_task._reporter = None
cls.__main_task.get_logger()
# unregister signal hooks, they cause subprocess to hang
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
if not running_remotely():
verify_defaults_match()
return cls.__main_task
# check that we are not a child process, in that case do nothing
# check that we are not a child process, in that case do nothing.
# we should not get here unless this is Windows platform, all others support fork
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
class _TaskStub(object):
def __call__(self, *args, **kwargs):
@@ -212,9 +232,10 @@ class Task(_Task):
raise
else:
Task.__main_task = task
# Patch argparse to be aware of the current task
argparser_update_currenttask(Task.__main_task)
EnvironmentBind.update_current_task(Task.__main_task)
# register the main task for at exit hooks (there should only be one)
task.__register_at_exit(task._at_exit)
# patch OS forking
PatchOsFork.patch_fork()
if auto_connect_frameworks:
PatchedMatplotlib.update_current_task(Task.__main_task)
PatchAbsl.update_current_task(Task.__main_task)
@@ -227,21 +248,19 @@ class Task(_Task):
if auto_resource_monitoring:
task._resource_monitor = ResourceMonitor(task)
task._resource_monitor.start()
# Check if parse args already called. If so, sync task parameters with parser
if argparser_parseargs_called():
parser, parsed_args = get_argparser_last_args()
task._connect_argparse(parser=parser, parsed_args=parsed_args)
# make sure all random generators are initialized with new seed
make_deterministic(task.get_random_seed())
if auto_connect_arg_parser:
EnvironmentBind.update_current_task(Task.__main_task)
# Patch ArgParser to be aware of the current task
argparser_update_currenttask(Task.__main_task)
# Check if parse args already called. If so, sync task parameters with parser
if argparser_parseargs_called():
parser, parsed_args = get_argparser_last_args()
task._connect_argparse(parser, parsed_args=parsed_args)
task._connect_argparse(parser=parser, parsed_args=parsed_args)
# Make sure we start the logger, it will patch the main logging object and pipe all output
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
@@ -339,7 +358,9 @@ class Task(_Task):
in_dev_mode = not running_remotely()
if in_dev_mode:
if not reuse_last_task_id or not cls.__task_is_relevant(default_task):
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
default_task_id = reuse_last_task_id
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
default_task_id = None
closed_old_task = cls.__close_timed_out_task(default_task)
else:
@@ -600,6 +621,9 @@ class Task(_Task):
"""
self._at_exit()
self._at_exit_called = False
# unregister atexit callbacks and signal hooks, if we are the main task
if self.is_main_task():
self.__register_at_exit(None)
def is_current_task(self):
"""
@@ -914,9 +938,12 @@ class Task(_Task):
Will happen automatically once we exit code, i.e. atexit
:return:
"""
# protect sub-process at_exit
if self._at_exit_called:
return
is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid()
# noinspection PyBroadException
try:
# from here do not get into watch dog
@@ -948,28 +975,32 @@ class Task(_Task):
# from here, do not send log in background thread
if wait_for_uploads:
self.flush(wait_for_uploads=True)
# wait until the reporter flush everything
self.reporter.stop()
if print_done_waiting:
self.log.info('Finished uploading')
else:
self._logger._flush_stdout_handler()
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
if not is_sub_process:
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
# change task status
if not task_status:
pass
elif task_status[0] == 'failed':
self.mark_failed(status_reason=task_status[1])
elif task_status[0] == 'completed':
self.completed()
elif task_status[0] == 'stopped':
self.stopped()
# change task status
if not task_status:
pass
elif task_status[0] == 'failed':
self.mark_failed(status_reason=task_status[1])
elif task_status[0] == 'completed':
self.completed()
elif task_status[0] == 'stopped':
self.stopped()
# stop resource monitoring
if self._resource_monitor:
self._resource_monitor.stop()
self._logger.set_flush_period(None)
# this is so in theory we can close a main task and start a new one
Task.__main_task = None
@@ -978,7 +1009,7 @@ class Task(_Task):
pass
@classmethod
def __register_at_exit(cls, exit_callback):
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
class ExitHooks(object):
_orig_exit = None
_orig_exc_handler = None
@@ -1000,7 +1031,21 @@ class Task(_Task):
except Exception:
pass
self._exit_callback = callback
atexit.register(self._exit_callback)
if callback:
self.hook()
else:
# un register int hook
print('removing int hook', self._orig_exc_handler)
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
for s in self._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, self._org_handlers[s])
except Exception:
pass
self._org_handlers = {}
def hook(self):
if self._orig_exit is None:
@@ -1009,20 +1054,23 @@ class Task(_Task):
if self._orig_exc_handler is None:
self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler
atexit.register(self._exit_callback)
if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[s] = signal.getsignal(s)
signal.signal(s, self.signal_handler)
except Exception:
pass
if self._exit_callback:
atexit.register(self._exit_callback)
if self._org_handlers:
if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for s in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[s] = signal.getsignal(s)
signal.signal(s, self.signal_handler)
except Exception:
pass
def exit(self, code=0):
self.exit_code = code
@@ -1077,6 +1125,22 @@ class Task(_Task):
# return handler result
return org_handler
# we only remove the signals since this will hang subprocesses
if only_remove_signal_and_exception_hooks:
if not cls.__exit_hook:
return
if cls.__exit_hook._orig_exc_handler:
sys.excepthook = cls.__exit_hook._orig_exc_handler
cls.__exit_hook._orig_exc_handler = None
for s in cls.__exit_hook._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, cls.__exit_hook._org_handlers[s])
except Exception:
pass
cls.__exit_hook._org_handlers = {}
return
if cls.__exit_hook is None:
# noinspection PyBroadException
try:
@@ -1084,13 +1148,13 @@ class Task(_Task):
cls.__exit_hook.hook()
except Exception:
cls.__exit_hook = None
elif cls.__main_task is None:
else:
cls.__exit_hook.update_callback(exit_callback)
@classmethod
def __get_task(cls, task_id=None, project_name=None, task_name=None):
if task_id:
return cls(private=cls.__create_protection, task_id=task_id)
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
res = cls._send(
cls._get_default_session(),