Add support for Popen subprocesses with task edit protection from multiple processes

This commit is contained in:
allegroai 2020-03-05 12:05:12 +02:00
parent e3ae4f4e26
commit da804ca75f
3 changed files with 90 additions and 64 deletions

View File

@ -3,7 +3,7 @@ import itertools
import logging
import os
from enum import Enum
from threading import Thread
from tempfile import gettempdir
from multiprocessing import RLock
try:
@ -14,7 +14,7 @@ except ImportError:
import six
from six.moves.urllib.parse import quote
from ...backend_interface.task.repo.scriptinfo import ScriptRequirements
from ...utilities.locks import RLock as FileRLock
from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects
@ -81,7 +81,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:type force_create: bool
"""
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
self._edit_lock = RLock()
self.__edit_lock = None
super(Task, self).__init__(id=task_id, session=session, log=log)
self._project_name = None
self._storage_uri = None
@ -596,26 +596,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
update = kwargs.pop('__update', False)
parameters = dict() if not update else self.get_parameters()
parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args))
parameters.update(kwargs)
not_allowed = {
k: type(v).__name__
for k, v in parameters.items()
if not isinstance(v, self._parameters_allowed_types)
}
if not_allowed:
raise ValueError(
"Only builtin types ({}) are allowed for values (got {})".format(
', '.join(t.__name__ for t in self._parameters_allowed_types),
', '.join('%s=>%s' % p for p in not_allowed.items())),
)
# force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
with self._edit_lock:
self.reload()
if update:
parameters = self.get_parameters()
else:
parameters = dict()
parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args))
parameters.update(kwargs)
not_allowed = {
k: type(v).__name__
for k, v in parameters.items()
if not isinstance(v, self._parameters_allowed_types)
}
if not_allowed:
raise ValueError(
"Only builtin types ({}) are allowed for values (got {})".format(
', '.join(t.__name__ for t in self._parameters_allowed_types),
', '.join('%s=>%s' % p for p in not_allowed.items())),
)
# force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
execution = self.data.execution
if execution is None:
execution = tasks.Execution(parameters=parameters)
@ -631,9 +635,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:param value: Parameter value
:param description: Parameter description (unused for now)
"""
params = self.get_parameters()
params[name] = value
self.set_parameters(params)
self.set_parameters({name: value}, __update=True)
def get_parameter(self, name, default=None):
"""

View File

@ -11,7 +11,7 @@ TASK_ID_ENV_VAR = EnvEntry("TRAINS_TASK_ID", "ALG_TASK_ID")
DOCKER_IMAGE_ENV_VAR = EnvEntry("TRAINS_DOCKER_IMAGE", "ALG_DOCKER_IMAGE")
LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool)
NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int)
PROC_MASTER_ID_ENV_VAR = EnvEntry("TRAINS_PROC_MASTER_ID", "ALG_PROC_MASTER_ID", type=int)
PROC_MASTER_ID_ENV_VAR = EnvEntry("TRAINS_PROC_MASTER_ID", "ALG_PROC_MASTER_ID", type=str)
LOG_STDERR_REDIRECT_LEVEL = EnvEntry("TRAINS_LOG_STDERR_REDIRECT_LEVEL", "ALG_LOG_STDERR_REDIRECT_LEVEL")
DEV_WORKER_NAME = EnvEntry("TRAINS_WORKER_NAME", "ALG_WORKER_NAME")
DEV_TASK_NO_REUSE = EnvEntry("TRAINS_TASK_NO_REUSE", "ALG_TASK_NO_REUSE", type=bool)

View File

@ -198,10 +198,10 @@ class Task(_Task):
if cls.__main_task is not None:
# if this is a subprocess, regardless of what the init was called for,
# we have to fix the main task hooks and stdout bindings
if cls.__forked_proc_main_pid != os.getpid() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
if cls.__forked_proc_main_pid != os.getpid() and cls.__is_subprocess():
if task_type is None:
task_type = cls.__main_task.task_type
# make sure we only do it once per process
# make sure we only do it once per process
cls.__forked_proc_main_pid = os.getpid()
# make sure we do not wait for the repo detect thread
cls.__main_task._detect_repo_async_thread = None
@ -223,7 +223,7 @@ class Task(_Task):
# check that we are not a child process, in that case do nothing.
# we should not get here unless this is Windows platform, all others support fork
if PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid():
if cls.__is_subprocess():
class _TaskStub(object):
def __call__(self, *args, **kwargs):
return self
@ -234,9 +234,14 @@ class Task(_Task):
def __setattr__(self, attr, val):
pass
return _TaskStub()
# set us as master process
PROC_MASTER_ID_ENV_VAR.set(os.getpid())
is_sub_process_task_id = cls.__get_master_id_task_id()
# we could not find a task ID, revert to old stub behaviour
if not is_sub_process_task_id:
return _TaskStub()
else:
# set us as master process (without task ID)
cls.__update_master_pid_task()
is_sub_process_task_id = None
if task_type is None:
# Backwards compatibility: if called from Task.current_task and task_type
@ -252,24 +257,41 @@ class Task(_Task):
try:
if not running_remotely():
task = cls._create_dev_task(
project_name,
task_name,
task_type,
reuse_last_task_id,
)
if output_uri:
task.output_uri = output_uri
elif cls.__default_output_uri:
task.output_uri = cls.__default_output_uri
# if this is the main process, create the task
if not is_sub_process_task_id:
task = cls._create_dev_task(
project_name,
task_name,
task_type,
reuse_last_task_id,
detect_repo=False if (isinstance(auto_connect_frameworks, dict) and
not auto_connect_frameworks.get('detect_repository', True)) else True
)
# set defaults
if output_uri:
task.output_uri = output_uri
elif cls.__default_output_uri:
task.output_uri = cls.__default_output_uri
# store new task ID
cls.__update_master_pid_task(task=task)
else:
# subprocess should get back the task info
task = Task.get_task(task_id=is_sub_process_task_id)
else:
task = cls(
private=cls.__create_protection,
task_id=get_remote_task_id(),
log_to_backend=False,
)
if cls.__default_output_uri and not task.output_uri:
task.output_uri = cls.__default_output_uri
# if this is the main process, create the task
if not is_sub_process_task_id:
task = cls(
private=cls.__create_protection,
task_id=get_remote_task_id(),
log_to_backend=False,
)
if cls.__default_output_uri and not task.output_uri:
task.output_uri = cls.__default_output_uri
# store new task ID
cls.__update_master_pid_task(task=task)
else:
# subprocess should get back the task info
task = Task.get_task(task_id=is_sub_process_task_id)
except Exception:
raise
else:
@ -291,7 +313,7 @@ class Task(_Task):
PatchPyTorchModelIO.update_current_task(task)
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('xgboost', True):
PatchXGBoostModelIO.update_current_task(task)
if auto_resource_monitoring:
if auto_resource_monitoring and not is_sub_process_task_id:
task._resource_monitor = ResourceMonitor(task)
task._resource_monitor.start()
@ -319,13 +341,14 @@ class Task(_Task):
# The logger will automatically take care of all patching (we just need to make sure to initialize it)
logger = task.get_logger()
# show the debug metrics page in the log, it is very convenient
logger.report_text(
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
task._get_app_server(),
task.project if task.project is not None else '*',
task.id,
),
)
if not is_sub_process_task_id:
logger.report_text(
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
task._get_app_server(),
task.project if task.project is not None else '*',
task.id,
),
)
# Make sure we start the dev worker if required, otherwise it will only be started when we write
# something to the log.
task._dev_mode_task_start()
@ -1344,7 +1367,7 @@ class Task(_Task):
if self._at_exit_called:
return
is_sub_process = PROC_MASTER_ID_ENV_VAR.get() and PROC_MASTER_ID_ENV_VAR.get() != os.getpid()
is_sub_process = self.__is_subprocess()
# noinspection PyBroadException
try:
@ -1371,11 +1394,12 @@ class Task(_Task):
task_status = ('stopped', )
# wait for repository detection (if we didn't crash)
if not is_sub_process and wait_for_uploads and self._logger:
if wait_for_uploads and self._logger:
# we should print summary here
self._summary_artifacts()
# make sure that if we crashed the thread we are not waiting forever
self._wait_for_repo_detection(timeout=10.)
if not is_sub_process:
self._wait_for_repo_detection(timeout=10.)
# wait for uploads
print_done_waiting = False
@ -1399,11 +1423,11 @@ class Task(_Task):
elif self._logger:
self._logger._flush_stdout_handler()
if not is_sub_process:
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
if not is_sub_process:
# change task status
if not task_status:
pass