diff --git a/trains/task.py b/trains/task.py index 1bc8bd29..fe0e4412 100644 --- a/trains/task.py +++ b/trains/task.py @@ -272,8 +272,15 @@ class Task(_Task): # Make sure we start the logger, it will patch the main logging object and pipe all output # if we are running locally and using development mode worker, we will pipe all stdout to logger. # The logger will automatically take care of all patching (we just need to make sure to initialize it) - task.get_logger() - + logger = task.get_logger() + # show the debug metrics page in the log, it is very convenient + logger.console( + 'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format( + task._get_app_server(), + task.project if task.project is not None else '*', + task.id, + ), + ) # Make sure we start the dev worker if required, otherwise it will only be started when we write # something to the log. task._dev_mode_task_start() @@ -378,7 +385,7 @@ class Task(_Task): task_id=default_task_id, log_to_backend=True, ) - if ((task.status in (tasks.TaskStatusEnum.published, tasks.TaskStatusEnum.closed)) + if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) or (ARCHIVED_TAG in task.data.tags) or task.output_model_id): # If the task is published or closed, we shouldn't reset it so we can't use it in dev mode # If the task is archived, or already has an output model, @@ -417,6 +424,8 @@ class Task(_Task): # update this session, for later use cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id) + # mark the task as started + task.started() # force update of base logger to this current task (this is the main logger task) task._setup_log(replace_existing=True) logger = task.get_logger() @@ -436,14 +445,6 @@ class Task(_Task): else: task._update_repository() - # show the debug metrics page in the log, it is very convenient - logger.console( - 'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format( - task._get_app_server(), - task.project if task.project is not None else '*', - task.id, - ), - ) # make sure everything is in sync task.reload() # make sure we see something in the UI @@ -470,6 +471,11 @@ class Task(_Task): @output_uri.setter def output_uri(self, value): + # check if we have the correct packages / configuration + if value and value != self.storage_uri: + from .storage.helper import StorageHelper + helper = StorageHelper.get(value) + helper.check_write_permissions() self.storage_uri = value @property @@ -1358,15 +1364,15 @@ class Task(_Task): return False stopped_statuses = ( - tasks.TaskStatusEnum.stopped, - tasks.TaskStatusEnum.published, - tasks.TaskStatusEnum.publishing, - tasks.TaskStatusEnum.closed, - tasks.TaskStatusEnum.failed, - tasks.TaskStatusEnum.completed, + str(tasks.TaskStatusEnum.stopped), + str(tasks.TaskStatusEnum.published), + str(tasks.TaskStatusEnum.publishing), + str(tasks.TaskStatusEnum.closed), + str(tasks.TaskStatusEnum.failed), + str(tasks.TaskStatusEnum.completed), ) - if task.status not in stopped_statuses: + if str(task.status) not in stopped_statuses: cls._send( cls._get_default_session(), tasks.StoppedRequest(