Refactor check for tasks write permission

This commit is contained in:
allegroai
2024-01-10 15:08:20 +02:00
parent 88a7773621
commit a604451b01
13 changed files with 340 additions and 188 deletions

View File

@@ -31,6 +31,7 @@ from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.model import ModelBLL
from apiserver.bll.task.utils import get_many_tasks_for_writing
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model
@@ -42,6 +43,7 @@ from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.redis_manager import redman
from apiserver.service_repo.auth import Identity
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.json import loads
@@ -55,7 +57,9 @@ MIN_LONG = -(2**63)
log = config.logger(__file__)
async_task_events_delete = config.get("services.tasks.async_events_delete", False)
async_delete_threshold = config.get("services.tasks.async_events_delete_threshold", 100_000)
async_delete_threshold = config.get(
"services.tasks.async_events_delete_threshold", 100_000
)
class EventBLL(object):
@@ -97,7 +101,9 @@ class EventBLL(object):
return self._metrics
@staticmethod
def _get_valid_entities(company_id, ids: Mapping[str, bool], model=False) -> Set:
def _get_valid_entities(
company_id, ids: Mapping[str, bool], identity: Identity, model=False
) -> Set:
"""Verify that task or model exists and can be updated"""
if not ids:
return set()
@@ -116,20 +122,34 @@ class EventBLL(object):
):
if not requested_ids:
continue
query = Q(id__in=requested_ids, company=company_id)
res.update(
(Model if model else Task).objects(query & locked_q).scalar("id")
)
query = Q(id__in=requested_ids) & locked_q
if model:
ids = Model.objects(query & Q(company=company_id)).scalar("id")
else:
ids = {
t.id
for t in get_many_tasks_for_writing(
company_id=company_id,
identity=identity,
query=query,
only=("id",),
throw_on_forbidden=False,
)
}
res.update(ids)
return res
def add_events(
self,
company_id: str,
user_id: str,
identity: Identity,
events: Sequence[dict],
worker: str,
) -> Tuple[int, int, dict]:
user_id = identity.user
task_ids = {}
model_ids = {}
for event in events:
@@ -161,8 +181,12 @@ class EventBLL(object):
"Inconsistent model_event setting in the passed events",
tasks=found_in_both,
)
valid_models = self._get_valid_entities(company_id, ids=model_ids, model=True)
valid_tasks = self._get_valid_entities(company_id, ids=task_ids)
valid_models = self._get_valid_entities(
company_id, ids=model_ids, identity=identity, model=True
)
valid_tasks = self._get_valid_entities(
company_id, ids=task_ids, identity=identity
)
actions: List[dict] = []
used_task_ids = set()

View File

@@ -10,6 +10,7 @@ from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.service_repo.auth import Identity
from .metadata import Metadata
@@ -57,14 +58,15 @@ class ModelBLL:
cls,
model_id: str,
company_id: str,
user_id: str,
identity: Identity,
force_publish_task: bool = False,
publish_task_func: Callable[[str, str, str, bool], dict] = None,
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
) -> Tuple[int, ModelTaskPublishResponse]:
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
if model.ready:
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
user_id = identity.user
published_task = None
if model.task and publish_task_func:
task = (
@@ -74,7 +76,7 @@ class ModelBLL:
)
if task and task.status != TaskStatus.published:
task_publish_res = publish_task_func(
model.task, company_id, user_id, force_publish_task
model.task, company_id, identity, force_publish_task
)
published_task = ModelTaskPublishResponse(
id=model.task, data=task_publish_res

View File

@@ -152,7 +152,7 @@ class QueueBLL(object):
for item in queue.entries:
try:
task = Task.get_for_writing(
task = Task.get(
company=company_id,
id=item.task,
_only=[

View File

@@ -5,6 +5,7 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@@ -48,12 +49,14 @@ class Artifacts:
def add_or_update_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifacts: Sequence[ApiArtifact],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifacts = {
get_artifact_id(a): Artifact(**a)
@@ -64,18 +67,20 @@ class Artifacts:
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_artifacts(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
artifact_ids: Sequence[ArtifactId],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
artifact_ids = [
get_artifact_id(a)
@@ -85,4 +90,4 @@ class Artifacts:
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -15,6 +15,7 @@ from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.service_repo.auth import Identity
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@@ -31,7 +32,10 @@ class HyperParams:
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -63,7 +67,7 @@ class HyperParams:
def delete_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamKey],
force: bool,
@@ -74,6 +78,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
with_param, without_param = iterutils.partition(
@@ -96,7 +101,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=delete_cmds,
set_last_update=not properties_only,
)
@@ -105,7 +110,7 @@ class HyperParams:
def edit_params(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
@@ -117,6 +122,7 @@ class HyperParams:
task_id=task_id,
allow_all_statuses=properties_only,
force=force,
identity=identity,
)
update_cmds = dict()
@@ -135,7 +141,7 @@ class HyperParams:
return update_task(
task,
user_id=user_id,
user_id=identity.user,
update_cmds=update_cmds,
set_last_update=not properties_only,
)
@@ -163,7 +169,10 @@ class HyperParams:
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
company_id=company_id,
task_ids=task_ids,
only=only,
allow_public=True,
)
return {
@@ -209,13 +218,15 @@ class HyperParams:
def edit_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
update_cmds = dict()
configuration = {
@@ -228,22 +239,24 @@ class HyperParams:
for name, value in configuration.items():
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return update_task(task, user_id=user_id, update_cmds=update_cmds)
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
@classmethod
def delete_configuration(
cls,
company_id: str,
user_id: str,
identity: Identity,
task_id: str,
configuration: Sequence[str],
force: bool,
) -> int:
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
task = get_task_for_update(
company_id=company_id, task_id=task_id, force=force, identity=identity
)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)

View File

@@ -58,27 +58,6 @@ class TaskBLL:
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
@staticmethod
def get_by_id(
company_id,

View File

@@ -9,6 +9,7 @@ from apiserver.bll.task import (
ChangeStatusRequest,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
@@ -24,6 +25,7 @@ from apiserver.database.model.task.task import (
DEFAULT_LAST_ITERATION,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
@@ -33,7 +35,7 @@ queue_bll = QueueBLL()
def archive_task(
task: Union[str, Task],
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
) -> int:
@@ -42,9 +44,10 @@ def archive_task(
Return 1 if successful
"""
if isinstance(task, str):
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
task,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
@@ -54,8 +57,9 @@ def archive_task(
"system_tags",
"enqueue_status",
),
requires_write_access=True,
)
user_id = identity.user
try:
TaskBLL.dequeue_and_change_status(
task,
@@ -79,34 +83,34 @@ def archive_task(
def unarchive_task(
task: str,
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
task = TaskBLL.get_task_with_access(
task,
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=("id",),
requires_write_access=True,
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
last_changed_by=identity.user,
)
def dequeue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
status_message: str,
status_reason: str,
remove_from_all_queues: bool = False,
@@ -119,7 +123,19 @@ def dequeue_task(
task = Task.get(
id=task_id,
company=company_id,
_only=(
_only=("id",),
include_public=True,
)
if not task:
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
return 1, {"updated": 0}
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
@@ -127,11 +143,7 @@ def dequeue_task(
"project",
"enqueue_status",
),
include_public=True,
)
if not task:
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
return 1, {"updated": 0}
res = TaskBLL.dequeue_and_change_status(
task,
@@ -148,7 +160,7 @@ def dequeue_task(
def enqueue_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
queue_id: str,
status_message: str,
status_reason: str,
@@ -173,11 +185,11 @@ def enqueue_task(
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(**query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
user_id = identity.user
if validate:
TaskBLL.validate(task)
@@ -207,9 +219,9 @@ def enqueue_task(
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
nested_set(res, ("fields", "execution.queue"), queue_id)
return 1, res
@@ -242,7 +254,7 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
def delete_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
move_to_trash: bool,
force: bool,
return_file_urls: bool,
@@ -251,8 +263,9 @@ def delete_task(
status_reason: str,
delete_external_artifacts: bool,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if (
@@ -305,15 +318,16 @@ def delete_task(
def reset_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
return_file_urls: bool,
delete_output_models: bool,
clear_all: bool,
delete_external_artifacts: bool,
) -> Tuple[dict, CleanupResult, dict]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if not force and task.status == TaskStatus.published:
@@ -392,14 +406,15 @@ def reset_task(
def publish_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
force: bool,
publish_model_func: Callable[[str, str, str], Any] = None,
publish_model_func: Callable[[str, str, Identity], Any] = None,
status_message: str = "",
status_reason: str = "",
) -> dict:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
if not force:
validate_status_change(task.status, TaskStatus.published)
@@ -422,7 +437,7 @@ def publish_task(
.first()
)
if model and not model.ready:
publish_model_func(model.id, company_id, user_id)
publish_model_func(model.id, company_id, identity)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
@@ -446,7 +461,7 @@ def publish_task(
def stop_task(
task_id: str,
company_id: str,
user_id: str,
identity: Identity,
user_name: str,
status_reason: str,
force: bool,
@@ -459,10 +474,11 @@ def stop_task(
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
user_id = identity.user
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"status",
"project",
@@ -472,7 +488,6 @@ def stop_task(
"last_update",
"execution.queue",
),
requires_write_access=True,
)
def is_run_by_worker(t: Task) -> bool:

View File

@@ -1,7 +1,9 @@
from datetime import datetime
from typing import Sequence
import attr
import six
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
@@ -10,6 +12,7 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
@@ -157,15 +160,75 @@ def get_possible_status_changes(current_status):
return possible
def get_many_tasks_for_writing(
company_id: str,
identity: Identity,
query: Q = None,
only: Sequence = None,
throw_on_forbidden: bool = True,
) -> Sequence[Task]:
if only:
missing = [f for f in ("company", ) if f not in only]
if missing:
only = [*only, *missing]
result = list(
Task.get_many(
company=company_id,
query=query,
override_projection=only,
allow_public=True,
return_dicts=False,
)
)
forbidden_tasks = {task.id for task in result if not task.company}
if forbidden_tasks:
if throw_on_forbidden:
raise errors.forbidden.NoWritePermission(
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
)
result = [task for task in result if task.id not in forbidden_tasks]
return result
def get_task_with_write_access(
task_id: str,
company_id: str,
identity: Identity,
only=None,
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(_only=only, **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
company_id: str,
task_id: str,
identity: Identity,
allow_all_statuses: bool = False,
force: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
"""
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
only=("id", "status"),
identity=identity,
)
if allow_all_statuses:
return task

View File

@@ -1283,21 +1283,6 @@ class GetMixin(PropsMixin):
)
return result
@classmethod
def get_many_for_writing(cls, company, *args, **kwargs):
result = cls.get_many(
company=company,
*args,
**dict(return_dicts=False, **kwargs),
allow_public=True,
)
forbidden_objects = {obj.id for obj in result if not obj.company}
if forbidden_objects:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
)
return result
class UpdateMixin(object):

View File

@@ -44,6 +44,7 @@ from apiserver.bll.task.param_utils import (
from apiserver.config_repo import config
from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.auth import Role
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import (
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
TaskModelNames,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
@@ -717,7 +719,10 @@ class PrePopulate:
@classmethod
def _generate_new_ids(
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
cls,
reader: ZipFile,
entity_files: Sequence,
metadata: Mapping[str, Any],
) -> Mapping[str, str]:
if not metadata or not any(
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
@@ -970,7 +975,7 @@ class PrePopulate:
ev["allow_locked"] = True
cls.event_bll.add_events(
company_id=company_id,
user_id=user_id,
identity=Identity(user_id, company=company_id, role=Role.admin),
events=events,
worker="",
)

View File

@@ -28,6 +28,7 @@ from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.task.utils import get_task_with_write_access
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.model import validate_id
@@ -46,6 +47,7 @@ from apiserver.database.utils import (
filter_fields,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
@@ -249,13 +251,12 @@ def update_for_task(call: APICall, company_id, _):
)
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
only=("models", "execution", "name", "status", "project"),
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
@@ -343,7 +344,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
validate_task(company_id, call.identity, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
@@ -373,7 +374,7 @@ def prepare_update_fields(call, company_id, fields: dict):
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(company_id, fields)
validate_task(company_id, call.identity, fields)
if "labels" in fields:
labels = fields["labels"]
@@ -403,8 +404,11 @@ def prepare_update_fields(call, company_id, fields: dict):
return fields
def validate_task(company_id, fields: dict):
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
def validate_task(company_id: str, identity: Identity, fields: dict):
task_id = fields["task"]
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity, only=("id",)
)
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
@@ -514,7 +518,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task = ModelBLL.publish_model(
model_id=request.model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
)
@@ -533,7 +537,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
func=partial(
ModelBLL.publish_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
),

View File

@@ -57,7 +57,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=False,
force=True,
return_file_urls=False,
@@ -108,7 +108,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
queued, res = enqueue_task(
task_id=task.id,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message="Starting pipeline",
status_reason="",

View File

@@ -100,7 +100,13 @@ from apiserver.bll.task.task_operations import (
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.task.utils import (
update_task,
get_task_for_update,
deleted_prefix,
get_many_tasks_for_writing,
get_task_with_write_access,
)
from apiserver.bll.util import run_batch_operation, update_project_time
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
@@ -118,6 +124,7 @@ from apiserver.database.utils import (
get_options,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
@@ -142,14 +149,34 @@ org_bll = OrgBLL()
project_bll = ProjectBLL()
def _assert_writable_tasks(
company_id: str, identity: Identity, ids: Sequence[str], only=("id",)
) -> Sequence[Task]:
tasks = get_many_tasks_for_writing(
company_id=company_id,
identity=identity,
query=Q(id__in=ids),
only=only,
)
missing_ids = set(ids) - {t.id for t in tasks}
if missing_ids:
raise errors.bad_request.InvalidTaskId(ids=list(missing_ids))
return tasks
def set_task_status_from_call(
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
request: UpdateRequest,
company_id: str,
identity: Identity,
new_status=None,
**set_fields,
) -> dict:
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
request.task,
company_id=company_id,
identity=identity,
only=("id", "status", "project"),
requires_write_access=True,
)
status_reason = request.status_reason
@@ -161,15 +188,17 @@ def set_task_status_from_call(
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
user_id=identity.user,
).execute(**set_fields)
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
task = TaskBLL.get_task_with_access(
req_model.task, company_id=company_id, allow_public=True
)
def get_by_id(call: APICall, company_id, request: TaskRequest):
task = TaskBLL.assert_exists(
company_id,
task_ids=request.task,
allow_public=True,
)[0]
task_dict = task.to_proper_dict()
conform_task_data(call, task_dict)
call.result.data = {"task": task_dict}
@@ -227,7 +256,9 @@ def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call.data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
company=company_id,
query_dict=call_data,
allow_public=True,
)
conform_task_data(call, tasks)
@@ -278,7 +309,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
**stop_task(
task_id=req_model.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
@@ -296,7 +327,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
func=partial(
stop_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
@@ -319,7 +350,7 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.stopped,
completed=datetime.utcnow(),
)
@@ -336,7 +367,7 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.in_progress,
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
)
@@ -353,7 +384,7 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.failed,
)
)
@@ -367,7 +398,7 @@ def close(call: APICall, company_id, req_model: UpdateRequest):
**set_task_status_from_call(
req_model,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.closed,
)
)
@@ -433,13 +464,17 @@ def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
for data in tasks_data:
params_unprepare_from_saved(
fields=data, copy_to_legacy=need_legacy_params,
fields=data,
copy_to_legacy=need_legacy_params,
)
artifacts_unprepare_from_saved(fields=data)
def prepare_create_fields(
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
call: APICall,
valid_fields=None,
output=None,
previous_task: Task = None,
):
valid_fields = valid_fields if valid_fields is not None else create_fields
t_fields = task_fields
@@ -566,11 +601,12 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
task_id = req_model.task
with translate_errors_context():
task = Task.get_for_writing(
id=task_id, company=company_id, _only=["id", "project"]
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
only=("id", "project"),
)
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
@@ -582,7 +618,8 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
id=task_id,
partial_update_dict=partial_update_dict,
injected_update=dict(
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
),
)
if updated_count:
@@ -606,11 +643,11 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
requirements = req_model.requirements
with translate_errors_context():
task = TaskBLL.get_task_with_access(
task = get_task_with_write_access(
req_model.task,
company_id=company_id,
identity=call.identity,
only=("status", "script"),
requires_write_access=True,
)
if not task.script:
raise errors.bad_request.MissingTaskFields(
@@ -636,8 +673,11 @@ def update_batch(call: APICall, company_id, _):
items = {i["task"]: i for i in items}
tasks = {
t.id: t
for t in Task.get_many_for_writing(
company=company_id, query=Q(id__in=list(items))
for t in _assert_writable_tasks(
identity=call.identity,
company_id=company_id,
ids=list(items),
only=("id", "project"),
)
}
@@ -656,7 +696,8 @@ def update_batch(call: APICall, company_id, _):
if not partial_update_dict:
continue
partial_update_dict.update(
last_change=now, last_changed_by=call.identity.user,
last_change=now,
last_changed_by=call.identity.user,
)
update_op = UpdateOne(
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
@@ -690,9 +731,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
force = req_model.force
with translate_errors_context():
task = Task.get_for_writing(id=task_id, company=company_id)
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
)
if not force and task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
@@ -756,7 +799,8 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
@endpoint(
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
"tasks.get_hyper_params",
request_data_model=GetHyperParamsRequest,
)
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
@@ -771,7 +815,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
@@ -785,7 +829,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
@@ -794,7 +838,8 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
@endpoint(
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
"tasks.get_configurations",
request_data_model=GetConfigurationsRequest,
)
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
tasks_params = HyperParams.get_configurations(
@@ -809,7 +854,8 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
@endpoint(
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
"tasks.get_configuration_names",
request_data_model=GetConfigurationNamesRequest,
)
def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest
@@ -830,7 +876,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
@@ -846,7 +892,7 @@ def delete_configuration(
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
configuration=request.configuration,
force=request.force,
@@ -863,7 +909,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -888,7 +934,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
func=partial(
enqueue_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -915,13 +961,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
@endpoint(
"tasks.dequeue", response_data_model=DequeueResponse,
"tasks.dequeue",
response_data_model=DequeueResponse,
)
def dequeue(call: APICall, company_id, request: DequeueRequest):
dequeued, res = dequeue_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues,
@@ -931,14 +978,15 @@ def dequeue(call: APICall, company_id, request: DequeueRequest):
@endpoint(
"tasks.dequeue_many", response_data_model=DequeueManyResponse,
"tasks.dequeue_many",
response_data_model=DequeueManyResponse,
)
def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
results, failures = run_batch_operation(
func=partial(
dequeue_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
remove_from_all_queues=request.remove_from_all_queues,
@@ -962,7 +1010,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
dequeued, cleanup_res, updates = reset_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
@@ -990,7 +1038,7 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
func=partial(
reset_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
@@ -1027,9 +1075,11 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
response_data_model=ArchiveResponse,
)
def archive(call: APICall, company_id, request: ArchiveRequest):
tasks = TaskBLL.assert_exists(
archived = 0
tasks = _assert_writable_tasks(
company_id,
task_ids=request.tasks,
call.identity,
ids=request.tasks,
only=(
"id",
"company",
@@ -1040,11 +1090,10 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
"enqueue_status",
),
)
archived = 0
for task in tasks:
archived += archive_task(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
@@ -1063,7 +1112,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
archive_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1085,7 +1134,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
func=partial(
unarchive_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
),
@@ -1104,7 +1153,7 @@ def delete(call: APICall, company_id, request: DeleteRequest):
deleted, task, cleanup_res = delete_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
@@ -1126,7 +1175,7 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
move_to_trash=request.move_to_trash,
force=request.force,
return_file_urls=request.return_file_urls,
@@ -1164,7 +1213,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
updates = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
status_reason=request.status_reason,
@@ -1183,7 +1232,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
func=partial(
publish_task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model
if request.publish_model
@@ -1211,7 +1260,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
**set_task_status_from_call(
request,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
new_status=TaskStatus.completed,
completed=datetime.utcnow(),
)
@@ -1221,7 +1270,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
publish_res = publish_task(
task_id=request.task,
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
force=request.force,
publish_model_func=ModelBLL.publish_model,
status_reason=request.status_reason,
@@ -1256,7 +1305,7 @@ def add_or_update_artifacts(
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
artifacts=request.artifacts,
force=True,
@@ -1273,7 +1322,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
user_id=call.identity.user,
identity=call.identity,
task_id=request.task,
artifact_ids=request.artifacts,
force=True,
@@ -1310,6 +1359,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
"project or project_name is required"
)
_assert_writable_tasks(company_id, call.identity, request.ids)
updated_projects = set(
t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
)
@@ -1330,7 +1380,8 @@ def move(call: APICall, company_id: str, request: MoveRequest):
@endpoint("tasks.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
_assert_writable_tasks(company_id, call.identity, request.ids)
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
@@ -1344,7 +1395,9 @@ def update_tags(_, company_id: str, request: UpdateTagsRequest):
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
get_task_for_update(
company_id=company_id, task_id=request.task, force=True, identity=call.identity
)
models_field = f"models__{request.type}"
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
@@ -1364,7 +1417,9 @@ def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelR
@endpoint("tasks.delete_models", min_version="2.13")
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
task = get_task_for_update(
company_id=company_id, task_id=request.task, force=True, identity=call.identity
)
delete_names = {
type_: [m.name for m in request.models if m.type == type_]
@@ -1377,6 +1432,8 @@ def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
}
updated = task.update(
last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,
last_change=datetime.utcnow(),
last_changed_by=call.identity.user,
**commands,
)
return {"updated": updated}