Support receiving mixed events for both locked and unlocked tasks and models events.add_batch

This commit is contained in:
allegroai 2023-05-25 19:20:35 +03:00
parent de9651d761
commit dff2ed34e8
4 changed files with 45 additions and 40 deletions

View File

@ -103,64 +103,69 @@ class EventBLL(object):
return self._metrics
@staticmethod
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
"""Verify that task exists and can be updated"""
if not task_ids:
def _get_valid_entities(company_id, ids: Mapping[str, bool], model=False) -> Set:
"""Verify that task or model exists and can be updated"""
if not ids:
return set()
with translate_errors_context():
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
allow_locked = {id_ for id_, allowed in ids.items() if allowed}
not_locked = {id_ for id_, allowed in ids.items() if not allowed}
res = set()
allow_locked_q = Q()
not_locked_q = (
Q(ready__ne=True) if model else Q(status__nin=LOCKED_TASK_STATUSES)
)
for requested_ids, locked_q in (
(allow_locked, allow_locked_q),
(not_locked, not_locked_q),
):
if not requested_ids:
continue
query = Q(id__in=requested_ids, company=company_id)
res.update(
(Model if model else Task).objects(query & locked_q).scalar("id")
)
@staticmethod
def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set:
"""Verify that task exists and can be updated"""
if not model_ids:
return set()
with translate_errors_context():
query = Q(id__in=model_ids, company=company_id)
if not allow_locked_models:
query &= Q(ready__ne=True)
res = Model.objects(query).only("id")
return {r.id for r in res}
return res
def add_events(
self, company_id, events, worker, allow_locked=False
self, company_id, events, worker
) -> Tuple[int, int, dict]:
task_ids = set()
model_ids = set()
task_ids = {}
model_ids = {}
for event in events:
if event.get("model_event", False):
model = event.pop("model", None)
if model is not None:
event["task"] = model
model_ids.add(event.get("task"))
entity_ids = model_ids
else:
event["model_event"] = False
task_ids.add(event.get("task"))
if event.pop("allow_locked", allow_locked) != allow_locked:
entity_ids = task_ids
id_ = event.get("task")
allow_locked = event.pop("allow_locked", False)
if not id_:
continue
allowed_for_entity = entity_ids.get(id_)
if allowed_for_entity is None:
entity_ids[id_] = allow_locked
elif allowed_for_entity != allow_locked:
raise errors.bad_request.ValidationError(
"Inconsistent allow_locked setting in the passed events"
f"Inconsistent allow_locked setting in the passed events for {id_}"
)
task_ids.discard(None)
model_ids.discard(None)
found_in_both = task_ids.intersection(model_ids)
found_in_both = set(task_ids).intersection(set(model_ids))
if found_in_both:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events",
tasks=found_in_both,
)
valid_models = self._get_valid_models(
company_id, model_ids=model_ids, allow_locked_models=allow_locked,
)
valid_tasks = self._get_valid_tasks(
company_id, task_ids=task_ids, allow_locked_tasks=allow_locked,
)
valid_models = self._get_valid_entities(company_id, ids=model_ids, model=True)
valid_tasks = self._get_valid_entities(company_id, ids=task_ids)
actions: List[dict] = []
used_task_ids = set()
used_model_ids = set()

View File

@ -967,6 +967,7 @@ class PrePopulate:
for ev in events:
ev["task"] = task_id
ev["company_id"] = company_id
ev["allow_locked"] = True
cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked=True
company_id, events=events, worker=""
)

View File

@ -70,9 +70,8 @@ def _assert_task_or_model_exists(
@endpoint("events.add")
def add(call: APICall, company_id, _):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked=allow_locked
company_id, [data], call.worker
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@ -87,7 +86,6 @@ def add_batch(call: APICall, company_id, _):
company_id,
events,
call.worker,
allow_locked=events[0].get("allow_locked", False),
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)

View File

@ -218,6 +218,7 @@ class TestTaskEvents(TestService):
msg=f"This is a log message",
metric="Metric0",
variant="Variant0",
allow_locked=True,
)
)
self.send_batch(events)