Add support for Popen subprocesses with task edit protection from multiple processes

This commit is contained in:
allegroai 2020-03-05 12:05:12 +02:00
parent e3ae4f4e26
commit da804ca75f
3 changed files with 90 additions and 64 deletions

View File

@ -3,7 +3,7 @@ import itertools
import logging import logging
import os import os
from enum import Enum from enum import Enum
from threading import Thread from tempfile import gettempdir
from multiprocessing import RLock from multiprocessing import RLock
try: try:
@ -14,7 +14,7 @@ except ImportError:
import six import six
from six.moves.urllib.parse import quote from six.moves.urllib.parse import quote
from ...backend_interface.task.repo.scriptinfo import ScriptRequirements from ...utilities.locks import RLock as FileRLock
from ...backend_interface.task.development.worker import DevWorker from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects from ...backend_api.services import tasks, models, events, projects
@ -81,7 +81,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:type force_create: bool :type force_create: bool
""" """
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
self._edit_lock = RLock() self.__edit_lock = None
super(Task, self).__init__(id=task_id, session=session, log=log) super(Task, self).__init__(id=task_id, session=session, log=log)
self._project_name = None self._project_name = None
self._storage_uri = None self._storage_uri = None
@ -596,7 +596,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
update = kwargs.pop('__update', False) update = kwargs.pop('__update', False)
parameters = dict() if not update else self.get_parameters() with self._edit_lock:
self.reload()
if update:
parameters = self.get_parameters()
else:
parameters = dict()
parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args)) parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args))
parameters.update(kwargs) parameters.update(kwargs)
@ -615,7 +620,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# force cast all variables to strings (so that we can later edit them in UI) # force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()} parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
with self._edit_lock:
execution = self.data.execution execution = self.data.execution
if execution is None: if execution is None:
execution = tasks.Execution(parameters=parameters) execution = tasks.Execution(parameters=parameters)
@ -631,9 +635,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:param value: Parameter value :param value: Parameter value
:param description: Parameter description (unused for now) :param description: Parameter description (unused for now)
""" """
params = self.get_parameters() self.set_parameters({name: value}, __update=True)
params[name] = value
self.set_parameters(params)
def get_parameter(self, name, default=None): def get_parameter(self, name, default=None):
""" """

View File

@ -11,7 +11,7 @@ TASK_ID_ENV_VAR = EnvEntry("TRAINS_TASK_ID", "ALG_TASK_ID")
DOCKER_IMAGE_ENV_VAR = EnvEntry("TRAINS_DOCKER_IMAGE", "ALG_DOCKER_IMAGE") DOCKER_IMAGE_ENV_VAR = EnvEntry("TRAINS_DOCKER_IMAGE", "ALG_DOCKER_IMAGE")
LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool) LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool)
NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int) NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int)
PROC_MASTER_ID_ENV_VAR = EnvEntry("TRAINS_PROC_MASTER_ID", "ALG_PROC_MASTER_ID", type=int) PROC_MASTER_ID_ENV_VAR = EnvEntry("TRAINS_PROC_MASTER_ID", "ALG_PROC_MASTER_ID", type=str)
LOG_STDERR_REDIRECT_LEVEL = EnvEntry("TRAINS_LOG_STDERR_REDIRECT_LEVEL", "ALG_LOG_STDERR_REDIRECT_LEVEL") LOG_STDERR_REDIRECT_LEVEL = EnvEntry("TRAINS_LOG_STDERR_REDIRECT_LEVEL", "ALG_LOG_STDERR_REDIRECT_LEVEL")
DEV_WORKER_NAME = EnvEntry("TRAINS_WORKER_NAME", "ALG_WORKER_NAME") DEV_WORKER_NAME = EnvEntry("TRAINS_WORKER_NAME", "ALG_WORKER_NAME")
DEV_TASK_NO_REUSE = EnvEntry("TRAINS_TASK_NO_REUSE", "ALG_TASK_NO_REUSE", type=bool) DEV_TASK_NO_REUSE = EnvEntry("TRAINS_TASK_NO_REUSE", "ALG_TASK_NO_REUSE", type=bool)

View File

@ -198,7 +198,7 @@ class Task(_Task):
if cls.__main_task is not None: if cls.__main_task is not None:
# if this is a subprocess, regardless of what the init was called for, # if this is a subprocess, regardless of what the init was called for,
# we have to fix the main task hooks and stdout bindings # we have to fix the main task hooks and stdout bindings
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): if cls.__forked_proc_main_pid != os.getpid() and cls.__is_subprocess():
if task_type is None: if task_type is None:
task_type = cls.__main_task.task_type task_type = cls.__main_task.task_type
# make sure we only do it once per process # make sure we only do it once per process
@ -223,7 +223,7 @@ class Task(_Task):
# check that we are not a child process, in that case do nothing. # check that we are not a child process, in that case do nothing.
# we should not get here unless this is Windows platform, all others support fork # we should not get here unless this is Windows platform, all others support fork
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid(): if cls.__is_subprocess():
class _TaskStub(object): class _TaskStub(object):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self return self
@ -234,9 +234,14 @@ class Task(_Task):
def __setattr__(self, attr, val): def __setattr__(self, attr, val):
pass pass
is_sub_process_task_id = cls.__get_master_id_task_id()
# we could not find a task ID, revert to old stub behaviour
if not is_sub_process_task_id:
return _TaskStub() return _TaskStub()
# set us as master process else:
PROC_MASTER_ID_ENV_VAR.set(os.getpid()) # set us as master process (without task ID)
cls.__update_master_pid_task()
is_sub_process_task_id = None
if task_type is None: if task_type is None:
# Backwards compatibility: if called from Task.current_task and task_type # Backwards compatibility: if called from Task.current_task and task_type
@ -252,17 +257,29 @@ class Task(_Task):
try: try:
if not running_remotely(): if not running_remotely():
# if this is the main process, create the task
if not is_sub_process_task_id:
task = cls._create_dev_task( task = cls._create_dev_task(
project_name, project_name,
task_name, task_name,
task_type, task_type,
reuse_last_task_id, reuse_last_task_id,
detect_repo=False if (isinstance(auto_connect_frameworks, dict) and
not auto_connect_frameworks.get('detect_repository', True)) else True
) )
# set defaults
if output_uri: if output_uri:
task.output_uri = output_uri task.output_uri = output_uri
elif cls.__default_output_uri: elif cls.__default_output_uri:
task.output_uri = cls.__default_output_uri task.output_uri = cls.__default_output_uri
# store new task ID
cls.__update_master_pid_task(task=task)
else: else:
# subprocess should get back the task info
task = Task.get_task(task_id=is_sub_process_task_id)
else:
# if this is the main process, create the task
if not is_sub_process_task_id:
task = cls( task = cls(
private=cls.__create_protection, private=cls.__create_protection,
task_id=get_remote_task_id(), task_id=get_remote_task_id(),
@ -270,6 +287,11 @@ class Task(_Task):
) )
if cls.__default_output_uri and not task.output_uri: if cls.__default_output_uri and not task.output_uri:
task.output_uri = cls.__default_output_uri task.output_uri = cls.__default_output_uri
# store new task ID
cls.__update_master_pid_task(task=task)
else:
# subprocess should get back the task info
task = Task.get_task(task_id=is_sub_process_task_id)
except Exception: except Exception:
raise raise
else: else:
@ -291,7 +313,7 @@ class Task(_Task):
PatchPyTorchModelIO.update_current_task(task) PatchPyTorchModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True): if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
PatchXGBoostModelIO.update_current_task(task) PatchXGBoostModelIO.update_current_task(task)
if auto_resource_monitoring: if auto_resource_monitoring and not is_sub_process_task_id:
task._resource_monitor = ResourceMonitor(task) task._resource_monitor = ResourceMonitor(task)
task._resource_monitor.start() task._resource_monitor.start()
@ -319,6 +341,7 @@ class Task(_Task):
# The logger will automatically take care of all patching (we just need to make sure to initialize it) # The logger will automatically take care of all patching (we just need to make sure to initialize it)
logger = task.get_logger() logger = task.get_logger()
# show the debug metrics page in the log, it is very convenient # show the debug metrics page in the log, it is very convenient
if not is_sub_process_task_id:
logger.report_text( logger.report_text(
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format( 'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
task._get_app_server(), task._get_app_server(),
@ -1344,7 +1367,7 @@ class Task(_Task):
if self._at_exit_called: if self._at_exit_called:
return return
is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid() is_sub_process = self.__is_subprocess()
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -1371,10 +1394,11 @@ class Task(_Task):
task_status = ('stopped', ) task_status = ('stopped', )
# wait for repository detection (if we didn't crash) # wait for repository detection (if we didn't crash)
if not is_sub_process and wait_for_uploads and self._logger: if wait_for_uploads and self._logger:
# we should print summary here # we should print summary here
self._summary_artifacts() self._summary_artifacts()
# make sure that if we crashed the thread we are not waiting forever # make sure that if we crashed the thread we are not waiting forever
if not is_sub_process:
self._wait_for_repo_detection(timeout=10.) self._wait_for_repo_detection(timeout=10.)
# wait for uploads # wait for uploads
@ -1399,11 +1423,11 @@ class Task(_Task):
elif self._logger: elif self._logger:
self._logger._flush_stdout_handler() self._logger._flush_stdout_handler()
if not is_sub_process:
# from here, do not check worker status # from here, do not check worker status
if self._dev_worker: if self._dev_worker:
self._dev_worker.unregister() self._dev_worker.unregister()
if not is_sub_process:
# change task status # change task status
if not task_status: if not task_status:
pass pass