Add permission check for Task.init(output_uri=)

This commit is contained in:
allegroai 2019-08-19 21:21:21 +03:00
parent 1f81a109b7
commit 25e3816484

View File

@ -272,8 +272,15 @@ class Task(_Task):
# Make sure we start the logger, it will patch the main logging object and pipe all output
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
# The logger will automatically take care of all patching (we just need to make sure to initialize it)
task.get_logger()
logger = task.get_logger()
# show the debug metrics page in the log, it is very convenient
logger.console(
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
task._get_app_server(),
task.project if task.project is not None else '*',
task.id,
),
)
# Make sure we start the dev worker if required, otherwise it will only be started when we write
# something to the log.
task._dev_mode_task_start()
@ -378,7 +385,7 @@ class Task(_Task):
task_id=default_task_id,
log_to_backend=True,
)
if ((task.status in (tasks.TaskStatusEnum.published, tasks.TaskStatusEnum.closed))
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
or (ARCHIVED_TAG in task.data.tags) or task.output_model_id):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
# If the task is archived, or already has an output model,
@ -417,6 +424,8 @@ class Task(_Task):
# update this session, for later use
cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id)
# mark the task as started
task.started()
# force update of base logger to this current task (this is the main logger task)
task._setup_log(replace_existing=True)
logger = task.get_logger()
@ -436,14 +445,6 @@ class Task(_Task):
else:
task._update_repository()
# show the debug metrics page in the log, it is very convenient
logger.console(
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
task._get_app_server(),
task.project if task.project is not None else '*',
task.id,
),
)
# make sure everything is in sync
task.reload()
# make sure we see something in the UI
@ -470,6 +471,11 @@ class Task(_Task):
@output_uri.setter
def output_uri(self, value):
# check if we have the correct packages / configuration
if value and value != self.storage_uri:
from .storage.helper import StorageHelper
helper = StorageHelper.get(value)
helper.check_write_permissions()
self.storage_uri = value
@property
@ -1358,15 +1364,15 @@ class Task(_Task):
return False
stopped_statuses = (
tasks.TaskStatusEnum.stopped,
tasks.TaskStatusEnum.published,
tasks.TaskStatusEnum.publishing,
tasks.TaskStatusEnum.closed,
tasks.TaskStatusEnum.failed,
tasks.TaskStatusEnum.completed,
str(tasks.TaskStatusEnum.stopped),
str(tasks.TaskStatusEnum.published),
str(tasks.TaskStatusEnum.publishing),
str(tasks.TaskStatusEnum.closed),
str(tasks.TaskStatusEnum.failed),
str(tasks.TaskStatusEnum.completed),
)
if task.status not in stopped_statuses:
if str(task.status) not in stopped_statuses:
cls._send(
cls._get_default_session(),
tasks.StoppedRequest(