diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 9c6725b..1fa4810 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -31,6 +31,7 @@ from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator from apiserver.bll.model import ModelBLL +from apiserver.bll.task.utils import get_many_tasks_for_writing from apiserver.bll.util import parallel_chunked_decorator from apiserver.database import utils as dbutils from apiserver.database.model.model import Model @@ -42,6 +43,7 @@ from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.task.task import Task, TaskStatus from apiserver.redis_manager import redman +from apiserver.service_repo.auth import Identity from apiserver.tools import safe_get from apiserver.utilities.dicts import nested_get from apiserver.utilities.json import loads @@ -55,7 +57,9 @@ MIN_LONG = -(2**63) log = config.logger(__file__) async_task_events_delete = config.get("services.tasks.async_events_delete", False) -async_delete_threshold = config.get("services.tasks.async_events_delete_threshold", 100_000) +async_delete_threshold = config.get( + "services.tasks.async_events_delete_threshold", 100_000 +) class EventBLL(object): @@ -97,7 +101,9 @@ class EventBLL(object): return self._metrics @staticmethod - def _get_valid_entities(company_id, ids: Mapping[str, bool], model=False) -> Set: + def _get_valid_entities( + company_id, ids: Mapping[str, bool], identity: Identity, model=False + ) -> Set: """Verify that task or model exists and can be updated""" if not ids: return set() @@ -116,20 +122,34 @@ class EventBLL(object): ): if not requested_ids: continue - query = Q(id__in=requested_ids, company=company_id) - res.update( - (Model if model else Task).objects(query & locked_q).scalar("id") - ) + + query = Q(id__in=requested_ids) & locked_q + if model: + ids = Model.objects(query & Q(company=company_id)).scalar("id") + else: + ids = { + t.id + for t in get_many_tasks_for_writing( + company_id=company_id, + identity=identity, + query=query, + only=("id",), + throw_on_forbidden=False, + ) + } + + res.update(ids) return res def add_events( self, company_id: str, - user_id: str, + identity: Identity, events: Sequence[dict], worker: str, ) -> Tuple[int, int, dict]: + user_id = identity.user task_ids = {} model_ids = {} for event in events: @@ -161,8 +181,12 @@ class EventBLL(object): "Inconsistent model_event setting in the passed events", tasks=found_in_both, ) - valid_models = self._get_valid_entities(company_id, ids=model_ids, model=True) - valid_tasks = self._get_valid_entities(company_id, ids=task_ids) + valid_models = self._get_valid_entities( + company_id, ids=model_ids, identity=identity, model=True + ) + valid_tasks = self._get_valid_entities( + company_id, ids=task_ids, identity=identity + ) actions: List[dict] = [] used_task_ids = set() diff --git a/apiserver/bll/model/__init__.py b/apiserver/bll/model/__init__.py index 3215e91..3dc1ffe 100644 --- a/apiserver/bll/model/__init__.py +++ b/apiserver/bll/model/__init__.py @@ -10,6 +10,7 @@ from apiserver.config_repo import config from apiserver.database.model import EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.task.task import Task, TaskStatus +from apiserver.service_repo.auth import Identity from .metadata import Metadata @@ -57,14 +58,15 @@ class ModelBLL: cls, model_id: str, company_id: str, - user_id: str, + identity: Identity, force_publish_task: bool = False, - publish_task_func: Callable[[str, str, str, bool], dict] = None, + publish_task_func: Callable[[str, str, Identity, bool], dict] = None, ) -> Tuple[int, ModelTaskPublishResponse]: model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id) if model.ready: raise errors.bad_request.ModelIsReady(company=company_id, model=model_id) + user_id = identity.user published_task = None if model.task and publish_task_func: task = ( @@ -74,7 +76,7 @@ class ModelBLL: ) if task and task.status != TaskStatus.published: task_publish_res = publish_task_func( - model.task, company_id, user_id, force_publish_task + model.task, company_id, identity, force_publish_task ) published_task = ModelTaskPublishResponse( id=model.task, data=task_publish_res diff --git a/apiserver/bll/queue/queue_bll.py b/apiserver/bll/queue/queue_bll.py index 1c33651..65d4957 100644 --- a/apiserver/bll/queue/queue_bll.py +++ b/apiserver/bll/queue/queue_bll.py @@ -152,7 +152,7 @@ class QueueBLL(object): for item in queue.entries: try: - task = Task.get_for_writing( + task = Task.get( company=company_id, id=item.task, _only=[ diff --git a/apiserver/bll/task/artifacts.py b/apiserver/bll/task/artifacts.py index 7ff19bf..c374e5c 100644 --- a/apiserver/bll/task/artifacts.py +++ b/apiserver/bll/task/artifacts.py @@ -5,6 +5,7 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId from apiserver.bll.task.utils import get_task_for_update, update_task from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact from apiserver.database.utils import hash_field_name +from apiserver.service_repo.auth import Identity from apiserver.utilities.dicts import nested_get, nested_set from apiserver.utilities.parameter_key_escaper import mongoengine_safe @@ -48,12 +49,14 @@ class Artifacts: def add_or_update_artifacts( cls, company_id: str, - user_id: str, + identity: Identity, task_id: str, artifacts: Sequence[ApiArtifact], force: bool, ) -> int: - task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,) + task = get_task_for_update( + company_id=company_id, task_id=task_id, force=force, identity=identity + ) artifacts = { get_artifact_id(a): Artifact(**a) @@ -64,18 +67,20 @@ class Artifacts: f"set__execution__artifacts__{mongoengine_safe(name)}": value for name, value in artifacts.items() } - return update_task(task, user_id=user_id, update_cmds=update_cmds) + return update_task(task, user_id=identity.user, update_cmds=update_cmds) @classmethod def delete_artifacts( cls, company_id: str, - user_id: str, + identity: Identity, task_id: str, artifact_ids: Sequence[ArtifactId], force: bool, ) -> int: - task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,) + task = get_task_for_update( + company_id=company_id, task_id=task_id, force=force, identity=identity + ) artifact_ids = [ get_artifact_id(a) @@ -85,4 +90,4 @@ class Artifacts: f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids) } - return update_task(task, user_id=user_id, update_cmds=delete_cmds) + return update_task(task, user_id=identity.user, update_cmds=delete_cmds) diff --git a/apiserver/bll/task/hyperparams.py b/apiserver/bll/task/hyperparams.py index 4159b87..eae25b2 100644 --- a/apiserver/bll/task/hyperparams.py +++ b/apiserver/bll/task/hyperparams.py @@ -15,6 +15,7 @@ from apiserver.bll.task import TaskBLL from apiserver.bll.task.utils import get_task_for_update, update_task from apiserver.config_repo import config from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem +from apiserver.service_repo.auth import Identity from apiserver.utilities.parameter_key_escaper import ( ParameterKeyEscaper, mongoengine_safe, @@ -31,7 +32,10 @@ class HyperParams: def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]: only = ("id", "hyperparams") tasks = task_bll.assert_exists( - company_id=company_id, task_ids=task_ids, only=only, allow_public=True, + company_id=company_id, + task_ids=task_ids, + only=only, + allow_public=True, ) return { @@ -63,7 +67,7 @@ class HyperParams: def delete_params( cls, company_id: str, - user_id: str, + identity: Identity, task_id: str, hyperparams: Sequence[HyperParamKey], force: bool, @@ -74,6 +78,7 @@ class HyperParams: task_id=task_id, allow_all_statuses=properties_only, force=force, + identity=identity, ) with_param, without_param = iterutils.partition( @@ -96,7 +101,7 @@ class HyperParams: return update_task( task, - user_id=user_id, + user_id=identity.user, update_cmds=delete_cmds, set_last_update=not properties_only, ) @@ -105,7 +110,7 @@ class HyperParams: def edit_params( cls, company_id: str, - user_id: str, + identity: Identity, task_id: str, hyperparams: Sequence[HyperParamItem], replace_hyperparams: str, @@ -117,6 +122,7 @@ class HyperParams: task_id=task_id, allow_all_statuses=properties_only, force=force, + identity=identity, ) update_cmds = dict() @@ -135,7 +141,7 @@ class HyperParams: return update_task( task, - user_id=user_id, + user_id=identity.user, update_cmds=update_cmds, set_last_update=not properties_only, ) @@ -163,7 +169,10 @@ class HyperParams: else: only.append("configuration") tasks = task_bll.assert_exists( - company_id=company_id, task_ids=task_ids, only=only, allow_public=True, + company_id=company_id, + task_ids=task_ids, + only=only, + allow_public=True, ) return { @@ -209,13 +218,15 @@ class HyperParams: def edit_configuration( cls, company_id: str, - user_id: str, + identity: Identity, task_id: str, configuration: Sequence[Configuration], replace_configuration: bool, force: bool, ) -> int: - task = get_task_for_update(company_id=company_id, task_id=task_id, force=force) + task = get_task_for_update( + company_id=company_id, task_id=task_id, force=force, identity=identity + ) update_cmds = dict() configuration = { @@ -228,22 +239,24 @@ class HyperParams: for name, value in configuration.items(): update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value - return update_task(task, user_id=user_id, update_cmds=update_cmds) + return update_task(task, user_id=identity.user, update_cmds=update_cmds) @classmethod def delete_configuration( cls, company_id: str, - user_id: str, + identity: Identity, task_id: str, configuration: Sequence[str], force: bool, ) -> int: - task = get_task_for_update(company_id=company_id, task_id=task_id, force=force) + task = get_task_for_update( + company_id=company_id, task_id=task_id, force=force, identity=identity + ) delete_cmds = { f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1 for name in set(configuration) } - return update_task(task, user_id=user_id, update_cmds=delete_cmds) + return update_task(task, user_id=identity.user, update_cmds=delete_cmds) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index 657786d..6c99b23 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -58,27 +58,6 @@ class TaskBLL: self.events_es = events_es or es_factory.connect("events") self.redis: StrictRedis = redis or redman.connection("apiserver") - @staticmethod - def get_task_with_access( - task_id, company_id, only=None, allow_public=False, requires_write_access=False - ) -> Task: - """ - Gets a task that has a required write access - :except errors.bad_request.InvalidTaskId: if the task is not found - :except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified - """ - with translate_errors_context(): - query = dict(id=task_id, company=company_id) - if requires_write_access: - task = Task.get_for_writing(_only=only, **query) - else: - task = Task.get(_only=only, **query, include_public=allow_public) - - if not task: - raise errors.bad_request.InvalidTaskId(**query) - - return task - @staticmethod def get_by_id( company_id, diff --git a/apiserver/bll/task/task_operations.py b/apiserver/bll/task/task_operations.py index 8edbd97..6b65161 100644 --- a/apiserver/bll/task/task_operations.py +++ b/apiserver/bll/task/task_operations.py @@ -9,6 +9,7 @@ from apiserver.bll.task import ( ChangeStatusRequest, ) from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult +from apiserver.bll.task.utils import get_task_with_write_access from apiserver.bll.util import update_project_time from apiserver.config_repo import config from apiserver.database.model import EntityVisibility @@ -24,6 +25,7 @@ from apiserver.database.model.task.task import ( DEFAULT_LAST_ITERATION, ) from apiserver.database.utils import get_options +from apiserver.service_repo.auth import Identity from apiserver.utilities.dicts import nested_set log = config.logger(__file__) @@ -33,7 +35,7 @@ queue_bll = QueueBLL() def archive_task( task: Union[str, Task], company_id: str, - user_id: str, + identity: Identity, status_message: str, status_reason: str, ) -> int: @@ -42,9 +44,10 @@ def archive_task( Return 1 if successful """ if isinstance(task, str): - task = TaskBLL.get_task_with_access( + task = get_task_with_write_access( task, company_id=company_id, + identity=identity, only=( "id", "company", @@ -54,8 +57,9 @@ def archive_task( "system_tags", "enqueue_status", ), - requires_write_access=True, ) + + user_id = identity.user try: TaskBLL.dequeue_and_change_status( task, @@ -79,34 +83,34 @@ def archive_task( def unarchive_task( - task: str, + task_id: str, company_id: str, - user_id: str, + identity: Identity, status_message: str, status_reason: str, ) -> int: """ Unarchive task. Return 1 if successful """ - task = TaskBLL.get_task_with_access( - task, + task = get_task_with_write_access( + task_id, company_id=company_id, + identity=identity, only=("id",), - requires_write_access=True, ) return task.update( status_message=status_message, status_reason=status_reason, pull__system_tags=EntityVisibility.archived.value, last_change=datetime.utcnow(), - last_changed_by=user_id, + last_changed_by=identity.user, ) def dequeue_task( task_id: str, company_id: str, - user_id: str, + identity: Identity, status_message: str, status_reason: str, remove_from_all_queues: bool = False, @@ -119,7 +123,19 @@ def dequeue_task( task = Task.get( id=task_id, company=company_id, - _only=( + _only=("id",), + include_public=True, + ) + if not task: + TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id) + return 1, {"updated": 0} + + user_id = identity.user + task = get_task_with_write_access( + task_id, + company_id=company_id, + identity=identity, + only=( "id", "company", "execution", @@ -127,11 +143,7 @@ def dequeue_task( "project", "enqueue_status", ), - include_public=True, ) - if not task: - TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id) - return 1, {"updated": 0} res = TaskBLL.dequeue_and_change_status( task, @@ -148,7 +160,7 @@ def dequeue_task( def enqueue_task( task_id: str, company_id: str, - user_id: str, + identity: Identity, queue_id: str, status_message: str, status_reason: str, @@ -173,11 +185,11 @@ def enqueue_task( # try to get default queue queue_id = queue_bll.get_default(company_id).id - query = dict(id=task_id, company=company_id) - task = Task.get_for_writing(**query) - if not task: - raise errors.bad_request.InvalidTaskId(**query) + task = get_task_with_write_access( + task_id=task_id, company_id=company_id, identity=identity + ) + user_id = identity.user if validate: TaskBLL.validate(task) @@ -207,9 +219,9 @@ def enqueue_task( # set the current queue ID in the task if task.execution: - Task.objects(**query).update(execution__queue=queue_id, multi=False) + Task.objects(id=task_id).update(execution__queue=queue_id, multi=False) else: - Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False) + Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False) nested_set(res, ("fields", "execution.queue"), queue_id) return 1, res @@ -242,7 +254,7 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int: def delete_task( task_id: str, company_id: str, - user_id: str, + identity: Identity, move_to_trash: bool, force: bool, return_file_urls: bool, @@ -251,8 +263,9 @@ def delete_task( status_reason: str, delete_external_artifacts: bool, ) -> Tuple[int, Task, CleanupResult]: - task = TaskBLL.get_task_with_access( - task_id, company_id=company_id, requires_write_access=True + user_id = identity.user + task = get_task_with_write_access( + task_id, company_id=company_id, identity=identity ) if ( @@ -305,15 +318,16 @@ def delete_task( def reset_task( task_id: str, company_id: str, - user_id: str, + identity: Identity, force: bool, return_file_urls: bool, delete_output_models: bool, clear_all: bool, delete_external_artifacts: bool, ) -> Tuple[dict, CleanupResult, dict]: - task = TaskBLL.get_task_with_access( - task_id, company_id=company_id, requires_write_access=True + user_id = identity.user + task = get_task_with_write_access( + task_id, company_id=company_id, identity=identity ) if not force and task.status == TaskStatus.published: @@ -392,14 +406,15 @@ def reset_task( def publish_task( task_id: str, company_id: str, - user_id: str, + identity: Identity, force: bool, - publish_model_func: Callable[[str, str, str], Any] = None, + publish_model_func: Callable[[str, str, Identity], Any] = None, status_message: str = "", status_reason: str = "", ) -> dict: - task = TaskBLL.get_task_with_access( - task_id, company_id=company_id, requires_write_access=True + user_id = identity.user + task = get_task_with_write_access( + task_id, company_id=company_id, identity=identity ) if not force: validate_status_change(task.status, TaskStatus.published) @@ -422,7 +437,7 @@ def publish_task( .first() ) if model and not model.ready: - publish_model_func(model.id, company_id, user_id) + publish_model_func(model.id, company_id, identity) # set task status to published, and update (or set) it's new output (view and models) return ChangeStatusRequest( @@ -446,7 +461,7 @@ def publish_task( def stop_task( task_id: str, company_id: str, - user_id: str, + identity: Identity, user_name: str, status_reason: str, force: bool, @@ -459,10 +474,11 @@ def stop_task( is set to 'stopping' to allow the worker to stop the task and report by itself :return: updated task fields """ - - task = TaskBLL.get_task_with_access( + user_id = identity.user + task = get_task_with_write_access( task_id, company_id=company_id, + identity=identity, only=( "status", "project", @@ -472,7 +488,6 @@ def stop_task( "last_update", "execution.queue", ), - requires_write_access=True, ) def is_run_by_worker(t: Task) -> bool: diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py index 99991db..11a7cd9 100644 --- a/apiserver/bll/task/utils.py +++ b/apiserver/bll/task/utils.py @@ -1,7 +1,9 @@ from datetime import datetime +from typing import Sequence import attr import six +from mongoengine import Q from apiserver.apierrors import errors from apiserver.bll.util import update_project_time @@ -10,6 +12,7 @@ from apiserver.database.errors import translate_errors_context from apiserver.database.model.model import Model from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags from apiserver.database.utils import get_options +from apiserver.service_repo.auth import Identity from apiserver.utilities.attrs import typed_attrs valid_statuses = get_options(TaskStatus) @@ -157,15 +160,75 @@ def get_possible_status_changes(current_status): return possible +def get_many_tasks_for_writing( + company_id: str, + identity: Identity, + query: Q = None, + only: Sequence = None, + throw_on_forbidden: bool = True, +) -> Sequence[Task]: + if only: + missing = [f for f in ("company", ) if f not in only] + if missing: + only = [*only, *missing] + + result = list( + Task.get_many( + company=company_id, + query=query, + override_projection=only, + allow_public=True, + return_dicts=False, + ) + ) + + forbidden_tasks = {task.id for task in result if not task.company} + if forbidden_tasks: + if throw_on_forbidden: + raise errors.forbidden.NoWritePermission( + f"cannot modify public task(s), ids={tuple(forbidden_tasks)}" + ) + result = [task for task in result if task.id not in forbidden_tasks] + + return result + + +def get_task_with_write_access( + task_id: str, + company_id: str, + identity: Identity, + only=None, +) -> Task: + """ + Gets a task that has a required write access + :except errors.bad_request.InvalidTaskId: if the task is not found + :except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified + """ + query = dict(id=task_id, company=company_id) + + task = Task.get_for_writing(_only=only, **query) + if not task: + raise errors.bad_request.InvalidTaskId(**query) + + return task + + def get_task_for_update( - company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False + company_id: str, + task_id: str, + identity: Identity, + allow_all_statuses: bool = False, + force: bool = False ) -> Task: """ Loads only task id and return the task only if it is updatable (status == 'created') """ - task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status")) - if not task: - raise errors.bad_request.InvalidTaskId(id=task_id) + task = get_task_with_write_access( + task_id=task_id, + company_id=company_id, + only=("id", "status"), + identity=identity, + ) if allow_all_statuses: return task diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index e3c5625..2dff748 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -1283,21 +1283,6 @@ class GetMixin(PropsMixin): ) return result - @classmethod - def get_many_for_writing(cls, company, *args, **kwargs): - result = cls.get_many( - company=company, - *args, - **dict(return_dicts=False, **kwargs), - allow_public=True, - ) - forbidden_objects = {obj.id for obj in result if not obj.company} - if forbidden_objects: - object_name = cls.__name__.lower() - raise errors.forbidden.NoWritePermission( - f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}" - ) - return result class UpdateMixin(object): diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index 7860e65..351520e 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -44,6 +44,7 @@ from apiserver.bll.task.param_utils import ( from apiserver.config_repo import config from apiserver.config.info import get_default_company from apiserver.database.model import EntityVisibility, User +from apiserver.database.model.auth import Role from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.task import ( @@ -54,6 +55,7 @@ from apiserver.database.model.task.task import ( TaskModelNames, ) from apiserver.database.utils import get_options +from apiserver.service_repo.auth import Identity from apiserver.utilities import json from apiserver.utilities.dicts import nested_get, nested_set, nested_delete from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper @@ -717,7 +719,10 @@ class PrePopulate: @classmethod def _generate_new_ids( - cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any], + cls, + reader: ZipFile, + entity_files: Sequence, + metadata: Mapping[str, Any], ) -> Mapping[str, str]: if not metadata or not any( metadata.get(key) for key in ("new_ids", "example_ids", "private_ids") @@ -970,7 +975,7 @@ class PrePopulate: ev["allow_locked"] = True cls.event_bll.add_events( company_id=company_id, - user_id=user_id, + identity=Identity(user_id, company=company_id, role=Role.admin), events=events, worker="", ) diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 82e77a5..2ca5443 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -28,6 +28,7 @@ from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL from apiserver.bll.task import TaskBLL from apiserver.bll.task.task_operations import publish_task +from apiserver.bll.task.utils import get_task_with_write_access from apiserver.bll.util import run_batch_operation from apiserver.config_repo import config from apiserver.database.model import validate_id @@ -46,6 +47,7 @@ from apiserver.database.utils import ( filter_fields, ) from apiserver.service_repo import APICall, endpoint +from apiserver.service_repo.auth import Identity from apiserver.services.utils import ( conform_tag_fields, conform_output_tags, @@ -249,13 +251,12 @@ def update_for_task(call: APICall, company_id, _): ) query = dict(id=task_id, company=company_id) - task = Task.get_for_writing( - id=task_id, - company=company_id, - _only=["models", "execution", "name", "status", "project"], + task = get_task_with_write_access( + task_id=task_id, + company_id=company_id, + identity=call.identity, + only=("models", "execution", "name", "status", "project"), ) - if not task: - raise errors.bad_request.InvalidTaskId(**query) allowed_states = [TaskStatus.created, TaskStatus.in_progress] if task.status not in allowed_states: @@ -343,7 +344,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest): task = req_model.task req_data = req_model.to_struct() if task: - validate_task(company_id, req_data) + validate_task(company_id, call.identity, req_data) fields = filter_fields(Model, req_data) conform_tag_fields(call, fields, validate=True) @@ -373,7 +374,7 @@ def prepare_update_fields(call, company_id, fields: dict): # clear UI cache if URI is provided (model updated) fields["ui_cache"] = fields.pop("ui_cache", {}) if "task" in fields: - validate_task(company_id, fields) + validate_task(company_id, call.identity, fields) if "labels" in fields: labels = fields["labels"] @@ -403,8 +404,11 @@ def prepare_update_fields(call, company_id, fields: dict): return fields -def validate_task(company_id, fields: dict): - Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"]) +def validate_task(company_id: str, identity: Identity, fields: dict): + task_id = fields["task"] + get_task_with_write_access( + task_id=task_id, company_id=company_id, identity=identity, only=("id",) + ) @endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse) @@ -514,7 +518,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest): updated, published_task = ModelBLL.publish_model( model_id=request.model, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, force_publish_task=request.force_publish_task, publish_task_func=publish_task if request.publish_task else None, ) @@ -533,7 +537,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest): func=partial( ModelBLL.publish_model, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, force_publish_task=request.force_publish_task, publish_task_func=publish_task if request.publish_task else None, ), diff --git a/apiserver/services/pipelines.py b/apiserver/services/pipelines.py index 0767489..72decf4 100644 --- a/apiserver/services/pipelines.py +++ b/apiserver/services/pipelines.py @@ -57,7 +57,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest): func=partial( delete_task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, move_to_trash=False, force=True, return_file_urls=False, @@ -108,7 +108,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest queued, res = enqueue_task( task_id=task.id, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, queue_id=request.queue, status_message="Starting pipeline", status_reason="", diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 15028a8..530c45e 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -100,7 +100,13 @@ from apiserver.bll.task.task_operations import ( unarchive_task, move_tasks_to_trash, ) -from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix +from apiserver.bll.task.utils import ( + update_task, + get_task_for_update, + deleted_prefix, + get_many_tasks_for_writing, + get_task_with_write_access, +) from apiserver.bll.util import run_batch_operation, update_project_time from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility @@ -118,6 +124,7 @@ from apiserver.database.utils import ( get_options, ) from apiserver.service_repo import APICall, endpoint +from apiserver.service_repo.auth import Identity from apiserver.services.utils import ( conform_tag_fields, conform_output_tags, @@ -142,14 +149,34 @@ org_bll = OrgBLL() project_bll = ProjectBLL() +def _assert_writable_tasks( + company_id: str, identity: Identity, ids: Sequence[str], only=("id",) +) -> Sequence[Task]: + tasks = get_many_tasks_for_writing( + company_id=company_id, + identity=identity, + query=Q(id__in=ids), + only=only, + ) + missing_ids = set(ids) - {t.id for t in tasks} + if missing_ids: + raise errors.bad_request.InvalidTaskId(ids=list(missing_ids)) + + return tasks + + def set_task_status_from_call( - request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields + request: UpdateRequest, + company_id: str, + identity: Identity, + new_status=None, + **set_fields, ) -> dict: - task = TaskBLL.get_task_with_access( + task = get_task_with_write_access( request.task, company_id=company_id, + identity=identity, only=("id", "status", "project"), - requires_write_access=True, ) status_reason = request.status_reason @@ -161,15 +188,17 @@ def set_task_status_from_call( status_reason=status_reason, status_message=status_message, force=force, - user_id=user_id, + user_id=identity.user, ).execute(**set_fields) @endpoint("tasks.get_by_id", request_data_model=TaskRequest) -def get_by_id(call: APICall, company_id, req_model: TaskRequest): - task = TaskBLL.get_task_with_access( - req_model.task, company_id=company_id, allow_public=True - ) +def get_by_id(call: APICall, company_id, request: TaskRequest): + task = TaskBLL.assert_exists( + company_id, + task_ids=request.task, + allow_public=True, + )[0] task_dict = task.to_proper_dict() conform_task_data(call, task_dict) call.result.data = {"task": task_dict} @@ -227,7 +256,9 @@ def get_by_id_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) call_data = escape_execution_parameters(call.data) tasks = Task.get_many_with_join( - company=company_id, query_dict=call_data, allow_public=True, + company=company_id, + query_dict=call_data, + allow_public=True, ) conform_task_data(call, tasks) @@ -278,7 +309,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest): **stop_task( task_id=req_model.task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, user_name=call.identity.user_name, status_reason=req_model.status_reason, force=req_model.force, @@ -296,7 +327,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest): func=partial( stop_task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, user_name=call.identity.user_name, status_reason=request.status_reason, force=request.force, @@ -319,7 +350,7 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest): **set_task_status_from_call( req_model, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, new_status=TaskStatus.stopped, completed=datetime.utcnow(), ) @@ -336,7 +367,7 @@ def started(call: APICall, company_id, req_model: UpdateRequest): **set_task_status_from_call( req_model, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, new_status=TaskStatus.in_progress, min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value ) @@ -353,7 +384,7 @@ def failed(call: APICall, company_id, req_model: UpdateRequest): **set_task_status_from_call( req_model, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, new_status=TaskStatus.failed, ) ) @@ -367,7 +398,7 @@ def close(call: APICall, company_id, req_model: UpdateRequest): **set_task_status_from_call( req_model, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, new_status=TaskStatus.closed, ) ) @@ -433,13 +464,17 @@ def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]): for data in tasks_data: params_unprepare_from_saved( - fields=data, copy_to_legacy=need_legacy_params, + fields=data, + copy_to_legacy=need_legacy_params, ) artifacts_unprepare_from_saved(fields=data) def prepare_create_fields( - call: APICall, valid_fields=None, output=None, previous_task: Task = None, + call: APICall, + valid_fields=None, + output=None, + previous_task: Task = None, ): valid_fields = valid_fields if valid_fields is not None else create_fields t_fields = task_fields @@ -566,11 +601,12 @@ def update(call: APICall, company_id, req_model: UpdateRequest): task_id = req_model.task with translate_errors_context(): - task = Task.get_for_writing( - id=task_id, company=company_id, _only=["id", "project"] + task = get_task_with_write_access( + task_id=task_id, + company_id=company_id, + identity=call.identity, + only=("id", "project"), ) - if not task: - raise errors.bad_request.InvalidTaskId(id=task_id) partial_update_dict, valid_fields = prepare_update_fields(call, call.data) @@ -582,7 +618,8 @@ def update(call: APICall, company_id, req_model: UpdateRequest): id=task_id, partial_update_dict=partial_update_dict, injected_update=dict( - last_change=datetime.utcnow(), last_changed_by=call.identity.user, + last_change=datetime.utcnow(), + last_changed_by=call.identity.user, ), ) if updated_count: @@ -606,11 +643,11 @@ def update(call: APICall, company_id, req_model: UpdateRequest): def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest): requirements = req_model.requirements with translate_errors_context(): - task = TaskBLL.get_task_with_access( + task = get_task_with_write_access( req_model.task, company_id=company_id, + identity=call.identity, only=("status", "script"), - requires_write_access=True, ) if not task.script: raise errors.bad_request.MissingTaskFields( @@ -636,8 +673,11 @@ def update_batch(call: APICall, company_id, _): items = {i["task"]: i for i in items} tasks = { t.id: t - for t in Task.get_many_for_writing( - company=company_id, query=Q(id__in=list(items)) + for t in _assert_writable_tasks( + identity=call.identity, + company_id=company_id, + ids=list(items), + only=("id", "project"), ) } @@ -656,7 +696,8 @@ def update_batch(call: APICall, company_id, _): if not partial_update_dict: continue partial_update_dict.update( - last_change=now, last_changed_by=call.identity.user, + last_change=now, + last_changed_by=call.identity.user, ) update_op = UpdateOne( {"_id": id, "company": company_id}, {"$set": partial_update_dict} @@ -690,9 +731,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): force = req_model.force with translate_errors_context(): - task = Task.get_for_writing(id=task_id, company=company_id) - if not task: - raise errors.bad_request.InvalidTaskId(id=task_id) + task = get_task_with_write_access( + task_id=task_id, + company_id=company_id, + identity=call.identity, + ) if not force and task.status != TaskStatus.created: raise errors.bad_request.InvalidTaskStatus( @@ -756,7 +799,8 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): @endpoint( - "tasks.get_hyper_params", request_data_model=GetHyperParamsRequest, + "tasks.get_hyper_params", + request_data_model=GetHyperParamsRequest, ) def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest): tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks) @@ -771,7 +815,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest call.result.data = { "updated": HyperParams.edit_params( company_id, - user_id=call.identity.user, + identity=call.identity, task_id=request.task, hyperparams=request.hyperparams, replace_hyperparams=request.replace_hyperparams, @@ -785,7 +829,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq call.result.data = { "deleted": HyperParams.delete_params( company_id, - user_id=call.identity.user, + identity=call.identity, task_id=request.task, hyperparams=request.hyperparams, force=request.force, @@ -794,7 +838,8 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq @endpoint( - "tasks.get_configurations", request_data_model=GetConfigurationsRequest, + "tasks.get_configurations", + request_data_model=GetConfigurationsRequest, ) def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest): tasks_params = HyperParams.get_configurations( @@ -809,7 +854,8 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ @endpoint( - "tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest, + "tasks.get_configuration_names", + request_data_model=GetConfigurationNamesRequest, ) def get_configuration_names( call: APICall, company_id, request: GetConfigurationNamesRequest @@ -830,7 +876,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ call.result.data = { "updated": HyperParams.edit_configuration( company_id, - user_id=call.identity.user, + identity=call.identity, task_id=request.task, configuration=request.configuration, replace_configuration=request.replace_configuration, @@ -846,7 +892,7 @@ def delete_configuration( call.result.data = { "deleted": HyperParams.delete_configuration( company_id, - user_id=call.identity.user, + identity=call.identity, task_id=request.task, configuration=request.configuration, force=request.force, @@ -863,7 +909,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest): queued, res = enqueue_task( task_id=request.task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, queue_id=request.queue, status_message=request.status_message, status_reason=request.status_reason, @@ -888,7 +934,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest): func=partial( enqueue_task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, queue_id=request.queue, status_message=request.status_message, status_reason=request.status_reason, @@ -915,13 +961,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest): @endpoint( - "tasks.dequeue", response_data_model=DequeueResponse, + "tasks.dequeue", + response_data_model=DequeueResponse, ) def dequeue(call: APICall, company_id, request: DequeueRequest): dequeued, res = dequeue_task( task_id=request.task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, status_message=request.status_message, status_reason=request.status_reason, remove_from_all_queues=request.remove_from_all_queues, @@ -931,14 +978,15 @@ def dequeue(call: APICall, company_id, request: DequeueRequest): @endpoint( - "tasks.dequeue_many", response_data_model=DequeueManyResponse, + "tasks.dequeue_many", + response_data_model=DequeueManyResponse, ) def dequeue_many(call: APICall, company_id, request: DequeueManyRequest): results, failures = run_batch_operation( func=partial( dequeue_task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, status_message=request.status_message, status_reason=request.status_reason, remove_from_all_queues=request.remove_from_all_queues, @@ -962,7 +1010,7 @@ def reset(call: APICall, company_id, request: ResetRequest): dequeued, cleanup_res, updates = reset_task( task_id=request.task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, force=request.force, return_file_urls=request.return_file_urls, delete_output_models=request.delete_output_models, @@ -990,7 +1038,7 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest): func=partial( reset_task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, force=request.force, return_file_urls=request.return_file_urls, delete_output_models=request.delete_output_models, @@ -1027,9 +1075,11 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest): response_data_model=ArchiveResponse, ) def archive(call: APICall, company_id, request: ArchiveRequest): - tasks = TaskBLL.assert_exists( + archived = 0 + tasks = _assert_writable_tasks( company_id, - task_ids=request.tasks, + call.identity, + ids=request.tasks, only=( "id", "company", @@ -1040,11 +1090,10 @@ def archive(call: APICall, company_id, request: ArchiveRequest): "enqueue_status", ), ) - archived = 0 for task in tasks: archived += archive_task( company_id=company_id, - user_id=call.identity.user, + identity=call.identity, task=task, status_message=request.status_message, status_reason=request.status_reason, @@ -1063,7 +1112,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest): func=partial( archive_task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, status_message=request.status_message, status_reason=request.status_reason, ), @@ -1085,7 +1134,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest): func=partial( unarchive_task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, status_message=request.status_message, status_reason=request.status_reason, ), @@ -1104,7 +1153,7 @@ def delete(call: APICall, company_id, request: DeleteRequest): deleted, task, cleanup_res = delete_task( task_id=request.task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, move_to_trash=request.move_to_trash, force=request.force, return_file_urls=request.return_file_urls, @@ -1126,7 +1175,7 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest): func=partial( delete_task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, move_to_trash=request.move_to_trash, force=request.force, return_file_urls=request.return_file_urls, @@ -1164,7 +1213,7 @@ def publish(call: APICall, company_id, request: PublishRequest): updates = publish_task( task_id=request.task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, force=request.force, publish_model_func=ModelBLL.publish_model if request.publish_model else None, status_reason=request.status_reason, @@ -1183,7 +1232,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest): func=partial( publish_task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, force=request.force, publish_model_func=ModelBLL.publish_model if request.publish_model @@ -1211,7 +1260,7 @@ def completed(call: APICall, company_id, request: CompletedRequest): **set_task_status_from_call( request, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, new_status=TaskStatus.completed, completed=datetime.utcnow(), ) @@ -1221,7 +1270,7 @@ def completed(call: APICall, company_id, request: CompletedRequest): publish_res = publish_task( task_id=request.task, company_id=company_id, - user_id=call.identity.user, + identity=call.identity, force=request.force, publish_model_func=ModelBLL.publish_model, status_reason=request.status_reason, @@ -1256,7 +1305,7 @@ def add_or_update_artifacts( call.result.data = { "updated": Artifacts.add_or_update_artifacts( company_id=company_id, - user_id=call.identity.user, + identity=call.identity, task_id=request.task, artifacts=request.artifacts, force=True, @@ -1273,7 +1322,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest) call.result.data = { "deleted": Artifacts.delete_artifacts( company_id=company_id, - user_id=call.identity.user, + identity=call.identity, task_id=request.task, artifact_ids=request.artifacts, force=True, @@ -1310,6 +1359,7 @@ def move(call: APICall, company_id: str, request: MoveRequest): "project or project_name is required" ) + _assert_writable_tasks(company_id, call.identity, request.ids) updated_projects = set( t.project for t in Task.objects(id__in=request.ids).only("project") if t.project ) @@ -1330,7 +1380,8 @@ def move(call: APICall, company_id: str, request: MoveRequest): @endpoint("tasks.update_tags") -def update_tags(_, company_id: str, request: UpdateTagsRequest): +def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest): + _assert_writable_tasks(company_id, call.identity, request.ids) return { "updated": org_bll.edit_entity_tags( company_id=company_id, @@ -1344,7 +1395,9 @@ def update_tags(_, company_id: str, request: UpdateTagsRequest): @endpoint("tasks.add_or_update_model", min_version="2.13") def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest): - get_task_for_update(company_id=company_id, task_id=request.task, force=True) + get_task_for_update( + company_id=company_id, task_id=request.task, force=True, identity=call.identity + ) models_field = f"models__{request.type}" model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow()) @@ -1364,7 +1417,9 @@ def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelR @endpoint("tasks.delete_models", min_version="2.13") def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest): - task = get_task_for_update(company_id=company_id, task_id=request.task, force=True) + task = get_task_for_update( + company_id=company_id, task_id=request.task, force=True, identity=call.identity + ) delete_names = { type_: [m.name for m in request.models if m.type == type_] @@ -1377,6 +1432,8 @@ def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest): } updated = task.update( - last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands, + last_change=datetime.utcnow(), + last_changed_by=call.identity.user, + **commands, ) return {"updated": updated}