From c1e7f8f9c17534b22235de053734d62f815efd7e Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Fri, 17 Nov 2023 09:38:32 +0200
Subject: [PATCH] Optimize deletion of projects with many tasks

---
 apiserver/bll/event/event_bll.py         | 64 +++++++++++------
 apiserver/bll/project/project_cleanup.py | 32 +++++----
 apiserver/bll/task/task_cleanup.py       | 91 +++++++++++++-----------
 apiserver/schema/services/models.conf    |  1 +
 apiserver/schema/services/tasks.conf     |  1 +
 apiserver/services/models.py             |  2 +-
 apiserver/services/tasks.py              |  2 +-
 7 files changed, 117 insertions(+), 76 deletions(-)

diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py
index c1350a7..b1b44ee 100644
--- a/apiserver/bll/event/event_bll.py
+++ b/apiserver/bll/event/event_bll.py
@@ -49,8 +49,8 @@ from apiserver.utilities.json import loads
 # noinspection PyTypeChecker
 EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
 LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
-MAX_LONG = 2 ** 63 - 1
-MIN_LONG = -(2 ** 63)
+MAX_LONG = 2**63 - 1
+MIN_LONG = -(2**63)
 
 
 log = config.logger(__file__)
@@ -272,11 +272,13 @@ class EventBLL(object):
             else:
                 used_task_ids.add(task_or_model_id)
                 self._update_last_metric_events_for_task(
-                    last_events=task_last_events[task_or_model_id], event=event,
+                    last_events=task_last_events[task_or_model_id],
+                    event=event,
                 )
             if event_type == EventType.metrics_scalar.value:
                 self._update_last_scalar_events_for_task(
-                    last_events=task_last_scalar_events[task_or_model_id], event=event,
+                    last_events=task_last_scalar_events[task_or_model_id],
+                    event=event,
                 )
 
             actions.append(es_action)
@@ -583,7 +585,8 @@ class EventBLL(object):
         query = {"bool": {"must": must}}
         search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
         max_metrics, max_variants = get_max_metric_and_variant_counts(
-            query=query, **search_args,
+            query=query,
+            **search_args,
         )
         max_variants = int(max_variants // last_iterations_per_plot)
 
@@ -650,9 +653,11 @@ class EventBLL(object):
         return events, total_events, next_scroll_id
 
     def get_debug_image_urls(
-        self, company_id: str, task_id: str, after_key: dict = None
+        self, company_id: str, task_ids: Sequence[str], after_key: dict = None
     ) -> Tuple[Sequence[str], Optional[dict]]:
-        if check_empty_data(self.es, company_id, EventType.metrics_image):
+        if not task_ids or check_empty_data(
+            self.es, company_id, EventType.metrics_image
+        ):
             return [], None
 
         es_req = {
@@ -668,7 +673,10 @@ class EventBLL(object):
             },
             "query": {
                 "bool": {
-                    "must": [{"term": {"task": task_id}}, {"exists": {"field": "url"}}]
+                    "must": [
+                        {"terms": {"task": task_ids}},
+                        {"exists": {"field": "url"}},
+                    ]
                 }
             },
         }
@@ -686,9 +694,13 @@ class EventBLL(object):
         return [bucket["key"]["url"] for bucket in res["buckets"]], res.get("after_key")
 
     def get_plot_image_urls(
-        self, company_id: str, task_id: str, scroll_id: Optional[str]
+        self, company_id: str, task_ids: Sequence[str], scroll_id: Optional[str]
     ) -> Tuple[Sequence[dict], Optional[str]]:
-        if scroll_id == self.empty_scroll:
+        if (
+            scroll_id == self.empty_scroll
+            or not task_ids
+            or check_empty_data(self.es, company_id, EventType.metrics_plot)
+        ):
             return [], None
 
         if scroll_id:
