Fix add_or_update_artifacts should always be allowed on in_progress tasks

Fix delete_artifacts should always be allowed on in_progress tasks
Fix query code
This commit is contained in:
allegroai 2022-02-13 19:31:54 +02:00
parent e149af58b1
commit 843450bb9b
3 changed files with 31 additions and 42 deletions

View File

@ -139,6 +139,7 @@ class GetMixin(PropsMixin):
self._current_op = None
self._sticky = False
self._support_legacy = legacy
self.allow_empty = False
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None
@ -149,9 +150,16 @@ class GetMixin(PropsMixin):
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
self._current_op = None
self._sticky = False
return self.default_mongo_op
self.allow_empty = True
return None
op = self._get_op(v)
if op is not None:
# operator - set state and return None
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
elif self._current_op:
current_op = self._current_op
if not self._sticky:
@ -160,26 +168,20 @@ class GetMixin(PropsMixin):
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
else:
op = self._get_op(v)
if op is not None:
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
return self.default_mongo_op
def get_actions(self, data: Sequence[str]) -> Tuple[Dict[str, List[Union[str, None]]], Optional[str]]:
actions = {}
if not data:
return actions, None
def get_global_op(self, data: Sequence[str]) -> int:
op_to_res = {
"in": Q.OR,
"all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = self._get_op(next(data, ""), translate=True) or self.default_mongo_op
return op_to_res.get(first_op, self.default_mongo_op)
global_op = self._get_op(data[0], translate=True)
if global_op in ("in", "all"):
data = data[1:]
else:
global_op = None
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
for val in data:
key = self._key(val)
@ -190,7 +192,7 @@ class GetMixin(PropsMixin):
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions, global_op or self.default_mongo_op
return actions
get_all_query_options = QueryParameterOptions()
@ -447,18 +449,9 @@ class GetMixin(PropsMixin):
if not isinstance(data, (list, tuple)):
data = [data]
actions, global_op = cls.ListFieldBucketHelper(legacy=True).get_actions(data)
default_op = cls.ListFieldBucketHelper.default_mongo_op
# Handle `allow_empty` hack: controlled using `None` as a specific value in the default "in" action
allow_empty = False
default_op_actions = actions.get(default_op)
if default_op_actions and None in default_op_actions:
allow_empty = True
default_op_actions.remove(None)
if not default_op_actions:
actions.pop(cls.ListFieldBucketHelper.default_mongo_op)
helper = cls.ListFieldBucketHelper(legacy=True)
global_op = helper.get_global_op(data)
actions = helper.get_actions(data)
mongoengine_field = field.replace(".", "__")
@ -471,11 +464,11 @@ class GetMixin(PropsMixin):
q = RegexQ()
else:
q = RegexQCombination(
operation=RegexQ.AND if global_op is not default_op else RegexQ.OR,
operation=global_op,
children=queries
)
if not allow_empty:
if not helper.allow_empty:
return q
return (

View File

@ -1173,7 +1173,7 @@ def add_or_update_artifacts(
company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=request.force,
force=True,
)
}
@ -1189,7 +1189,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
force=request.force,
force=True,
)
}

View File

@ -82,14 +82,10 @@ class TestTasksArtifacts(TestService):
# test edit running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts, res)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True)
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)