Fix add_or_update_artifacts should always be allowed on in_progress tasks

Fix delete_artifacts should always be allowed on in_progress tasks
Fix query code
This commit is contained in:
allegroai 2022-02-13 19:31:54 +02:00
parent e149af58b1
commit 843450bb9b
3 changed files with 31 additions and 42 deletions

View File

@ -139,6 +139,7 @@ class GetMixin(PropsMixin):
self._current_op = None self._current_op = None
self._sticky = False self._sticky = False
self._support_legacy = legacy self._support_legacy = legacy
self.allow_empty = False
def _get_op(self, v: str, translate: bool = False) -> Optional[str]: def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None
@ -149,9 +150,16 @@ class GetMixin(PropsMixin):
def _key(self, v) -> Optional[Union[str, bool]]: def _key(self, v) -> Optional[Union[str, bool]]:
if v is None: if v is None:
self._current_op = None self.allow_empty = True
self._sticky = False return None
return self.default_mongo_op
op = self._get_op(v)
if op is not None:
# operator - set state and return None
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
elif self._current_op: elif self._current_op:
current_op = self._current_op current_op = self._current_op
if not self._sticky: if not self._sticky:
@ -160,26 +168,20 @@ class GetMixin(PropsMixin):
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix): elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None self._current_op = None
return False return False
else:
op = self._get_op(v)
if op is not None:
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
return self.default_mongo_op return self.default_mongo_op
def get_actions(self, data: Sequence[str]) -> Tuple[Dict[str, List[Union[str, None]]], Optional[str]]: def get_global_op(self, data: Sequence[str]) -> int:
actions = {} op_to_res = {
if not data: "in": Q.OR,
return actions, None "all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = self._get_op(next(data, ""), translate=True) or self.default_mongo_op
return op_to_res.get(first_op, self.default_mongo_op)
global_op = self._get_op(data[0], translate=True) def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
if global_op in ("in", "all"): actions = {}
data = data[1:]
else:
global_op = None
for val in data: for val in data:
key = self._key(val) key = self._key(val)
@ -190,7 +192,7 @@ class GetMixin(PropsMixin):
val = val[len(self._legacy_exclude_prefix) :] val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val) actions.setdefault(key, []).append(val)
return actions, global_op or self.default_mongo_op return actions
get_all_query_options = QueryParameterOptions() get_all_query_options = QueryParameterOptions()
@ -447,18 +449,9 @@ class GetMixin(PropsMixin):
if not isinstance(data, (list, tuple)): if not isinstance(data, (list, tuple)):
data = [data] data = [data]
actions, global_op = cls.ListFieldBucketHelper(legacy=True).get_actions(data) helper = cls.ListFieldBucketHelper(legacy=True)
global_op = helper.get_global_op(data)
default_op = cls.ListFieldBucketHelper.default_mongo_op actions = helper.get_actions(data)
# Handle `allow_empty` hack: controlled using `None` as a specific value in the default "in" action
allow_empty = False
default_op_actions = actions.get(default_op)
if default_op_actions and None in default_op_actions:
allow_empty = True
default_op_actions.remove(None)
if not default_op_actions:
actions.pop(cls.ListFieldBucketHelper.default_mongo_op)
mongoengine_field = field.replace(".", "__") mongoengine_field = field.replace(".", "__")
@ -471,11 +464,11 @@ class GetMixin(PropsMixin):
q = RegexQ() q = RegexQ()
else: else:
q = RegexQCombination( q = RegexQCombination(
operation=RegexQ.AND if global_op is not default_op else RegexQ.OR, operation=global_op,
children=queries children=queries
) )
if not allow_empty: if not helper.allow_empty:
return q return q
return ( return (

View File

@ -1173,7 +1173,7 @@ def add_or_update_artifacts(
company_id=company_id, company_id=company_id,
task_id=request.task, task_id=request.task,
artifacts=request.artifacts, artifacts=request.artifacts,
force=request.force, force=True,
) )
} }
@ -1189,7 +1189,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
company_id=company_id, company_id=company_id,
task_id=request.task, task_id=request.task,
artifact_ids=request.artifacts, artifact_ids=request.artifacts,
force=request.force, force=True,
) )
} }

View File

@ -82,14 +82,10 @@ class TestTasksArtifacts(TestService):
# test edit running task # test edit running task
self.api.tasks.started(task=task) self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus): self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0] res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts, res) self._assertTaskArtifacts(artifacts, res)
with self.api.raises(InvalidTaskStatus): self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0] res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res) self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)