mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Fix support for sub-process (process pool)
This commit is contained in:
154
trains/task.py
154
trains/task.py
@@ -26,7 +26,7 @@ from .errors import UsageError
|
||||
from .logger import Logger
|
||||
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
||||
from .task_parameters import TaskParameters
|
||||
from .binding.environ_bind import EnvironmentBind
|
||||
from .binding.environ_bind import EnvironmentBind, PatchOsFork
|
||||
from .binding.absl_bind import PatchAbsl
|
||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||
argparser_update_currenttask
|
||||
@@ -66,6 +66,7 @@ class Task(_Task):
|
||||
__create_protection = object()
|
||||
__main_task = None
|
||||
__exit_hook = None
|
||||
__forked_proc_main_pid = None
|
||||
__task_id_reuse_time_window_in_hours = float(config.get('development.task_reuse_time_window_in_hours', 24.0))
|
||||
__store_diff_on_train = config.get('development.store_uncommitted_code_diff_on_train', False)
|
||||
__detect_repo_async = config.get('development.vcs_repo_detect_async', False)
|
||||
@@ -104,7 +105,6 @@ class Task(_Task):
|
||||
self._resource_monitor = None
|
||||
# register atexit, so that we mark the task as stopped
|
||||
self._at_exit_called = False
|
||||
self.__register_at_exit(self._at_exit)
|
||||
|
||||
@classmethod
|
||||
def current_task(cls):
|
||||
@@ -132,9 +132,10 @@ class Task(_Task):
|
||||
:param project_name: project to create the task in (if project doesn't exist, it will be created)
|
||||
:param task_name: task name to be created (in development mode, not when running remotely)
|
||||
:param task_type: task type to be created (in development mode, not when running remotely)
|
||||
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder). \
|
||||
if False every time we call the function we create a new task with the same name \
|
||||
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies) \
|
||||
:param reuse_last_task_id: start with the previously used task id (stored in the data cache folder).
|
||||
if False every time we call the function we create a new task with the same name
|
||||
Notice! The reused task will be reset. (when running remotely, the usual behaviour applies)
|
||||
If reuse_last_task_id is of type string, it will assume this is the task_id to reuse!
|
||||
Note: A closed or published task will not be reused, and a new task will be created.
|
||||
:param output_uri: Default location for output models (currently support folder/S3/GS/ ).
|
||||
notice: sub-folders (task_id) is created in the destination folder for all outputs.
|
||||
@@ -166,12 +167,31 @@ class Task(_Task):
|
||||
)
|
||||
|
||||
if cls.__main_task is not None:
|
||||
# if this is a subprocess, regardless of what the init was called for,
|
||||
# we have to fix the main task hooks and stdout bindings
|
||||
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
|
||||
# make sure we only do it once per process
|
||||
cls.__forked_proc_main_pid = os.getpid()
|
||||
# make sure we do not wait for the repo detect thread
|
||||
cls.__main_task._detect_repo_async_thread = None
|
||||
# remove the logger from the previous process
|
||||
logger = cls.__main_task.get_logger()
|
||||
logger.set_flush_period(None)
|
||||
# create a new logger (to catch stdout/err)
|
||||
cls.__main_task._logger = None
|
||||
cls.__main_task._reporter = None
|
||||
cls.__main_task.get_logger()
|
||||
# unregister signal hooks, they cause subprocess to hang
|
||||
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
|
||||
cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
|
||||
|
||||
if not running_remotely():
|
||||
verify_defaults_match()
|
||||
|
||||
return cls.__main_task
|
||||
|
||||
# check that we are not a child process, in that case do nothing
|
||||
# check that we are not a child process, in that case do nothing.
|
||||
# we should not get here unless this is Windows platform, all others support fork
|
||||
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
|
||||
class _TaskStub(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
@@ -212,9 +232,10 @@ class Task(_Task):
|
||||
raise
|
||||
else:
|
||||
Task.__main_task = task
|
||||
# Patch argparse to be aware of the current task
|
||||
argparser_update_currenttask(Task.__main_task)
|
||||
EnvironmentBind.update_current_task(Task.__main_task)
|
||||
# register the main task for at exit hooks (there should only be one)
|
||||
task.__register_at_exit(task._at_exit)
|
||||
# patch OS forking
|
||||
PatchOsFork.patch_fork()
|
||||
if auto_connect_frameworks:
|
||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||
PatchAbsl.update_current_task(Task.__main_task)
|
||||
@@ -227,21 +248,19 @@ class Task(_Task):
|
||||
if auto_resource_monitoring:
|
||||
task._resource_monitor = ResourceMonitor(task)
|
||||
task._resource_monitor.start()
|
||||
# Check if parse args already called. If so, sync task parameters with parser
|
||||
if argparser_parseargs_called():
|
||||
parser, parsed_args = get_argparser_last_args()
|
||||
task._connect_argparse(parser=parser, parsed_args=parsed_args)
|
||||
|
||||
# make sure all random generators are initialized with new seed
|
||||
make_deterministic(task.get_random_seed())
|
||||
|
||||
if auto_connect_arg_parser:
|
||||
EnvironmentBind.update_current_task(Task.__main_task)
|
||||
|
||||
# Patch ArgParser to be aware of the current task
|
||||
argparser_update_currenttask(Task.__main_task)
|
||||
# Check if parse args already called. If so, sync task parameters with parser
|
||||
if argparser_parseargs_called():
|
||||
parser, parsed_args = get_argparser_last_args()
|
||||
task._connect_argparse(parser, parsed_args=parsed_args)
|
||||
task._connect_argparse(parser=parser, parsed_args=parsed_args)
|
||||
|
||||
# Make sure we start the logger, it will patch the main logging object and pipe all output
|
||||
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
|
||||
@@ -339,7 +358,9 @@ class Task(_Task):
|
||||
in_dev_mode = not running_remotely()
|
||||
|
||||
if in_dev_mode:
|
||||
if not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
||||
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
|
||||
default_task_id = reuse_last_task_id
|
||||
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
||||
default_task_id = None
|
||||
closed_old_task = cls.__close_timed_out_task(default_task)
|
||||
else:
|
||||
@@ -600,6 +621,9 @@ class Task(_Task):
|
||||
"""
|
||||
self._at_exit()
|
||||
self._at_exit_called = False
|
||||
# unregister atexit callbacks and signal hooks, if we are the main task
|
||||
if self.is_main_task():
|
||||
self.__register_at_exit(None)
|
||||
|
||||
def is_current_task(self):
|
||||
"""
|
||||
@@ -914,9 +938,12 @@ class Task(_Task):
|
||||
Will happen automatically once we exit code, i.e. atexit
|
||||
:return:
|
||||
"""
|
||||
# protect sub-process at_exit
|
||||
if self._at_exit_called:
|
||||
return
|
||||
|
||||
is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid()
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# from here do not get into watch dog
|
||||
@@ -948,28 +975,32 @@ class Task(_Task):
|
||||
# from here, do not send log in background thread
|
||||
if wait_for_uploads:
|
||||
self.flush(wait_for_uploads=True)
|
||||
# wait until the reporter flush everything
|
||||
self.reporter.stop()
|
||||
if print_done_waiting:
|
||||
self.log.info('Finished uploading')
|
||||
else:
|
||||
self._logger._flush_stdout_handler()
|
||||
|
||||
# from here, do not check worker status
|
||||
if self._dev_worker:
|
||||
self._dev_worker.unregister()
|
||||
if not is_sub_process:
|
||||
# from here, do not check worker status
|
||||
if self._dev_worker:
|
||||
self._dev_worker.unregister()
|
||||
|
||||
# change task status
|
||||
if not task_status:
|
||||
pass
|
||||
elif task_status[0] == 'failed':
|
||||
self.mark_failed(status_reason=task_status[1])
|
||||
elif task_status[0] == 'completed':
|
||||
self.completed()
|
||||
elif task_status[0] == 'stopped':
|
||||
self.stopped()
|
||||
# change task status
|
||||
if not task_status:
|
||||
pass
|
||||
elif task_status[0] == 'failed':
|
||||
self.mark_failed(status_reason=task_status[1])
|
||||
elif task_status[0] == 'completed':
|
||||
self.completed()
|
||||
elif task_status[0] == 'stopped':
|
||||
self.stopped()
|
||||
|
||||
# stop resource monitoring
|
||||
if self._resource_monitor:
|
||||
self._resource_monitor.stop()
|
||||
|
||||
self._logger.set_flush_period(None)
|
||||
# this is so in theory we can close a main task and start a new one
|
||||
Task.__main_task = None
|
||||
@@ -978,7 +1009,7 @@ class Task(_Task):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __register_at_exit(cls, exit_callback):
|
||||
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
|
||||
class ExitHooks(object):
|
||||
_orig_exit = None
|
||||
_orig_exc_handler = None
|
||||
@@ -1000,7 +1031,21 @@ class Task(_Task):
|
||||
except Exception:
|
||||
pass
|
||||
self._exit_callback = callback
|
||||
atexit.register(self._exit_callback)
|
||||
if callback:
|
||||
self.hook()
|
||||
else:
|
||||
# un register int hook
|
||||
print('removing int hook', self._orig_exc_handler)
|
||||
if self._orig_exc_handler:
|
||||
sys.excepthook = self._orig_exc_handler
|
||||
self._orig_exc_handler = None
|
||||
for s in self._org_handlers:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
signal.signal(s, self._org_handlers[s])
|
||||
except Exception:
|
||||
pass
|
||||
self._org_handlers = {}
|
||||
|
||||
def hook(self):
|
||||
if self._orig_exit is None:
|
||||
@@ -1009,20 +1054,23 @@ class Task(_Task):
|
||||
if self._orig_exc_handler is None:
|
||||
self._orig_exc_handler = sys.excepthook
|
||||
sys.excepthook = self.exc_handler
|
||||
atexit.register(self._exit_callback)
|
||||
if sys.platform == 'win32':
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE]
|
||||
else:
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
|
||||
for s in catch_signals:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._org_handlers[s] = signal.getsignal(s)
|
||||
signal.signal(s, self.signal_handler)
|
||||
except Exception:
|
||||
pass
|
||||
if self._exit_callback:
|
||||
atexit.register(self._exit_callback)
|
||||
|
||||
if self._org_handlers:
|
||||
if sys.platform == 'win32':
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE]
|
||||
else:
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
|
||||
for s in catch_signals:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._org_handlers[s] = signal.getsignal(s)
|
||||
signal.signal(s, self.signal_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def exit(self, code=0):
|
||||
self.exit_code = code
|
||||
@@ -1077,6 +1125,22 @@ class Task(_Task):
|
||||
# return handler result
|
||||
return org_handler
|
||||
|
||||
# we only remove the signals since this will hang subprocesses
|
||||
if only_remove_signal_and_exception_hooks:
|
||||
if not cls.__exit_hook:
|
||||
return
|
||||
if cls.__exit_hook._orig_exc_handler:
|
||||
sys.excepthook = cls.__exit_hook._orig_exc_handler
|
||||
cls.__exit_hook._orig_exc_handler = None
|
||||
for s in cls.__exit_hook._org_handlers:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
signal.signal(s, cls.__exit_hook._org_handlers[s])
|
||||
except Exception:
|
||||
pass
|
||||
cls.__exit_hook._org_handlers = {}
|
||||
return
|
||||
|
||||
if cls.__exit_hook is None:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@@ -1084,13 +1148,13 @@ class Task(_Task):
|
||||
cls.__exit_hook.hook()
|
||||
except Exception:
|
||||
cls.__exit_hook = None
|
||||
elif cls.__main_task is None:
|
||||
else:
|
||||
cls.__exit_hook.update_callback(exit_callback)
|
||||
|
||||
@classmethod
|
||||
def __get_task(cls, task_id=None, project_name=None, task_name=None):
|
||||
if task_id:
|
||||
return cls(private=cls.__create_protection, task_id=task_id)
|
||||
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
|
||||
|
||||
res = cls._send(
|
||||
cls._get_default_session(),
|
||||
|
||||
Reference in New Issue
Block a user