Add new task status and support for reuse override

This commit is contained in:
allegroai 2019-07-06 23:02:58 +03:00
parent bcb10d7adb
commit e40d4f2f41

View File

@ -11,15 +11,13 @@ import psutil
import six
from .backend_api.services import tasks, projects
from .backend_interface import TaskStatusEnum
from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task
from .backend_interface.task.args import _Arguments
from .backend_interface.task.development.stop_signal import TaskStopSignal
from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.repo import ScriptInfo
from .backend_interface.util import get_single_result, exact_match_regex, make_message
from .config import config, PROC_MASTER_ID_ENV_VAR
from .config import config, PROC_MASTER_ID_ENV_VAR, DEV_TASK_NO_REUSE
from .config import running_remotely, get_remote_task_id
from .config.cache import SessionCache
from .debugging.log import LoggerRoot
@ -61,7 +59,7 @@ class Task(_Task):
**Usage: Task.init(...), Task.create() or Task.get_task(...)**
"""
TaskTypes = tasks.TaskTypeEnum
TaskTypes = _Task.TaskTypes
__create_protection = object()
__main_task = None
@ -99,8 +97,6 @@ class Task(_Task):
self._last_input_model_id = None
self._connected_output_model = None
self._dev_worker = None
self._dev_stop_signal = None
self._dev_mode_periodic_flag = False
self._connected_parameter_type = None
self._detect_repo_async_thread = None
self._resource_monitor = None
@ -303,8 +299,6 @@ class Task(_Task):
if task._dev_worker:
task._dev_worker.unregister()
task._dev_worker = None
if task._dev_stop_signal:
task._dev_stop_signal = None
@classmethod
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
@ -326,16 +320,20 @@ class Task(_Task):
except Exception:
pass
# if we have a previous session to use, get the task id from it
default_task = cls.__get_last_used_task_id(
default_project_name,
default_task_name,
default_task_type.value,
)
# if we force no task reuse from os environment
if DEV_TASK_NO_REUSE.get():
default_task = None
else:
# if we have a previous session to use, get the task id from it
default_task = cls.__get_last_used_task_id(
default_project_name,
default_task_name,
default_task_type.value,
)
closed_old_task = False
default_task_id = None
in_dev_mode = not running_remotely() and not DevWorker.is_enabled()
in_dev_mode = not running_remotely()
if in_dev_mode:
if not reuse_last_task_id or not cls.__task_is_relevant(default_task):
@ -351,7 +349,7 @@ class Task(_Task):
task_id=default_task_id,
log_to_backend=True,
)
if ((task.status in (TaskStatusEnum.published, TaskStatusEnum.closed))
if ((task.status in (tasks.TaskStatusEnum.published, tasks.TaskStatusEnum.closed))
or (ARCHIVED_TAG in task.data.tags) or task.output_model_id):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
# If the task is archived, or already has an output model,
@ -402,6 +400,7 @@ class Task(_Task):
# update current repository and put warning into logs
if in_dev_mode and cls.__detect_repo_async:
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.daemon = True
task._detect_repo_async_thread.start()
else:
task._update_repository()
@ -554,8 +553,6 @@ class Task(_Task):
:param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed
:return: True
"""
self._dev_mode_periodic()
# wait for detection repo sync
if self._detect_repo_async_thread:
with self._lock:
@ -826,10 +823,11 @@ class Task(_Task):
""" Called when we suspect the task has started running """
self._dev_mode_setup_worker(model_updated=model_updated)
if TaskStopSignal.enabled and not self._dev_stop_signal:
self._dev_stop_signal = TaskStopSignal(task=self)
def _dev_mode_stop_task(self, stop_reason):
# make sure we do not get called (by a daemon thread) after at_exit
if self._at_exit_called:
return
self.get_logger().warn(
"### TASK STOPPED - USER ABORTED - {} ###".format(
stop_reason.upper().replace('_', ' ')
@ -871,23 +869,6 @@ class Task(_Task):
else:
parent.terminate()
def _dev_mode_periodic(self):
if self._dev_mode_periodic_flag or not self.is_main_task():
# Ensures we won't get into an infinite recursion since we might call self.flush() down the line
return
self._dev_mode_periodic_flag = True
try:
if self._dev_stop_signal:
stop_reason = self._dev_stop_signal.test()
if stop_reason and not self._at_exit_called:
self._dev_mode_stop_task(stop_reason)
if self._dev_worker:
self._dev_worker.status_report()
finally:
self._dev_mode_periodic_flag = False
def _dev_mode_setup_worker(self, model_updated=False):
if running_remotely() or not self.is_main_task():
return
@ -895,11 +876,8 @@ class Task(_Task):
if self._dev_worker:
return self._dev_worker
if not DevWorker.is_enabled(model_updated):
return None
self._dev_worker = DevWorker()
self._dev_worker.register()
self._dev_worker.register(self)
logger = self.get_logger()
flush_period = logger.get_flush_period()
@ -927,26 +905,24 @@ class Task(_Task):
try:
# from here do not get into watch dog
self._at_exit_called = True
self._dev_stop_signal = None
self._dev_mode_periodic_flag = True
wait_for_uploads = True
# first thing mark task as stopped, so we will not end up with "running" on lost tasks
# if we are running remotely, the daemon will take care of it
task_status = None
if not running_remotely() and self.is_main_task():
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
# check if we crashed, ot the signal is not interrupt (manual break)
task_status = ('stopped', )
if self.__exit_hook:
if self.__exit_hook.exception is not None or \
(not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal not in (None, 2)):
self.mark_failed(status_reason='Exception')
task_status = ('failed', 'Exception')
wait_for_uploads = False
else:
self.stopped()
wait_for_uploads = (self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None)
else:
self.stopped()
if not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None:
task_status = ('completed', )
else:
task_status = ('stopped', )
# wait for uploads
print_done_waiting = False
@ -960,6 +936,21 @@ class Task(_Task):
self.log.info('Finished uploading')
else:
self._logger._flush_stdout_handler()
# from here, do not check worker status
if self._dev_worker:
self._dev_worker.unregister()
# change task status
if not task_status:
pass
elif task_status[0] == 'failed':
self.mark_failed(status_reason=task_status[1])
elif task_status[0] == 'completed':
self.completed()
elif task_status[0] == 'stopped':
self.stopped()
# stop resource monitoring
if self._resource_monitor:
self._resource_monitor.stop()
@ -1203,10 +1194,6 @@ class Task(_Task):
if not task_data:
return False
# in dev-worker mode, never reuse a task
if DevWorker.is_enabled():
return False
if cls.__task_timed_out(task_data):
return False
@ -1236,8 +1223,9 @@ class Task(_Task):
(task.type, 'type'),
)
return all(server_data == task_data.get(task_data_key)
for server_data, task_data_key in compares)
# compare after casting to string to avoid enum instance issues
# remember we might have replaced the api version by now, so enums are different
return all(str(server_data) == str(task_data.get(task_data_key)) for server_data, task_data_key in compares)
@classmethod
def __close_timed_out_task(cls, task_data):
@ -1255,6 +1243,7 @@ class Task(_Task):
tasks.TaskStatusEnum.publishing,
tasks.TaskStatusEnum.closed,
tasks.TaskStatusEnum.failed,
tasks.TaskStatusEnum.completed,
)
if task.status not in stopped_statuses: