From 843450bb9b393eb0be9f2d0c5f368ce6892e24b2 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 13 Feb 2022 19:31:54 +0200 Subject: [PATCH] Fix add_or_update_artifacts should always be allowed on in_progress tasks Fix delete_artifacts should always be allowed on in_progress tasks Fix query code --- apiserver/database/model/base.py | 61 ++++++++----------- apiserver/services/tasks.py | 4 +- .../tests/automated/test_task_artifacts.py | 8 +-- 3 files changed, 31 insertions(+), 42 deletions(-) diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index 967c9cb..a8770e0 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -139,6 +139,7 @@ class GetMixin(PropsMixin): self._current_op = None self._sticky = False self._support_legacy = legacy + self.allow_empty = False def _get_op(self, v: str, translate: bool = False) -> Optional[str]: op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None @@ -149,9 +150,16 @@ class GetMixin(PropsMixin): def _key(self, v) -> Optional[Union[str, bool]]: if v is None: - self._current_op = None - self._sticky = False - return self.default_mongo_op + self.allow_empty = True + return None + + op = self._get_op(v) + if op is not None: + # operator - set state and return None + self._current_op, self._sticky = self._ops.get( + op, (self.default_mongo_op, self._sticky) + ) + return None elif self._current_op: current_op = self._current_op if not self._sticky: @@ -160,26 +168,20 @@ class GetMixin(PropsMixin): elif self._support_legacy and v.startswith(self._legacy_exclude_prefix): self._current_op = None return False - else: - op = self._get_op(v) - if op is not None: - self._current_op, self._sticky = self._ops.get( - op, (self.default_mongo_op, self._sticky) - ) - return None return self.default_mongo_op - def get_actions(self, data: Sequence[str]) -> Tuple[Dict[str, List[Union[str, None]]], Optional[str]]: - actions = {} - if not data: - return actions, None + def get_global_op(self, data: Sequence[str]) -> int: + op_to_res = { + "in": Q.OR, + "all": Q.AND, + } + data = (x for x in data if x is not None) + first_op = self._get_op(next(data, ""), translate=True) or self.default_mongo_op + return op_to_res.get(first_op, self.default_mongo_op) - global_op = self._get_op(data[0], translate=True) - if global_op in ("in", "all"): - data = data[1:] - else: - global_op = None + def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]: + actions = {} for val in data: key = self._key(val) @@ -190,7 +192,7 @@ class GetMixin(PropsMixin): val = val[len(self._legacy_exclude_prefix) :] actions.setdefault(key, []).append(val) - return actions, global_op or self.default_mongo_op + return actions get_all_query_options = QueryParameterOptions() @@ -447,18 +449,9 @@ class GetMixin(PropsMixin): if not isinstance(data, (list, tuple)): data = [data] - actions, global_op = cls.ListFieldBucketHelper(legacy=True).get_actions(data) - - default_op = cls.ListFieldBucketHelper.default_mongo_op - - # Handle `allow_empty` hack: controlled using `None` as a specific value in the default "in" action - allow_empty = False - default_op_actions = actions.get(default_op) - if default_op_actions and None in default_op_actions: - allow_empty = True - default_op_actions.remove(None) - if not default_op_actions: - actions.pop(cls.ListFieldBucketHelper.default_mongo_op) + helper = cls.ListFieldBucketHelper(legacy=True) + global_op = helper.get_global_op(data) + actions = helper.get_actions(data) mongoengine_field = field.replace(".", "__") @@ -471,11 +464,11 @@ class GetMixin(PropsMixin): q = RegexQ() else: q = RegexQCombination( - operation=RegexQ.AND if global_op is not default_op else RegexQ.OR, + operation=global_op, children=queries ) - if not allow_empty: + if not helper.allow_empty: return q return ( diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 9624234..a3ecfd5 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -1173,7 +1173,7 @@ def add_or_update_artifacts( company_id=company_id, task_id=request.task, artifacts=request.artifacts, - force=request.force, + force=True, ) } @@ -1189,7 +1189,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest) company_id=company_id, task_id=request.task, artifact_ids=request.artifacts, - force=request.force, + force=True, ) } diff --git a/apiserver/tests/automated/test_task_artifacts.py b/apiserver/tests/automated/test_task_artifacts.py index 60e0487..afee2a2 100644 --- a/apiserver/tests/automated/test_task_artifacts.py +++ b/apiserver/tests/automated/test_task_artifacts.py @@ -82,14 +82,10 @@ class TestTasksArtifacts(TestService): # test edit running task self.api.tasks.started(task=task) - with self.api.raises(InvalidTaskStatus): - self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit) - self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True) + self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit) res = self.api.tasks.get_all_ex(id=[task]).tasks[0] self._assertTaskArtifacts(artifacts, res) - with self.api.raises(InvalidTaskStatus): - self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}]) - self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True) + self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}]) res = self.api.tasks.get_all_ex(id=[task]).tasks[0] self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)