@@ -703,7 +715,7 @@ class EventBLL(object):
                 "query": {
                     "bool": {
                         "must": [
-                            {"term": {"task": task_id}},
+                            {"terms": {"task": task_ids}},
                             {"exists": {"field": PlotFields.source_urls}},
                         ]
                     }
@@ -839,7 +851,8 @@ class EventBLL(object):
         query = {"bool": {"must": [{"term": {"task": task_id}}]}}
         search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
         max_metrics, max_variants = get_max_metric_and_variant_counts(
-            query=query, **search_args,
+            query=query,
+            **search_args,
         )
         es_req = {
             "size": 0,
@@ -893,7 +906,8 @@ class EventBLL(object):
         }
         search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
         max_metrics, max_variants = get_max_metric_and_variant_counts(
-            query=query, **search_args,
+            query=query,
+            **search_args,
         )
         max_variants = int(max_variants // 2)
         es_req = {
@@ -1037,9 +1051,9 @@ class EventBLL(object):
                                         "order": {"_key": "desc"},
                                     }
                                 }
-                            }
+                            },
                         }
-                    }
+                    },
                 }
             },
             "query": {"bool": {"must": must}},
@@ -1105,7 +1119,10 @@ class EventBLL(object):
 
         with translate_errors_context():
             es_res = search_company_events(
-                self.es, company_id=company_ids, event_type=event_type, body=es_req,
+                self.es,
+                company_id=company_ids,
+                event_type=event_type,
+                body=es_req,
             )
 
         if "aggregations" not in es_res:
@@ -1157,11 +1174,18 @@ class EventBLL(object):
         return {"refresh": True}
 
     def delete_task_events(
-        self, company_id, task_id, allow_locked=False, model=False, async_delete=False,
+        self,
+        company_id,
+        task_id,
+        allow_locked=False,
+        model=False,
+        async_delete=False,
     ):
         if model:
             self._validate_model_state(
-                company_id=company_id, model_id=task_id, allow_locked=allow_locked,
+                company_id=company_id,
+                model_id=task_id,
+                allow_locked=allow_locked,
             )
         else:
             self._validate_task_state(
@@ -1228,7 +1252,7 @@ class EventBLL(object):
         self, company_id: str, task_ids: Sequence[str], async_delete=False
     ):
         """
-        Delete mutliple task events. No check is done for tasks write access
+        Delete multiple task events. No check is done for tasks write access
         so it should be checked by the calling code
         """
         deleted = 0
@@ -1246,7 +1270,7 @@ class EventBLL(object):
                     deleted += es_res.get("deleted", 0)
 
         if not async_delete:
-            return es_res.get("deleted", 0)
+            return deleted
 
     def clear_scroll(self, scroll_id: str):
         if scroll_id == self.empty_scroll:
diff --git a/apiserver/bll/project/project_cleanup.py b/apiserver/bll/project/project_cleanup.py
index bf3cbd1..15f97f6 100644
--- a/apiserver/bll/project/project_cleanup.py
+++ b/apiserver/bll/project/project_cleanup.py
@@ -83,7 +83,8 @@ def validate_project_delete(company: str, project_id: str):
         ret["pipelines"] = 0
     if dataset_ids:
         datasets_with_data = Task.objects(
-            project__in=dataset_ids, system_tags__nin=[EntityVisibility.archived.value],
+            project__in=dataset_ids,
+            system_tags__nin=[EntityVisibility.archived.value],
         ).distinct("project")
         ret["datasets"] = len(datasets_with_data)
     else:
