From 1b650b168949f4ed5ecb0ecb3f0d867ba4c31ba9 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 26 Jul 2023 18:21:16 +0300 Subject: [PATCH] Add projects.get_user_names endpoint --- apiserver/apimodels/projects.py | 12 +++- apiserver/bll/project/project_bll.py | 71 ++++++++++++------- apiserver/schema/services/projects.conf | 48 +++++++++++++ apiserver/services/projects.py | 17 ++++- apiserver/tests/automated/test_subprojects.py | 4 ++ 5 files changed, 125 insertions(+), 27 deletions(-) diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 53a5897..eb92412 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -1,10 +1,11 @@ -from enum import Enum +from enum import Enum, auto from jsonmodels import models, fields from apiserver.apimodels import ListField, ActualEnumField, DictField from apiserver.apimodels.organization import TagsRequest from apiserver.database.model import EntityVisibility +from apiserver.utilities.stringenum import StringEnum class ProjectRequest(models.Base): @@ -52,6 +53,15 @@ class ProjectTaskParentsRequest(MultiProjectRequest): task_name = fields.StringField() +class EntityTypeEnum(StringEnum): + task = auto() + model = auto() + + +class ProjectUserNamesRequest(MultiProjectRequest): + entity = ActualEnumField(EntityTypeEnum, default=EntityTypeEnum.task) + + class ProjectHyperparamValuesRequest(MultiProjectRequest): section = fields.StringField(required=True) name = fields.StringField(required=True) diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index c4b26e2..2ae635b 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -13,7 +13,7 @@ from typing import ( TypeVar, Callable, Mapping, - Any, + Any, Union, ) from mongoengine import Q, Document @@ -22,7 +22,7 @@ from apiserver import database from apiserver.apierrors import errors from apiserver.apimodels.projects import ProjectChildrenType from apiserver.config_repo import config -from apiserver.database.model import EntityVisibility, AttributedDocument +from apiserver.database.model import EntityVisibility, AttributedDocument, User from apiserver.database.model.base import GetMixin from apiserver.database.model.model import Model from apiserver.database.model.project import Project @@ -973,6 +973,28 @@ class ProjectBLL: return filtered_ids, selected_project_ids + @staticmethod + def _get_project_query( + company: str, + projects: Sequence, + include_subprojects: bool = True, + state: Optional[EntityVisibility] = None, + ) -> Q: + query = get_company_or_none_constraint(company) + if projects: + if include_subprojects: + projects = _ids_with_children(projects) + query &= Q(project__in=projects) + else: + query &= Q(system_tags__nin=[EntityVisibility.hidden.value]) + + if state == EntityVisibility.archived: + query &= Q(system_tags__in=[EntityVisibility.archived.value]) + elif state == EntityVisibility.active: + query &= Q(system_tags__nin=[EntityVisibility.archived.value]) + + return query + @classmethod def get_task_parents( cls, @@ -986,19 +1008,9 @@ class ProjectBLL: Get list of unique parent tasks sorted by task name for the passed company projects If projects is None or empty then get parents for all the company tasks """ - query = Q(company=company_id) - - if projects: - if include_subprojects: - projects = _ids_with_children(projects) - query &= Q(project__in=projects) - else: - query &= Q(system_tags__nin=[EntityVisibility.hidden.value]) - - if state == EntityVisibility.archived: - query &= Q(system_tags__in=[EntityVisibility.archived.value]) - elif state == EntityVisibility.active: - query &= Q(system_tags__nin=[EntityVisibility.archived.value]) + query = cls._get_project_query( + company_id, projects, include_subprojects=include_subprojects, state=state + ) parent_ids = set(Task.objects(query).distinct("parent")) if not parent_ids: @@ -1014,18 +1026,30 @@ class ProjectBLL: return sorted(parents, key=itemgetter("name")) + @classmethod + def get_entity_users( + cls, + company: str, + entity_cls: Type[Union[Task, Model]], + projects: Sequence[str], + include_subprojects: bool, + ) -> Sequence[dict]: + query = cls._get_project_query( + company, projects, include_subprojects=include_subprojects + ) + user_ids = entity_cls.objects(query).distinct(field="user") + if not user_ids: + return [] + users = User.objects(id__in=user_ids).only("id", "name") + return [{"id": u.id, "name": u.name} for u in users] + @classmethod def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set: """ Return the list of unique task types used by company and public tasks If project ids passed then only tasks from these projects are considered """ - query = get_company_or_none_constraint(company) - if project_ids: - project_ids = _ids_with_children(project_ids) - query &= Q(project__in=project_ids) - else: - query &= Q(system_tags__nin=[EntityVisibility.hidden.value]) + query = cls._get_project_query(company, project_ids) res = Task.objects(query).distinct(field="type") return set(res).intersection(external_task_types) @@ -1035,10 +1059,7 @@ class ProjectBLL: Return the list of unique frameworks used by company and public models If project ids passed then only models from these projects are considered """ - query = get_company_or_none_constraint(company) - if project_ids: - project_ids = _ids_with_children(project_ids) - query &= Q(project__in=project_ids) + query = cls._get_project_query(company, project_ids) return Model.objects(query).distinct(field="framework") @staticmethod diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 7c6590e..f852317 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -1248,3 +1248,51 @@ get_task_parents { } } } +get_user_names { + "999.0" { + description: "Get names and ids of the users who created child entitites under the passed projects" + request { + type: object + properties { + projects { + description: "The list of projects. If not passed or empty then all the projects are searched" + type: array + items { type: string } + } + include_subprojects { + description: "If set to 'true' and the projects field is not empty then the result includes user name from the subprojects children" + type: boolean + default: true + } + entity { + description: The type of the child entity to look for + type: string + enum: [task, model] + default: task + } + } + } + response { + type: object + properties { + users { + description: "The list of users sorted by their names" + type: array + items { + type: object + properties { + id { + description: "The ID of the user" + type: string + } + name { + description: "The name of the user" + type: string + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 9059384..4d74138 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -19,6 +19,8 @@ from apiserver.apimodels.projects import ( ProjectModelMetadataValuesRequest, ProjectChildrenType, GetUniqueMetricsRequest, + ProjectUserNamesRequest, + EntityTypeEnum, ) from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL, ProjectQueries @@ -29,8 +31,9 @@ from apiserver.bll.project.project_cleanup import ( ) from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility +from apiserver.database.model.model import Model from apiserver.database.model.project import Project -from apiserver.database.model.task.task import TaskType +from apiserver.database.model.task.task import TaskType, Task from apiserver.database.utils import ( parse_from_call, get_company_or_none_constraint, @@ -527,3 +530,15 @@ def get_task_parents( name=request.task_name, ) } + + +@endpoint("projects.get_user_names") +def get_user_names(call: APICall, company_id: str, request: ProjectUserNamesRequest): + call.result.data = { + "users": ProjectBLL.get_entity_users( + company_id, + entity_cls=Model if request.entity == EntityTypeEnum.model else Task, + projects=request.projects, + include_subprojects=request.include_subprojects, + ) + } diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index d94cfb1..487d179 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -183,6 +183,8 @@ class TestSubProjects(TestService): self.assertEqual(res.types, []) res = self.api.projects.get_task_parents(projects=[project]) self.assertEqual(res.parents, []) + res = self.api.projects.get_user_names(projects=[project]) + self.assertEqual(res.users, []) res = self.api.organization.get_entities_count( projects={"id": [project]}, active_users=[user] ) @@ -206,6 +208,8 @@ class TestSubProjects(TestService): self.assertEqual(res.projects[0].stats.active.total_tasks, 2) res = self.api.projects.get_task_parents(projects=[project]) self._assert_ids(res.parents, [task1]) + res = self.api.projects.get_user_names(projects=[project]) + self.assertEqual(res.users, [{"id": "Test1", "name": "Test User"}]) res = self.api.models.get_frameworks(projects=[project]) self.assertEqual(res.frameworks, [framework]) res = self.api.tasks.get_types(projects=[project])