From 5456ee4ebfd5c73c1f578d5fec98c979fdfde5f8 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 20 Jun 2024 17:48:18 +0300 Subject: [PATCH] Data tool export projects by name now includes subprojects + option for exporting all projects added --- apiserver/mongo/initialize/pre_populate.py | 62 ++++++++++++++++------ 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index 6a050d0..9f80bd1 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -22,6 +22,7 @@ from typing import ( Mapping, IO, Callable, + Iterable, ) from urllib.parse import unquote, urlparse from uuid import uuid4, UUID, uuid5 @@ -220,6 +221,9 @@ class PrePopulate: raise ValueError("Invalid task statuses") file = Path(filename) + if not (experiments or projects): + projects = cls.project_cls.objects(parent=None).scalar("id") + entities = cls._resolve_entities( experiments=experiments, projects=projects, task_statuses=task_statuses ) @@ -417,24 +421,50 @@ class PrePopulate: featured_index = get_index(project) cls.project_cls.objects(id=project.id).update(featured=featured_index) - @staticmethod - def _resolve_type( - cls: Type[mongoengine.Document], ids: Optional[Sequence[str]] + @classmethod + def _resolve_entity_type( + cls, entity_type: Type[mongoengine.Document], ids: Optional[Sequence[str]] ) -> Sequence[Any]: ids = set(ids) - items = list(cls.objects(id__in=list(ids))) + items = list(entity_type.objects(id__in=list(ids))) resolved = {i.id for i in items} missing = ids - resolved - for name_candidate in missing: - results = list(cls.objects(name=name_candidate)) - if not results: - print(f"ERROR: no match for `{name_candidate}`") - exit(1) - elif len(results) > 1: - print(f"ERROR: more than one match for `{name_candidate}`") - exit(1) - items.append(results[0]) - return items + if not missing: + return items + + resolved_by_name = defaultdict(list) + for entity in entity_type.objects(name__in=list(missing)): + resolved_by_name[entity.name].append(entity) + + not_found = missing - set(resolved_by_name) + if not_found: + print(f"ERROR: no match for {', '.join(not_found)}") + exit(1) + + duplicates = [k for k, v in resolved_by_name.items() if len(v) > 1] + if duplicates: + print(f"ERROR: more than one match for {', '.join(duplicates)}") + exit(1) + + def get_new_items(input_: Iterable) -> list: + return [item for item in input_ if item.id not in resolved] + + def get_projects_with_children(projects: list) -> list: + project_ids = set(item.id for item in projects) + ids_with_children = project_ids_with_children(list(project_ids)) + if project_ids == set(ids_with_children): + return projects + + return get_new_items(entity_type.objects(id__in=ids_with_children)) + + new_items = get_new_items(chain(*resolved_by_name.values())) + if not new_items: + return items + + if entity_type == cls.project_cls: + new_items = get_projects_with_children(new_items) + + return items + new_items @classmethod def _check_projects_hierarchy(cls, projects: Set[Project]): @@ -467,7 +497,7 @@ class PrePopulate: print("Reading projects...") projects = project_ids_with_children(projects) entities[cls.project_cls].update( - cls._resolve_type(cls.project_cls, projects) + cls._resolve_entity_type(cls.project_cls, projects) ) print("--> Reading project experiments...") query = Q( @@ -485,7 +515,7 @@ class PrePopulate: if experiments: print("Reading experiments...") - entities[cls.task_cls].update(cls._resolve_type(cls.task_cls, experiments)) + entities[cls.task_cls].update(cls._resolve_entity_type(cls.task_cls, experiments)) print("--> Reading experiments projects...") objs = cls.project_cls.objects( id__in=list(