@@ -217,7 +218,9 @@ def delete_project(
     return res, affected
 
 
-def _delete_tasks(company: str, user: str, projects: Sequence[str]) -> Tuple[int, Set, Set]:
+def _delete_tasks(
+    company: str, user: str, projects: Sequence[str]
+) -> Tuple[int, Set, Set]:
     """
     Delete only the task themselves and their non published version.
     Child models under the same project are deleted separately.
@@ -228,7 +231,7 @@ def _delete_tasks(company: str, user: str, projects: Sequence[str]) -> Tuple[int
     if not tasks:
         return 0, set(), set()
 
-    task_ids = {t.id for t in tasks}
+    task_ids = list({t.id for t in tasks})
     now = datetime.utcnow()
     Task.objects(parent__in=task_ids, project__nin=projects).update(
         parent=None,
@@ -241,10 +244,11 @@ def _delete_tasks(company: str, user: str, projects: Sequence[str]) -> Tuple[int
         last_changed_by=user,
     )
 
-    event_urls, artifact_urls = set(), set()
+    event_urls = collect_debug_image_urls(company, task_ids) | collect_plot_image_urls(
+        company, task_ids
+    )
+    artifact_urls = set()
     for task in tasks:
-        event_urls.update(collect_debug_image_urls(company, task.id))
-        event_urls.update(collect_plot_image_urls(company, task.id))
         if task.execution and task.execution.artifacts:
             artifact_urls.update(
                 {
@@ -255,7 +259,7 @@ def _delete_tasks(company: str, user: str, projects: Sequence[str]) -> Tuple[int
             )
 
     event_bll.delete_multi_task_events(
-        company, list(task_ids), async_delete=async_events_delete
+        company, task_ids, async_delete=async_events_delete
     )
     deleted = tasks.delete()
     return deleted, event_urls, artifact_urls
@@ -307,19 +311,19 @@ def _delete_models(
         )
         # update unpublished tasks
         Task.objects(
-            id__in=model_tasks, project__nin=projects, status__ne=TaskStatus.published,
+            id__in=model_tasks,
+            project__nin=projects,
+            status__ne=TaskStatus.published,
         ).update(
             pull__models__output__model__in=model_ids,
             set__last_change=now,
             set__last_changed_by=user,
         )
 
-    event_urls, model_urls = set(), set()
-    for m in models:
-        event_urls.update(collect_debug_image_urls(company, m.id))
-        event_urls.update(collect_plot_image_urls(company, m.id))
-        if m.uri:
-            model_urls.add(m.uri)
+    event_urls = collect_debug_image_urls(company, model_ids) | collect_plot_image_urls(
+        company, model_ids
+    )
+    model_urls = {m.uri for m in models if m.uri}
 
     event_bll.delete_multi_task_events(
         company, model_ids, async_delete=async_events_delete
diff --git a/apiserver/bll/task/task_cleanup.py b/apiserver/bll/task/task_cleanup.py
index 2c74003..1f27a9e 100644
--- a/apiserver/bll/task/task_cleanup.py
+++ b/apiserver/bll/task/task_cleanup.py
@@ -1,10 +1,10 @@
 from datetime import datetime
 from itertools import chain
 from operator import attrgetter
-from typing import Sequence, Set, Tuple
+from typing import Sequence, Set, Tuple, Union
 
 import attr
-from boltons.iterutils import partition, bucketize, first
+from boltons.iterutils import partition, bucketize, first, chunked_iter
 from furl import furl
 from mongoengine import NotUniqueError
 from pymongo.errors import DuplicateKeyError
@@ -69,37 +69,47 @@ class CleanupResult:
         )
 
 
-def collect_plot_image_urls(company: str, task_or_model: str) -> Set[str]:
+def collect_plot_image_urls(
+    company: str, task_or_model: Union[str, Sequence[str]]
+) -> Set[str]:
     urls = set()
-    next_scroll_id = None
-    while True:
-        events, next_scroll_id = event_bll.get_plot_image_urls(
-            company_id=company, task_id=task_or_model, scroll_id=next_scroll_id
-        )
-        if not events:
-            break
-        for event in events:
-            event_urls = event.get(PlotFields.source_urls)
-            if event_urls:
-                urls.update(set(event_urls))
+    task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
+    for tasks in chunked_iter(task_ids, 100):
+        next_scroll_id = None
+        while True:
+            events, next_scroll_id = event_bll.get_plot_image_urls(
+                company_id=company, task_ids=tasks, scroll_id=next_scroll_id
+            )
+            if not events:
+                break
+            for event in events:
+                event_urls = event.get(PlotFields.source_urls)
+                if event_urls:
+                    urls.update(set(event_urls))
 
     return urls
 
 
-def collect_debug_image_urls(company: str, task_or_model: str) -> Set[str]:
+def collect_debug_image_urls(
+    company: str, task_or_model: Union[str, Sequence[str]]
+) -> Set[str]:
     """
     Return the set of unique image urls
     Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
     """
-    after_key = None
     urls = set()
-    while True:
-        res, after_key = event_bll.get_debug_image_urls(
-            company_id=company, task_id=task_or_model, after_key=after_key,
-        )
-        urls.update(res)
-        if not after_key:
-            break
+    task_ids = task_or_model if isinstance(task_or_model, list) else [task_or_model]
+    for tasks in chunked_iter(task_ids, 100):
+        after_key = None
+        while True:
+            res, after_key = event_bll.get_debug_image_urls(
+                company_id=company,
+                task_ids=tasks,
+                after_key=after_key,
+            )
+            urls.update(res)
+            if not after_key:
+                break
 
     return urls
 
@@ -122,7 +132,11 @@ supported_storage_types.update(
 
 
 def _schedule_for_delete(
-    company: str, user: str, task_id: str, urls: Set[str], can_delete_folders: bool,
+    company: str,
+    user: str,
+    task_id: str,
+    urls: Set[str],
+    can_delete_folders: bool,
 ) -> Set[str]:
     urls_per_storage = bucketize(
         urls,
@@ -236,23 +250,19 @@ def cleanup_task(
         if not models:
             continue
         if delete_output_models and allow_delete:
-            model_ids = set(m.id for m in models if m.id not in in_use_model_ids)
-            for m_id in model_ids:
+            model_ids = list({m.id for m in models if m.id not in in_use_model_ids})
+            if model_ids:
                 if return_file_urls or delete_external_artifacts:
-                    event_urls.update(collect_debug_image_urls(task.company, m_id))
-                    event_urls.update(collect_plot_image_urls(task.company, m_id))
-                try:
-                    event_bll.delete_task_events(
-                        task.company,
-                        m_id,
-                        allow_locked=True,
-                        model=True,
-                        async_delete=async_events_delete,
-                    )
-                except errors.bad_request.InvalidModelId as ex:
-                    log.info(f"Error deleting events for the model {m_id}: {str(ex)}")
+                    event_urls.update(collect_debug_image_urls(task.company, model_ids))
+                    event_urls.update(collect_plot_image_urls(task.company, model_ids))
+
+                event_bll.delete_multi_task_events(
+                    task.company,
+                    model_ids,
+                    async_delete=async_events_delete,
+                )
+                deleted_models += Model.objects(id__in=list(model_ids)).delete()
 
-            deleted_models += Model.objects(id__in=list(model_ids)).delete()
             if in_use_model_ids:
                 Model.objects(id__in=list(in_use_model_ids)).update(
                     unset__task=1,
@@ -319,7 +329,8 @@ def verify_task_children_and_ouptuts(
 
     model_fields = ["id", "ready", "uri"]
     published_models, draft_models = partition(
-        Model.objects(task=task.id).only(*model_fields), key=attrgetter("ready"),
+        Model.objects(task=task.id).only(*model_fields),
+        key=attrgetter("ready"),
     )
     if not force and published_models:
         raise errors.bad_request.TaskCannotBeDeleted(
diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf
index 3c9bd1f..8170fc5 100644
--- a/apiserver/schema/services/models.conf
+++ b/apiserver/schema/services/models.conf
@@ -1097,6 +1097,7 @@ delete_metadata {
 }
 update_tags {
     "999.0" {
+        description: Add or remove tags from multiple models
         request {
             type: object
             properties {
diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf
index 6b4ba65..f3ca19d 100644
--- a/apiserver/schema/services/tasks.conf
+++ b/apiserver/schema/services/tasks.conf
@@ -2060,6 +2060,7 @@ move {
 }
 update_tags {
     "999.0" {
+        description: Add or remove tags from multiple tasks
         request {
             type: object
             properties {
diff --git a/apiserver/services/models.py b/apiserver/services/models.py
index 25c2980..82e77a5 100644
--- a/apiserver/services/models.py
+++ b/apiserver/services/models.py
@@ -682,7 +682,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
 @endpoint("models.update_tags")
 def update_tags(_, company_id: str, request: UpdateTagsRequest):
     return {
-        "update": org_bll.edit_entity_tags(
+        "updated": org_bll.edit_entity_tags(
             company_id=company_id,
             entity_cls=Model,
             entity_ids=request.ids,
diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py
index 4140f57..15028a8 100644
--- a/apiserver/services/tasks.py
+++ b/apiserver/services/tasks.py
@@ -1332,7 +1332,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
 @endpoint("tasks.update_tags")
 def update_tags(_, company_id: str, request: UpdateTagsRequest):
     return {
-        "update": org_bll.edit_entity_tags(
+        "updated": org_bll.edit_entity_tags(
             company_id=company_id,
             entity_cls=Task,
             entity_ids=request.ids,