Add new task status and support for reuse override

This commit is contained in:
allegroai 2019-07-06 23:02:58 +03:00
parent bcb10d7adb
commit e40d4f2f41

View File

@ -11,15 +11,13 @@ import psutil
import six import six
from .backend_api.services import tasks, projects from .backend_api.services import tasks, projects
from .backend_interface import TaskStatusEnum
from .backend_interface.model import Model as BackendModel from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task from .backend_interface.task import Task as _Task
from .backend_interface.task.args import _Arguments from .backend_interface.task.args import _Arguments
from .backend_interface.task.development.stop_signal import TaskStopSignal
from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.repo import ScriptInfo from .backend_interface.task.repo import ScriptInfo
from .backend_interface.util import get_single_result, exact_match_regex, make_message from .backend_interface.util import get_single_result, exact_match_regex, make_message
from .config import config, PROC_MASTER_ID_ENV_VAR from .config import config, PROC_MASTER_ID_ENV_VAR, DEV_TASK_NO_REUSE
from .config import running_remotely, get_remote_task_id from .config import running_remotely, get_remote_task_id
from .config.cache import SessionCache from .config.cache import SessionCache
from .debugging.log import LoggerRoot from .debugging.log import LoggerRoot
@ -61,7 +59,7 @@ class Task(_Task):
**Usage: Task.init(...), Task.create() or Task.get_task(...)** **Usage: Task.init(...), Task.create() or Task.get_task(...)**
""" """
TaskTypes = tasks.TaskTypeEnum TaskTypes = _Task.TaskTypes
__create_protection = object() __create_protection = object()
__main_task = None __main_task = None
@ -99,8 +97,6 @@ class Task(_Task):
self._last_input_model_id = None self._last_input_model_id = None
self._connected_output_model = None self._connected_output_model = None
self._dev_worker = None self._dev_worker = None
self._dev_stop_signal = None
self._dev_mode_periodic_flag = False
self._connected_parameter_type = None self._connected_parameter_type = None
self._detect_repo_async_thread = None self._detect_repo_async_thread = None
self._resource_monitor = None self._resource_monitor = None
@ -303,8 +299,6 @@ class Task(_Task):
if task._dev_worker: if task._dev_worker:
task._dev_worker.unregister() task._dev_worker.unregister()
task._dev_worker = None task._dev_worker = None
if task._dev_stop_signal:
task._dev_stop_signal = None
@classmethod @classmethod
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id): def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
@ -326,6 +320,10 @@ class Task(_Task):
except Exception: except Exception:
pass pass
# if we force no task reuse from os environment
if DEV_TASK_NO_REUSE.get():
default_task = None
else:
# if we have a previous session to use, get the task id from it # if we have a previous session to use, get the task id from it
default_task = cls.__get_last_used_task_id( default_task = cls.__get_last_used_task_id(
default_project_name, default_project_name,
@ -335,7 +333,7 @@ class Task(_Task):
closed_old_task = False closed_old_task = False
default_task_id = None default_task_id = None
in_dev_mode = not running_remotely() and not DevWorker.is_enabled() in_dev_mode = not running_remotely()
if in_dev_mode: if in_dev_mode:
if not reuse_last_task_id or not cls.__task_is_relevant(default_task): if not reuse_last_task_id or not cls.__task_is_relevant(default_task):
@ -351,7 +349,7 @@ class Task(_Task):
task_id=default_task_id, task_id=default_task_id,
log_to_backend=True, log_to_backend=True,
) )
if ((task.status in (TaskStatusEnum.published, TaskStatusEnum.closed)) if ((task.status in (tasks.TaskStatusEnum.published, tasks.TaskStatusEnum.closed))
or (ARCHIVED_TAG in task.data.tags) or task.output_model_id): or (ARCHIVED_TAG in task.data.tags) or task.output_model_id):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode # If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
# If the task is archived, or already has an output model, # If the task is archived, or already has an output model,
@ -402,6 +400,7 @@ class Task(_Task):
# update current repository and put warning into logs # update current repository and put warning into logs
if in_dev_mode and cls.__detect_repo_async: if in_dev_mode and cls.__detect_repo_async:
task._detect_repo_async_thread = threading.Thread(target=task._update_repository) task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.daemon = True
task._detect_repo_async_thread.start() task._detect_repo_async_thread.start()
else: else:
task._update_repository() task._update_repository()
@ -554,8 +553,6 @@ class Task(_Task):
:param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed :param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed
:return: True :return: True
""" """
self._dev_mode_periodic()
# wait for detection repo sync # wait for detection repo sync
if self._detect_repo_async_thread: if self._detect_repo_async_thread:
with self._lock: with self._lock:
@ -826,10 +823,11 @@ class Task(_Task):
""" Called when we suspect the task has started running """ """ Called when we suspect the task has started running """
self._dev_mode_setup_worker(model_updated=model_updated) self._dev_mode_setup_worker(model_updated=model_updated)
if TaskStopSignal.enabled and not self._dev_stop_signal:
self._dev_stop_signal = TaskStopSignal(task=self)
def _dev_mode_stop_task(self, stop_reason): def _dev_mode_stop_task(self, stop_reason):
# make sure we do not get called (by a daemon thread) after at_exit
if self._at_exit_called:
return
self.get_logger().warn( self.get_logger().warn(
"### TASK STOPPED - USER ABORTED - {} ###".format( "### TASK STOPPED - USER ABORTED - {} ###".format(
stop_reason.upper().replace('_', ' ') stop_reason.upper().replace('_', ' ')
@ -871,23 +869,6 @@ class Task(_Task):
else: else:
parent.terminate() parent.terminate()
def _dev_mode_periodic(self):
if self._dev_mode_periodic_flag or not self.is_main_task():
# Ensures we won't get into an infinite recursion since we might call self.flush() down the line
return
self._dev_mode_periodic_flag = True
try:
if self._dev_stop_signal:
stop_reason = self._dev_stop_signal.test()
if stop_reason and not self._at_exit_called:
self._dev_mode_stop_task(stop_reason)
if self._dev_worker:
self._dev_worker.status_report()
finally:
self._dev_mode_periodic_flag = False
def _dev_mode_setup_worker(self, model_updated=False): def _dev_mode_setup_worker(self, model_updated=False):
if running_remotely() or not self.is_main_task(): if running_remotely() or not self.is_main_task():
return return
@ -895,11 +876,8 @@ class Task(_Task):
if self._dev_worker: if self._dev_worker:
return self._dev_worker return self._dev_worker
if not DevWorker.is_enabled(model_updated):
return None
self._dev_worker = DevWorker() self._dev_worker = DevWorker()
self._dev_worker.register() self._dev_worker.register(self)
logger = self.get_logger() logger = self.get_logger()
flush_period = logger.get_flush_period() flush_period = logger.get_flush_period()
@ -927,26 +905,24 @@ class Task(_Task):
try: try:
# from here do not get into watch dog # from here do not get into watch dog
self._at_exit_called = True self._at_exit_called = True
self._dev_stop_signal = None
self._dev_mode_periodic_flag = True
wait_for_uploads = True wait_for_uploads = True
# first thing mark task as stopped, so we will not end up with "running" on lost tasks # first thing mark task as stopped, so we will not end up with "running" on lost tasks
# if we are running remotely, the daemon will take care of it # if we are running remotely, the daemon will take care of it
task_status = None
if not running_remotely() and self.is_main_task(): if not running_remotely() and self.is_main_task():
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
# check if we crashed, ot the signal is not interrupt (manual break) # check if we crashed, ot the signal is not interrupt (manual break)
task_status = ('stopped', )
if self.__exit_hook: if self.__exit_hook:
if self.__exit_hook.exception is not None or \ if self.__exit_hook.exception is not None or \
(not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal not in (None, 2)): (not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal not in (None, 2)):
self.mark_failed(status_reason='Exception') task_status = ('failed', 'Exception')
wait_for_uploads = False wait_for_uploads = False
else: else:
self.stopped()
wait_for_uploads = (self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None) wait_for_uploads = (self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None)
if not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None:
task_status = ('completed', )
else: else:
self.stopped() task_status = ('stopped', )
# wait for uploads # wait for uploads
print_done_waiting = False print_done_waiting = False
@ -960,6 +936,21 @@ class Task(_Task):
self.log.info('Finished uploading') self.log.info('Finished uploading')
else: else:
self._logger._flush_stdout_handler() self._logger._flush_stdout_handler()
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
# change task status
if not task_status:
pass
elif task_status[0] == 'failed':
self.mark_failed(status_reason=task_status[1])
elif task_status[0] == 'completed':
self.completed()
elif task_status[0] == 'stopped':
self.stopped()
# stop resource monitoring # stop resource monitoring
if self._resource_monitor: if self._resource_monitor:
self._resource_monitor.stop() self._resource_monitor.stop()
@ -1203,10 +1194,6 @@ class Task(_Task):
if not task_data: if not task_data:
return False return False
# in dev-worker mode, never reuse a task
if DevWorker.is_enabled():
return False
if cls.__task_timed_out(task_data): if cls.__task_timed_out(task_data):
return False return False
@ -1236,8 +1223,9 @@ class Task(_Task):
(task.type, 'type'), (task.type, 'type'),
) )
return all(server_data == task_data.get(task_data_key) # compare after casting to string to avoid enum instance issues
for server_data, task_data_key in compares) # remember we might have replaced the api version by now, so enums are different
return all(str(server_data) == str(task_data.get(task_data_key)) for server_data, task_data_key in compares)
@classmethod @classmethod
def __close_timed_out_task(cls, task_data): def __close_timed_out_task(cls, task_data):
@ -1255,6 +1243,7 @@ class Task(_Task):
tasks.TaskStatusEnum.publishing, tasks.TaskStatusEnum.publishing,
tasks.TaskStatusEnum.closed, tasks.TaskStatusEnum.closed,
tasks.TaskStatusEnum.failed, tasks.TaskStatusEnum.failed,
tasks.TaskStatusEnum.completed,
) )
if task.status not in stopped_statuses: if task.status not in stopped_statuses: