mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Refactor check for tasks write permission
This commit is contained in:
@@ -31,6 +31,7 @@ from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
|
||||
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
|
||||
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.task.utils import get_many_tasks_for_writing
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -42,6 +43,7 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.json import loads
|
||||
@@ -55,7 +57,9 @@ MIN_LONG = -(2**63)
|
||||
|
||||
log = config.logger(__file__)
|
||||
async_task_events_delete = config.get("services.tasks.async_events_delete", False)
|
||||
async_delete_threshold = config.get("services.tasks.async_events_delete_threshold", 100_000)
|
||||
async_delete_threshold = config.get(
|
||||
"services.tasks.async_events_delete_threshold", 100_000
|
||||
)
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
@@ -97,7 +101,9 @@ class EventBLL(object):
|
||||
return self._metrics
|
||||
|
||||
@staticmethod
|
||||
def _get_valid_entities(company_id, ids: Mapping[str, bool], model=False) -> Set:
|
||||
def _get_valid_entities(
|
||||
company_id, ids: Mapping[str, bool], identity: Identity, model=False
|
||||
) -> Set:
|
||||
"""Verify that task or model exists and can be updated"""
|
||||
if not ids:
|
||||
return set()
|
||||
@@ -116,20 +122,34 @@ class EventBLL(object):
|
||||
):
|
||||
if not requested_ids:
|
||||
continue
|
||||
query = Q(id__in=requested_ids, company=company_id)
|
||||
res.update(
|
||||
(Model if model else Task).objects(query & locked_q).scalar("id")
|
||||
)
|
||||
|
||||
query = Q(id__in=requested_ids) & locked_q
|
||||
if model:
|
||||
ids = Model.objects(query & Q(company=company_id)).scalar("id")
|
||||
else:
|
||||
ids = {
|
||||
t.id
|
||||
for t in get_many_tasks_for_writing(
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
query=query,
|
||||
only=("id",),
|
||||
throw_on_forbidden=False,
|
||||
)
|
||||
}
|
||||
|
||||
res.update(ids)
|
||||
|
||||
return res
|
||||
|
||||
def add_events(
|
||||
self,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
events: Sequence[dict],
|
||||
worker: str,
|
||||
) -> Tuple[int, int, dict]:
|
||||
user_id = identity.user
|
||||
task_ids = {}
|
||||
model_ids = {}
|
||||
for event in events:
|
||||
@@ -161,8 +181,12 @@ class EventBLL(object):
|
||||
"Inconsistent model_event setting in the passed events",
|
||||
tasks=found_in_both,
|
||||
)
|
||||
valid_models = self._get_valid_entities(company_id, ids=model_ids, model=True)
|
||||
valid_tasks = self._get_valid_entities(company_id, ids=task_ids)
|
||||
valid_models = self._get_valid_entities(
|
||||
company_id, ids=model_ids, identity=identity, model=True
|
||||
)
|
||||
valid_tasks = self._get_valid_entities(
|
||||
company_id, ids=task_ids, identity=identity
|
||||
)
|
||||
|
||||
actions: List[dict] = []
|
||||
used_task_ids = set()
|
||||
|
||||
@@ -10,6 +10,7 @@ from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from .metadata import Metadata
|
||||
|
||||
|
||||
@@ -57,14 +58,15 @@ class ModelBLL:
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
force_publish_task: bool = False,
|
||||
publish_task_func: Callable[[str, str, str, bool], dict] = None,
|
||||
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
|
||||
) -> Tuple[int, ModelTaskPublishResponse]:
|
||||
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
if model.ready:
|
||||
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
|
||||
|
||||
user_id = identity.user
|
||||
published_task = None
|
||||
if model.task and publish_task_func:
|
||||
task = (
|
||||
@@ -74,7 +76,7 @@ class ModelBLL:
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
task_publish_res = publish_task_func(
|
||||
model.task, company_id, user_id, force_publish_task
|
||||
model.task, company_id, identity, force_publish_task
|
||||
)
|
||||
published_task = ModelTaskPublishResponse(
|
||||
id=model.task, data=task_publish_res
|
||||
|
||||
@@ -152,7 +152,7 @@ class QueueBLL(object):
|
||||
|
||||
for item in queue.entries:
|
||||
try:
|
||||
task = Task.get_for_writing(
|
||||
task = Task.get(
|
||||
company=company_id,
|
||||
id=item.task,
|
||||
_only=[
|
||||
|
||||
@@ -5,6 +5,7 @@ from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.database.utils import hash_field_name
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
|
||||
@@ -48,12 +49,14 @@ class Artifacts:
|
||||
def add_or_update_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
artifacts: Sequence[ApiArtifact],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
@@ -64,18 +67,20 @@ class Artifacts:
|
||||
f"set__execution__artifacts__{mongoengine_safe(name)}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return update_task(task, user_id=user_id, update_cmds=update_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
artifact_ids: Sequence[ArtifactId],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force,)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
@@ -85,4 +90,4 @@ class Artifacts:
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)
|
||||
|
||||
@@ -15,6 +15,7 @@ from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
@@ -31,7 +32,10 @@ class HyperParams:
|
||||
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
only = ("id", "hyperparams")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
company_id=company_id,
|
||||
task_ids=task_ids,
|
||||
only=only,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -63,7 +67,7 @@ class HyperParams:
|
||||
def delete_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamKey],
|
||||
force: bool,
|
||||
@@ -74,6 +78,7 @@ class HyperParams:
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
@@ -96,7 +101,7 @@ class HyperParams:
|
||||
|
||||
return update_task(
|
||||
task,
|
||||
user_id=user_id,
|
||||
user_id=identity.user,
|
||||
update_cmds=delete_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
@@ -105,7 +110,7 @@ class HyperParams:
|
||||
def edit_params(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
@@ -117,6 +122,7 @@ class HyperParams:
|
||||
task_id=task_id,
|
||||
allow_all_statuses=properties_only,
|
||||
force=force,
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
@@ -135,7 +141,7 @@ class HyperParams:
|
||||
|
||||
return update_task(
|
||||
task,
|
||||
user_id=user_id,
|
||||
user_id=identity.user,
|
||||
update_cmds=update_cmds,
|
||||
set_last_update=not properties_only,
|
||||
)
|
||||
@@ -163,7 +169,10 @@ class HyperParams:
|
||||
else:
|
||||
only.append("configuration")
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
|
||||
company_id=company_id,
|
||||
task_ids=task_ids,
|
||||
only=only,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -209,13 +218,15 @@ class HyperParams:
|
||||
def edit_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
@@ -228,22 +239,24 @@ class HyperParams:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
|
||||
|
||||
return update_task(task, user_id=user_id, update_cmds=update_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=update_cmds)
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
task_id: str,
|
||||
configuration: Sequence[str],
|
||||
force: bool,
|
||||
) -> int:
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id, force=force)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, force=force, identity=identity
|
||||
)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return update_task(task, user_id=user_id, update_cmds=delete_cmds)
|
||||
return update_task(task, user_id=identity.user, update_cmds=delete_cmds)
|
||||
|
||||
@@ -58,27 +58,6 @@ class TaskBLL:
|
||||
self.events_es = events_es or es_factory.connect("events")
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id,
|
||||
|
||||
@@ -9,6 +9,7 @@ from apiserver.bll.task import (
|
||||
ChangeStatusRequest,
|
||||
)
|
||||
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
@@ -24,6 +25,7 @@ from apiserver.database.model.task.task import (
|
||||
DEFAULT_LAST_ITERATION,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -33,7 +35,7 @@ queue_bll = QueueBLL()
|
||||
def archive_task(
|
||||
task: Union[str, Task],
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> int:
|
||||
@@ -42,9 +44,10 @@ def archive_task(
|
||||
Return 1 if successful
|
||||
"""
|
||||
if isinstance(task, str):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
task,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
@@ -54,8 +57,9 @@ def archive_task(
|
||||
"system_tags",
|
||||
"enqueue_status",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
user_id = identity.user
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
@@ -79,34 +83,34 @@ def archive_task(
|
||||
|
||||
|
||||
def unarchive_task(
|
||||
task: str,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> int:
|
||||
"""
|
||||
Unarchive task. Return 1 if successful
|
||||
"""
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task,
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=("id",),
|
||||
requires_write_access=True,
|
||||
)
|
||||
return task.update(
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
pull__system_tags=EntityVisibility.archived.value,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=user_id,
|
||||
last_changed_by=identity.user,
|
||||
)
|
||||
|
||||
|
||||
def dequeue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
remove_from_all_queues: bool = False,
|
||||
@@ -119,7 +123,19 @@ def dequeue_task(
|
||||
task = Task.get(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=(
|
||||
_only=("id",),
|
||||
include_public=True,
|
||||
)
|
||||
if not task:
|
||||
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
|
||||
return 1, {"updated": 0}
|
||||
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
"execution",
|
||||
@@ -127,11 +143,7 @@ def dequeue_task(
|
||||
"project",
|
||||
"enqueue_status",
|
||||
),
|
||||
include_public=True,
|
||||
)
|
||||
if not task:
|
||||
TaskBLL.remove_task_from_all_queues(company_id, task_id=task_id)
|
||||
return 1, {"updated": 0}
|
||||
|
||||
res = TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
@@ -148,7 +160,7 @@ def dequeue_task(
|
||||
def enqueue_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
queue_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
@@ -173,11 +185,11 @@ def enqueue_task(
|
||||
# try to get default queue
|
||||
queue_id = queue_bll.get_default(company_id).id
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(**query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
user_id = identity.user
|
||||
if validate:
|
||||
TaskBLL.validate(task)
|
||||
|
||||
@@ -207,9 +219,9 @@ def enqueue_task(
|
||||
|
||||
# set the current queue ID in the task
|
||||
if task.execution:
|
||||
Task.objects(**query).update(execution__queue=queue_id, multi=False)
|
||||
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
|
||||
else:
|
||||
Task.objects(**query).update(execution=Execution(queue=queue_id), multi=False)
|
||||
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
|
||||
|
||||
nested_set(res, ("fields", "execution.queue"), queue_id)
|
||||
return 1, res
|
||||
@@ -242,7 +254,7 @@ def move_tasks_to_trash(tasks: Sequence[str]) -> int:
|
||||
def delete_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
move_to_trash: bool,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
@@ -251,8 +263,9 @@ def delete_task(
|
||||
status_reason: str,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -305,15 +318,16 @@ def delete_task(
|
||||
def reset_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
clear_all: bool,
|
||||
delete_external_artifacts: bool,
|
||||
) -> Tuple[dict, CleanupResult, dict]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
|
||||
if not force and task.status == TaskStatus.published:
|
||||
@@ -392,14 +406,15 @@ def reset_task(
|
||||
def publish_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
force: bool,
|
||||
publish_model_func: Callable[[str, str, str], Any] = None,
|
||||
publish_model_func: Callable[[str, str, Identity], Any] = None,
|
||||
status_message: str = "",
|
||||
status_reason: str = "",
|
||||
) -> dict:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id, company_id=company_id, identity=identity
|
||||
)
|
||||
if not force:
|
||||
validate_status_change(task.status, TaskStatus.published)
|
||||
@@ -422,7 +437,7 @@ def publish_task(
|
||||
.first()
|
||||
)
|
||||
if model and not model.ready:
|
||||
publish_model_func(model.id, company_id, user_id)
|
||||
publish_model_func(model.id, company_id, identity)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
@@ -446,7 +461,7 @@ def publish_task(
|
||||
def stop_task(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
identity: Identity,
|
||||
user_name: str,
|
||||
status_reason: str,
|
||||
force: bool,
|
||||
@@ -459,10 +474,11 @@ def stop_task(
|
||||
is set to 'stopping' to allow the worker to stop the task and report by itself
|
||||
:return: updated task fields
|
||||
"""
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
user_id = identity.user
|
||||
task = get_task_with_write_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=(
|
||||
"status",
|
||||
"project",
|
||||
@@ -472,7 +488,6 @@ def stop_task(
|
||||
"last_update",
|
||||
"execution.queue",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
def is_run_by_worker(t: Task) -> bool:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.util import update_project_time
|
||||
@@ -10,6 +12,7 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
@@ -157,15 +160,75 @@ def get_possible_status_changes(current_status):
|
||||
return possible
|
||||
|
||||
|
||||
def get_many_tasks_for_writing(
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
query: Q = None,
|
||||
only: Sequence = None,
|
||||
throw_on_forbidden: bool = True,
|
||||
) -> Sequence[Task]:
|
||||
if only:
|
||||
missing = [f for f in ("company", ) if f not in only]
|
||||
if missing:
|
||||
only = [*only, *missing]
|
||||
|
||||
result = list(
|
||||
Task.get_many(
|
||||
company=company_id,
|
||||
query=query,
|
||||
override_projection=only,
|
||||
allow_public=True,
|
||||
return_dicts=False,
|
||||
)
|
||||
)
|
||||
|
||||
forbidden_tasks = {task.id for task in result if not task.company}
|
||||
if forbidden_tasks:
|
||||
if throw_on_forbidden:
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public task(s), ids={tuple(forbidden_tasks)}"
|
||||
)
|
||||
result = [task for task in result if task.id not in forbidden_tasks]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_task_with_write_access(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
only=None,
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
query = dict(id=task_id, company=company_id)
|
||||
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
identity: Identity,
|
||||
allow_all_statuses: bool = False,
|
||||
force: bool = False
|
||||
) -> Task:
|
||||
"""
|
||||
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||
"""
|
||||
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
only=("id", "status"),
|
||||
identity=identity,
|
||||
)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
@@ -1283,21 +1283,6 @@ class GetMixin(PropsMixin):
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_many_for_writing(cls, company, *args, **kwargs):
|
||||
result = cls.get_many(
|
||||
company=company,
|
||||
*args,
|
||||
**dict(return_dicts=False, **kwargs),
|
||||
allow_public=True,
|
||||
)
|
||||
forbidden_objects = {obj.id for obj in result if not obj.company}
|
||||
if forbidden_objects:
|
||||
object_name = cls.__name__.lower()
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class UpdateMixin(object):
|
||||
|
||||
@@ -44,6 +44,7 @@ from apiserver.bll.task.param_utils import (
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_default_company
|
||||
from apiserver.database.model import EntityVisibility, User
|
||||
from apiserver.database.model.auth import Role
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import (
|
||||
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
|
||||
TaskModelNames,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
@@ -717,7 +719,10 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _generate_new_ids(
|
||||
cls, reader: ZipFile, entity_files: Sequence, metadata: Mapping[str, Any],
|
||||
cls,
|
||||
reader: ZipFile,
|
||||
entity_files: Sequence,
|
||||
metadata: Mapping[str, Any],
|
||||
) -> Mapping[str, str]:
|
||||
if not metadata or not any(
|
||||
metadata.get(key) for key in ("new_ids", "example_ids", "private_ids")
|
||||
@@ -970,7 +975,7 @@ class PrePopulate:
|
||||
ev["allow_locked"] = True
|
||||
cls.event_bll.add_events(
|
||||
company_id=company_id,
|
||||
user_id=user_id,
|
||||
identity=Identity(user_id, company=company_id, role=Role.admin),
|
||||
events=events,
|
||||
worker="",
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import publish_task
|
||||
from apiserver.bll.task.utils import get_task_with_write_access
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import validate_id
|
||||
@@ -46,6 +47,7 @@ from apiserver.database.utils import (
|
||||
filter_fields,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
@@ -249,13 +251,12 @@ def update_for_task(call: APICall, company_id, _):
|
||||
)
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only=("models", "execution", "name", "status", "project"),
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
@@ -343,7 +344,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(company_id, req_data)
|
||||
validate_task(company_id, call.identity, req_data)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
@@ -373,7 +374,7 @@ def prepare_update_fields(call, company_id, fields: dict):
|
||||
# clear UI cache if URI is provided (model updated)
|
||||
fields["ui_cache"] = fields.pop("ui_cache", {})
|
||||
if "task" in fields:
|
||||
validate_task(company_id, fields)
|
||||
validate_task(company_id, call.identity, fields)
|
||||
|
||||
if "labels" in fields:
|
||||
labels = fields["labels"]
|
||||
@@ -403,8 +404,11 @@ def prepare_update_fields(call, company_id, fields: dict):
|
||||
return fields
|
||||
|
||||
|
||||
def validate_task(company_id, fields: dict):
|
||||
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
|
||||
def validate_task(company_id: str, identity: Identity, fields: dict):
|
||||
task_id = fields["task"]
|
||||
get_task_with_write_access(
|
||||
task_id=task_id, company_id=company_id, identity=identity, only=("id",)
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
||||
@@ -514,7 +518,7 @@ def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
|
||||
updated, published_task = ModelBLL.publish_model(
|
||||
model_id=request.model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
)
|
||||
@@ -533,7 +537,7 @@ def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
|
||||
func=partial(
|
||||
ModelBLL.publish_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force_publish_task=request.force_publish_task,
|
||||
publish_task_func=publish_task if request.publish_task else None,
|
||||
),
|
||||
|
||||
@@ -57,7 +57,7 @@ def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
|
||||
func=partial(
|
||||
delete_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
move_to_trash=False,
|
||||
force=True,
|
||||
return_file_urls=False,
|
||||
@@ -108,7 +108,7 @@ def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest
|
||||
queued, res = enqueue_task(
|
||||
task_id=task.id,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
queue_id=request.queue,
|
||||
status_message="Starting pipeline",
|
||||
status_reason="",
|
||||
|
||||
@@ -100,7 +100,13 @@ from apiserver.bll.task.task_operations import (
|
||||
unarchive_task,
|
||||
move_tasks_to_trash,
|
||||
)
|
||||
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
|
||||
from apiserver.bll.task.utils import (
|
||||
update_task,
|
||||
get_task_for_update,
|
||||
deleted_prefix,
|
||||
get_many_tasks_for_writing,
|
||||
get_task_with_write_access,
|
||||
)
|
||||
from apiserver.bll.util import run_batch_operation, update_project_time
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
@@ -118,6 +124,7 @@ from apiserver.database.utils import (
|
||||
get_options,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.service_repo.auth import Identity
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
@@ -142,14 +149,34 @@ org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
def _assert_writable_tasks(
|
||||
company_id: str, identity: Identity, ids: Sequence[str], only=("id",)
|
||||
) -> Sequence[Task]:
|
||||
tasks = get_many_tasks_for_writing(
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
query=Q(id__in=ids),
|
||||
only=only,
|
||||
)
|
||||
missing_ids = set(ids) - {t.id for t in tasks}
|
||||
if missing_ids:
|
||||
raise errors.bad_request.InvalidTaskId(ids=list(missing_ids))
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
def set_task_status_from_call(
|
||||
request: UpdateRequest, company_id: str, user_id: str, new_status=None, **set_fields
|
||||
request: UpdateRequest,
|
||||
company_id: str,
|
||||
identity: Identity,
|
||||
new_status=None,
|
||||
**set_fields,
|
||||
) -> dict:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
request.task,
|
||||
company_id=company_id,
|
||||
identity=identity,
|
||||
only=("id", "status", "project"),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
status_reason = request.status_reason
|
||||
@@ -161,15 +188,17 @@ def set_task_status_from_call(
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
user_id=user_id,
|
||||
user_id=identity.user,
|
||||
).execute(**set_fields)
|
||||
|
||||
|
||||
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
|
||||
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
req_model.task, company_id=company_id, allow_public=True
|
||||
)
|
||||
def get_by_id(call: APICall, company_id, request: TaskRequest):
|
||||
task = TaskBLL.assert_exists(
|
||||
company_id,
|
||||
task_ids=request.task,
|
||||
allow_public=True,
|
||||
)[0]
|
||||
task_dict = task.to_proper_dict()
|
||||
conform_task_data(call, task_dict)
|
||||
call.result.data = {"task": task_dict}
|
||||
@@ -227,7 +256,9 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = escape_execution_parameters(call.data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
conform_task_data(call, tasks)
|
||||
@@ -278,7 +309,7 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**stop_task(
|
||||
task_id=req_model.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
user_name=call.identity.user_name,
|
||||
status_reason=req_model.status_reason,
|
||||
force=req_model.force,
|
||||
@@ -296,7 +327,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
|
||||
func=partial(
|
||||
stop_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
user_name=call.identity.user_name,
|
||||
status_reason=request.status_reason,
|
||||
force=request.force,
|
||||
@@ -319,7 +350,7 @@ def stopped(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.stopped,
|
||||
completed=datetime.utcnow(),
|
||||
)
|
||||
@@ -336,7 +367,7 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.in_progress,
|
||||
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
|
||||
)
|
||||
@@ -353,7 +384,7 @@ def failed(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.failed,
|
||||
)
|
||||
)
|
||||
@@ -367,7 +398,7 @@ def close(call: APICall, company_id, req_model: UpdateRequest):
|
||||
**set_task_status_from_call(
|
||||
req_model,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.closed,
|
||||
)
|
||||
)
|
||||
@@ -433,13 +464,17 @@ def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
|
||||
|
||||
for data in tasks_data:
|
||||
params_unprepare_from_saved(
|
||||
fields=data, copy_to_legacy=need_legacy_params,
|
||||
fields=data,
|
||||
copy_to_legacy=need_legacy_params,
|
||||
)
|
||||
artifacts_unprepare_from_saved(fields=data)
|
||||
|
||||
|
||||
def prepare_create_fields(
|
||||
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
|
||||
call: APICall,
|
||||
valid_fields=None,
|
||||
output=None,
|
||||
previous_task: Task = None,
|
||||
):
|
||||
valid_fields = valid_fields if valid_fields is not None else create_fields
|
||||
t_fields = task_fields
|
||||
@@ -566,11 +601,12 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
task_id = req_model.task
|
||||
|
||||
with translate_errors_context():
|
||||
task = Task.get_for_writing(
|
||||
id=task_id, company=company_id, _only=["id", "project"]
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only=("id", "project"),
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
|
||||
|
||||
@@ -582,7 +618,8 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
id=task_id,
|
||||
partial_update_dict=partial_update_dict,
|
||||
injected_update=dict(
|
||||
last_change=datetime.utcnow(), last_changed_by=call.identity.user,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=call.identity.user,
|
||||
),
|
||||
)
|
||||
if updated_count:
|
||||
@@ -606,11 +643,11 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
|
||||
requirements = req_model.requirements
|
||||
with translate_errors_context():
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task = get_task_with_write_access(
|
||||
req_model.task,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
only=("status", "script"),
|
||||
requires_write_access=True,
|
||||
)
|
||||
if not task.script:
|
||||
raise errors.bad_request.MissingTaskFields(
|
||||
@@ -636,8 +673,11 @@ def update_batch(call: APICall, company_id, _):
|
||||
items = {i["task"]: i for i in items}
|
||||
tasks = {
|
||||
t.id: t
|
||||
for t in Task.get_many_for_writing(
|
||||
company=company_id, query=Q(id__in=list(items))
|
||||
for t in _assert_writable_tasks(
|
||||
identity=call.identity,
|
||||
company_id=company_id,
|
||||
ids=list(items),
|
||||
only=("id", "project"),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -656,7 +696,8 @@ def update_batch(call: APICall, company_id, _):
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
partial_update_dict.update(
|
||||
last_change=now, last_changed_by=call.identity.user,
|
||||
last_change=now,
|
||||
last_changed_by=call.identity.user,
|
||||
)
|
||||
update_op = UpdateOne(
|
||||
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
|
||||
@@ -690,9 +731,11 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
force = req_model.force
|
||||
|
||||
with translate_errors_context():
|
||||
task = Task.get_for_writing(id=task_id, company=company_id)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
task = get_task_with_write_access(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
identity=call.identity,
|
||||
)
|
||||
|
||||
if not force and task.status != TaskStatus.created:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
@@ -756,7 +799,8 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
|
||||
"tasks.get_hyper_params",
|
||||
request_data_model=GetHyperParamsRequest,
|
||||
)
|
||||
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
@@ -771,7 +815,7 @@ def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
@@ -785,7 +829,7 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
@@ -794,7 +838,8 @@ def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsReq
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
|
||||
"tasks.get_configurations",
|
||||
request_data_model=GetConfigurationsRequest,
|
||||
)
|
||||
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
@@ -809,7 +854,8 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
|
||||
"tasks.get_configuration_names",
|
||||
request_data_model=GetConfigurationNamesRequest,
|
||||
)
|
||||
def get_configuration_names(
|
||||
call: APICall, company_id, request: GetConfigurationNamesRequest
|
||||
@@ -830,7 +876,7 @@ def edit_configuration(call: APICall, company_id, request: EditConfigurationRequ
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
@@ -846,7 +892,7 @@ def delete_configuration(
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
@@ -863,7 +909,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
|
||||
queued, res = enqueue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -888,7 +934,7 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
||||
func=partial(
|
||||
enqueue_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -915,13 +961,14 @@ def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.dequeue", response_data_model=DequeueResponse,
|
||||
"tasks.dequeue",
|
||||
response_data_model=DequeueResponse,
|
||||
)
|
||||
def dequeue(call: APICall, company_id, request: DequeueRequest):
|
||||
dequeued, res = dequeue_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
remove_from_all_queues=request.remove_from_all_queues,
|
||||
@@ -931,14 +978,15 @@ def dequeue(call: APICall, company_id, request: DequeueRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.dequeue_many", response_data_model=DequeueManyResponse,
|
||||
"tasks.dequeue_many",
|
||||
response_data_model=DequeueManyResponse,
|
||||
)
|
||||
def dequeue_many(call: APICall, company_id, request: DequeueManyRequest):
|
||||
results, failures = run_batch_operation(
|
||||
func=partial(
|
||||
dequeue_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
remove_from_all_queues=request.remove_from_all_queues,
|
||||
@@ -962,7 +1010,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
dequeued, cleanup_res, updates = reset_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
@@ -990,7 +1038,7 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
|
||||
func=partial(
|
||||
reset_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
@@ -1027,9 +1075,11 @@ def reset_many(call: APICall, company_id, request: ResetManyRequest):
|
||||
response_data_model=ArchiveResponse,
|
||||
)
|
||||
def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
tasks = TaskBLL.assert_exists(
|
||||
archived = 0
|
||||
tasks = _assert_writable_tasks(
|
||||
company_id,
|
||||
task_ids=request.tasks,
|
||||
call.identity,
|
||||
ids=request.tasks,
|
||||
only=(
|
||||
"id",
|
||||
"company",
|
||||
@@ -1040,11 +1090,10 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
"enqueue_status",
|
||||
),
|
||||
)
|
||||
archived = 0
|
||||
for task in tasks:
|
||||
archived += archive_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task=task,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1063,7 +1112,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
func=partial(
|
||||
archive_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
@@ -1085,7 +1134,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
|
||||
func=partial(
|
||||
unarchive_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
@@ -1104,7 +1153,7 @@ def delete(call: APICall, company_id, request: DeleteRequest):
|
||||
deleted, task, cleanup_res = delete_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
@@ -1126,7 +1175,7 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
|
||||
func=partial(
|
||||
delete_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
move_to_trash=request.move_to_trash,
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
@@ -1164,7 +1213,7 @@ def publish(call: APICall, company_id, request: PublishRequest):
|
||||
updates = publish_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1183,7 +1232,7 @@ def publish_many(call: APICall, company_id, request: PublishManyRequest):
|
||||
func=partial(
|
||||
publish_task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model
|
||||
if request.publish_model
|
||||
@@ -1211,7 +1260,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
|
||||
**set_task_status_from_call(
|
||||
request,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
new_status=TaskStatus.completed,
|
||||
completed=datetime.utcnow(),
|
||||
)
|
||||
@@ -1221,7 +1270,7 @@ def completed(call: APICall, company_id, request: CompletedRequest):
|
||||
publish_res = publish_task(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
force=request.force,
|
||||
publish_model_func=ModelBLL.publish_model,
|
||||
status_reason=request.status_reason,
|
||||
@@ -1256,7 +1305,7 @@ def add_or_update_artifacts(
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=True,
|
||||
@@ -1273,7 +1322,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
identity=call.identity,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=True,
|
||||
@@ -1310,6 +1359,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
"project or project_name is required"
|
||||
)
|
||||
|
||||
_assert_writable_tasks(company_id, call.identity, request.ids)
|
||||
updated_projects = set(
|
||||
t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
|
||||
)
|
||||
@@ -1330,7 +1380,8 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
|
||||
|
||||
@endpoint("tasks.update_tags")
|
||||
def update_tags(_, company_id: str, request: UpdateTagsRequest):
|
||||
def update_tags(call: APICall, company_id: str, request: UpdateTagsRequest):
|
||||
_assert_writable_tasks(company_id, call.identity, request.ids)
|
||||
return {
|
||||
"updated": org_bll.edit_entity_tags(
|
||||
company_id=company_id,
|
||||
@@ -1344,7 +1395,9 @@ def update_tags(_, company_id: str, request: UpdateTagsRequest):
|
||||
|
||||
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
||||
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
|
||||
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
get_task_for_update(
|
||||
company_id=company_id, task_id=request.task, force=True, identity=call.identity
|
||||
)
|
||||
|
||||
models_field = f"models__{request.type}"
|
||||
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
|
||||
@@ -1364,7 +1417,9 @@ def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelR
|
||||
|
||||
@endpoint("tasks.delete_models", min_version="2.13")
|
||||
def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=request.task, force=True, identity=call.identity
|
||||
)
|
||||
|
||||
delete_names = {
|
||||
type_: [m.name for m in request.models if m.type == type_]
|
||||
@@ -1377,6 +1432,8 @@ def delete_models(call: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
}
|
||||
|
||||
updated = task.update(
|
||||
last_change=datetime.utcnow(), last_changed_by=call.identity.user, **commands,
|
||||
last_change=datetime.utcnow(),
|
||||
last_changed_by=call.identity.user,
|
||||
**commands,
|
||||
)
|
||||
return {"updated": updated}
|
||||
|
||||
Reference in New Issue
Block a user