Add projects.get_user_names endpoint

This commit is contained in:
allegroai 2023-07-26 18:21:16 +03:00
parent 14d18a7aba
commit 1b650b1689
5 changed files with 125 additions and 27 deletions

View File

@ -1,10 +1,11 @@
from enum import Enum
from enum import Enum, auto
from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility
from apiserver.utilities.stringenum import StringEnum
class ProjectRequest(models.Base):
@ -52,6 +53,15 @@ class ProjectTaskParentsRequest(MultiProjectRequest):
task_name = fields.StringField()
class EntityTypeEnum(StringEnum):
task = auto()
model = auto()
class ProjectUserNamesRequest(MultiProjectRequest):
entity = ActualEnumField(EntityTypeEnum, default=EntityTypeEnum.task)
class ProjectHyperparamValuesRequest(MultiProjectRequest):
section = fields.StringField(required=True)
name = fields.StringField(required=True)

View File

@ -13,7 +13,7 @@ from typing import (
TypeVar,
Callable,
Mapping,
Any,
Any, Union,
)
from mongoengine import Q, Document
@ -22,7 +22,7 @@ from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.projects import ProjectChildrenType
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model import EntityVisibility, AttributedDocument, User
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
@ -973,6 +973,28 @@ class ProjectBLL:
return filtered_ids, selected_project_ids
@staticmethod
def _get_project_query(
company: str,
projects: Sequence,
include_subprojects: bool = True,
state: Optional[EntityVisibility] = None,
) -> Q:
query = get_company_or_none_constraint(company)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
return query
@classmethod
def get_task_parents(
cls,
@ -986,19 +1008,9 @@ class ProjectBLL:
Get list of unique parent tasks sorted by task name for the passed company projects
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
query &= Q(system_tags__nin=[EntityVisibility.archived.value])
query = cls._get_project_query(
company_id, projects, include_subprojects=include_subprojects, state=state
)
parent_ids = set(Task.objects(query).distinct("parent"))
if not parent_ids:
@ -1014,18 +1026,30 @@ class ProjectBLL:
return sorted(parents, key=itemgetter("name"))
@classmethod
def get_entity_users(
cls,
company: str,
entity_cls: Type[Union[Task, Model]],
projects: Sequence[str],
include_subprojects: bool,
) -> Sequence[dict]:
query = cls._get_project_query(
company, projects, include_subprojects=include_subprojects
)
user_ids = entity_cls.objects(query).distinct(field="user")
if not user_ids:
return []
users = User.objects(id__in=user_ids).only("id", "name")
return [{"id": u.id, "name": u.name} for u in users]
@classmethod
def get_task_types(cls, company, project_ids: Optional[Sequence]) -> set:
"""
Return the list of unique task types used by company and public tasks
If project ids passed then only tasks from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
query = cls._get_project_query(company, project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@ -1035,10 +1059,7 @@ class ProjectBLL:
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
query = cls._get_project_query(company, project_ids)
return Model.objects(query).distinct(field="framework")
@staticmethod

View File

@ -1248,3 +1248,51 @@ get_task_parents {
}
}
}
get_user_names {
"999.0" {
description: "Get names and ids of the users who created child entitites under the passed projects"
request {
type: object
properties {
projects {
description: "The list of projects. If not passed or empty then all the projects are searched"
type: array
items { type: string }
}
include_subprojects {
description: "If set to 'true' and the projects field is not empty then the result includes user name from the subprojects children"
type: boolean
default: true
}
entity {
description: The type of the child entity to look for
type: string
enum: [task, model]
default: task
}
}
}
response {
type: object
properties {
users {
description: "The list of users sorted by their names"
type: array
items {
type: object
properties {
id {
description: "The ID of the user"
type: string
}
name {
description: "The name of the user"
type: string
}
}
}
}
}
}
}
}

View File

@ -19,6 +19,8 @@ from apiserver.apimodels.projects import (
ProjectModelMetadataValuesRequest,
ProjectChildrenType,
GetUniqueMetricsRequest,
ProjectUserNamesRequest,
EntityTypeEnum,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries
@ -29,8 +31,9 @@ from apiserver.bll.project.project_cleanup import (
)
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import TaskType
from apiserver.database.model.task.task import TaskType, Task
from apiserver.database.utils import (
parse_from_call,
get_company_or_none_constraint,
@ -527,3 +530,15 @@ def get_task_parents(
name=request.task_name,
)
}
@endpoint("projects.get_user_names")
def get_user_names(call: APICall, company_id: str, request: ProjectUserNamesRequest):
call.result.data = {
"users": ProjectBLL.get_entity_users(
company_id,
entity_cls=Model if request.entity == EntityTypeEnum.model else Task,
projects=request.projects,
include_subprojects=request.include_subprojects,
)
}

View File

@ -183,6 +183,8 @@ class TestSubProjects(TestService):
self.assertEqual(res.types, [])
res = self.api.projects.get_task_parents(projects=[project])
self.assertEqual(res.parents, [])
res = self.api.projects.get_user_names(projects=[project])
self.assertEqual(res.users, [])
res = self.api.organization.get_entities_count(
projects={"id": [project]}, active_users=[user]
)
@ -206,6 +208,8 @@ class TestSubProjects(TestService):
self.assertEqual(res.projects[0].stats.active.total_tasks, 2)
res = self.api.projects.get_task_parents(projects=[project])
self._assert_ids(res.parents, [task1])
res = self.api.projects.get_user_names(projects=[project])
self.assertEqual(res.users, [{"id": "Test1", "name": "Test User"}])
res = self.api.models.get_frameworks(projects=[project])
self.assertEqual(res.frameworks, [framework])
res = self.api.tasks.get_types(projects=[project])