diff --git a/trains/task.py b/trains/task.py index 59bd8ca4..58f216e5 100644 --- a/trains/task.py +++ b/trains/task.py @@ -11,15 +11,13 @@ import psutil import six from .backend_api.services import tasks, projects -from .backend_interface import TaskStatusEnum from .backend_interface.model import Model as BackendModel from .backend_interface.task import Task as _Task from .backend_interface.task.args import _Arguments -from .backend_interface.task.development.stop_signal import TaskStopSignal from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.repo import ScriptInfo from .backend_interface.util import get_single_result, exact_match_regex, make_message -from .config import config, PROC_MASTER_ID_ENV_VAR +from .config import config, PROC_MASTER_ID_ENV_VAR, DEV_TASK_NO_REUSE from .config import running_remotely, get_remote_task_id from .config.cache import SessionCache from .debugging.log import LoggerRoot @@ -61,7 +59,7 @@ class Task(_Task): **Usage: Task.init(...), Task.create() or Task.get_task(...)** """ - TaskTypes = tasks.TaskTypeEnum + TaskTypes = _Task.TaskTypes __create_protection = object() __main_task = None @@ -99,8 +97,6 @@ class Task(_Task): self._last_input_model_id = None self._connected_output_model = None self._dev_worker = None - self._dev_stop_signal = None - self._dev_mode_periodic_flag = False self._connected_parameter_type = None self._detect_repo_async_thread = None self._resource_monitor = None @@ -303,8 +299,6 @@ class Task(_Task): if task._dev_worker: task._dev_worker.unregister() task._dev_worker = None - if task._dev_stop_signal: - task._dev_stop_signal = None @classmethod def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id): @@ -326,16 +320,20 @@ class Task(_Task): except Exception: pass - # if we have a previous session to use, get the task id from it - default_task = cls.__get_last_used_task_id( - default_project_name, - default_task_name, - default_task_type.value, - ) + # if we force no task reuse from os environment + if DEV_TASK_NO_REUSE.get(): + default_task = None + else: + # if we have a previous session to use, get the task id from it + default_task = cls.__get_last_used_task_id( + default_project_name, + default_task_name, + default_task_type.value, + ) closed_old_task = False default_task_id = None - in_dev_mode = not running_remotely() and not DevWorker.is_enabled() + in_dev_mode = not running_remotely() if in_dev_mode: if not reuse_last_task_id or not cls.__task_is_relevant(default_task): @@ -351,7 +349,7 @@ class Task(_Task): task_id=default_task_id, log_to_backend=True, ) - if ((task.status in (TaskStatusEnum.published, TaskStatusEnum.closed)) + if ((task.status in (tasks.TaskStatusEnum.published, tasks.TaskStatusEnum.closed)) or (ARCHIVED_TAG in task.data.tags) or task.output_model_id): # If the task is published or closed, we shouldn't reset it so we can't use it in dev mode # If the task is archived, or already has an output model, @@ -402,6 +400,7 @@ class Task(_Task): # update current repository and put warning into logs if in_dev_mode and cls.__detect_repo_async: task._detect_repo_async_thread = threading.Thread(target=task._update_repository) + task._detect_repo_async_thread.daemon = True task._detect_repo_async_thread.start() else: task._update_repository() @@ -554,8 +553,6 @@ class Task(_Task): :param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed :return: True """ - self._dev_mode_periodic() - # wait for detection repo sync if self._detect_repo_async_thread: with self._lock: @@ -826,10 +823,11 @@ class Task(_Task): """ Called when we suspect the task has started running """ self._dev_mode_setup_worker(model_updated=model_updated) - if TaskStopSignal.enabled and not self._dev_stop_signal: - self._dev_stop_signal = TaskStopSignal(task=self) - def _dev_mode_stop_task(self, stop_reason): + # make sure we do not get called (by a daemon thread) after at_exit + if self._at_exit_called: + return + self.get_logger().warn( "### TASK STOPPED - USER ABORTED - {} ###".format( stop_reason.upper().replace('_', ' ') @@ -871,23 +869,6 @@ class Task(_Task): else: parent.terminate() - def _dev_mode_periodic(self): - if self._dev_mode_periodic_flag or not self.is_main_task(): - # Ensures we won't get into an infinite recursion since we might call self.flush() down the line - return - self._dev_mode_periodic_flag = True - - try: - if self._dev_stop_signal: - stop_reason = self._dev_stop_signal.test() - if stop_reason and not self._at_exit_called: - self._dev_mode_stop_task(stop_reason) - - if self._dev_worker: - self._dev_worker.status_report() - finally: - self._dev_mode_periodic_flag = False - def _dev_mode_setup_worker(self, model_updated=False): if running_remotely() or not self.is_main_task(): return @@ -895,11 +876,8 @@ class Task(_Task): if self._dev_worker: return self._dev_worker - if not DevWorker.is_enabled(model_updated): - return None - self._dev_worker = DevWorker() - self._dev_worker.register() + self._dev_worker.register(self) logger = self.get_logger() flush_period = logger.get_flush_period() @@ -927,26 +905,24 @@ class Task(_Task): try: # from here do not get into watch dog self._at_exit_called = True - self._dev_stop_signal = None - self._dev_mode_periodic_flag = True wait_for_uploads = True # first thing mark task as stopped, so we will not end up with "running" on lost tasks # if we are running remotely, the daemon will take care of it + task_status = None if not running_remotely() and self.is_main_task(): - # from here, do not check worker status - if self._dev_worker: - self._dev_worker.unregister() # check if we crashed, ot the signal is not interrupt (manual break) + task_status = ('stopped', ) if self.__exit_hook: if self.__exit_hook.exception is not None or \ (not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal not in (None, 2)): - self.mark_failed(status_reason='Exception') + task_status = ('failed', 'Exception') wait_for_uploads = False else: - self.stopped() wait_for_uploads = (self.__exit_hook.remote_user_aborted or self.__exit_hook.signal is None) - else: - self.stopped() + if not self.__exit_hook.remote_user_aborted and self.__exit_hook.signal is None: + task_status = ('completed', ) + else: + task_status = ('stopped', ) # wait for uploads print_done_waiting = False @@ -960,6 +936,21 @@ class Task(_Task): self.log.info('Finished uploading') else: self._logger._flush_stdout_handler() + + # from here, do not check worker status + if self._dev_worker: + self._dev_worker.unregister() + + # change task status + if not task_status: + pass + elif task_status[0] == 'failed': + self.mark_failed(status_reason=task_status[1]) + elif task_status[0] == 'completed': + self.completed() + elif task_status[0] == 'stopped': + self.stopped() + # stop resource monitoring if self._resource_monitor: self._resource_monitor.stop() @@ -1203,10 +1194,6 @@ class Task(_Task): if not task_data: return False - # in dev-worker mode, never reuse a task - if DevWorker.is_enabled(): - return False - if cls.__task_timed_out(task_data): return False @@ -1236,8 +1223,9 @@ class Task(_Task): (task.type, 'type'), ) - return all(server_data == task_data.get(task_data_key) - for server_data, task_data_key in compares) + # compare after casting to string to avoid enum instance issues + # remember we might have replaced the api version by now, so enums are different + return all(str(server_data) == str(task_data.get(task_data_key)) for server_data, task_data_key in compares) @classmethod def __close_timed_out_task(cls, task_data): @@ -1255,6 +1243,7 @@ class Task(_Task): tasks.TaskStatusEnum.publishing, tasks.TaskStatusEnum.closed, tasks.TaskStatusEnum.failed, + tasks.TaskStatusEnum.completed, ) if task.status not in stopped_statuses: