diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index ce0fee3..59ddcf1 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -103,64 +103,69 @@ class EventBLL(object): return self._metrics @staticmethod - def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set: - """Verify that task exists and can be updated""" - if not task_ids: + def _get_valid_entities(company_id, ids: Mapping[str, bool], model=False) -> Set: + """Verify that task or model exists and can be updated""" + if not ids: return set() with translate_errors_context(): - query = Q(id__in=task_ids, company=company_id) - if not allow_locked_tasks: - query &= Q(status__nin=LOCKED_TASK_STATUSES) - res = Task.objects(query).only("id") - return {r.id for r in res} + allow_locked = {id_ for id_, allowed in ids.items() if allowed} + not_locked = {id_ for id_, allowed in ids.items() if not allowed} + res = set() + allow_locked_q = Q() + not_locked_q = ( + Q(ready__ne=True) if model else Q(status__nin=LOCKED_TASK_STATUSES) + ) + for requested_ids, locked_q in ( + (allow_locked, allow_locked_q), + (not_locked, not_locked_q), + ): + if not requested_ids: + continue + query = Q(id__in=requested_ids, company=company_id) + res.update( + (Model if model else Task).objects(query & locked_q).scalar("id") + ) - @staticmethod - def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set: - """Verify that task exists and can be updated""" - if not model_ids: - return set() - - with translate_errors_context(): - query = Q(id__in=model_ids, company=company_id) - if not allow_locked_models: - query &= Q(ready__ne=True) - res = Model.objects(query).only("id") - return {r.id for r in res} + return res def add_events( - self, company_id, events, worker, allow_locked=False + self, company_id, events, worker ) -> Tuple[int, int, dict]: - task_ids = set() - model_ids = set() + task_ids = {} + model_ids = {} for event in events: if event.get("model_event", False): model = event.pop("model", None) if model is not None: event["task"] = model - model_ids.add(event.get("task")) + entity_ids = model_ids else: event["model_event"] = False - task_ids.add(event.get("task")) - if event.pop("allow_locked", allow_locked) != allow_locked: + entity_ids = task_ids + + id_ = event.get("task") + allow_locked = event.pop("allow_locked", False) + if not id_: + continue + + allowed_for_entity = entity_ids.get(id_) + if allowed_for_entity is None: + entity_ids[id_] = allow_locked + elif allowed_for_entity != allow_locked: raise errors.bad_request.ValidationError( - "Inconsistent allow_locked setting in the passed events" + f"Inconsistent allow_locked setting in the passed events for {id_}" ) - task_ids.discard(None) - model_ids.discard(None) - found_in_both = task_ids.intersection(model_ids) + found_in_both = set(task_ids).intersection(set(model_ids)) if found_in_both: raise errors.bad_request.ValidationError( "Inconsistent model_event setting in the passed events", tasks=found_in_both, ) - valid_models = self._get_valid_models( - company_id, model_ids=model_ids, allow_locked_models=allow_locked, - ) - valid_tasks = self._get_valid_tasks( - company_id, task_ids=task_ids, allow_locked_tasks=allow_locked, - ) + valid_models = self._get_valid_entities(company_id, ids=model_ids, model=True) + valid_tasks = self._get_valid_entities(company_id, ids=task_ids) + actions: List[dict] = [] used_task_ids = set() used_model_ids = set() diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index 9ea08c1..d7a68dc 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -967,6 +967,7 @@ class PrePopulate: for ev in events: ev["task"] = task_id ev["company_id"] = company_id + ev["allow_locked"] = True cls.event_bll.add_events( - company_id, events=events, worker="", allow_locked=True + company_id, events=events, worker="" ) diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 55f4bd0..977194c 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -70,9 +70,8 @@ def _assert_task_or_model_exists( @endpoint("events.add") def add(call: APICall, company_id, _): data = call.data.copy() - allow_locked = data.pop("allow_locked", False) added, err_count, err_info = event_bll.add_events( - company_id, [data], call.worker, allow_locked=allow_locked + company_id, [data], call.worker ) call.result.data = dict(added=added, errors=err_count, errors_info=err_info) @@ -87,7 +86,6 @@ def add_batch(call: APICall, company_id, _): company_id, events, call.worker, - allow_locked=events[0].get("allow_locked", False), ) call.result.data = dict(added=added, errors=err_count, errors_info=err_info) diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index 6ca7f44..a9ff3cf 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -218,6 +218,7 @@ class TestTaskEvents(TestService): msg=f"This is a log message", metric="Metric0", variant="Variant0", + allow_locked=True, ) ) self.send_batch(events)