Support active users in projects

This commit is contained in:
allegroai 2021-05-03 17:36:04 +03:00
parent 6411954002
commit d029d56508
10 changed files with 379 additions and 247 deletions

View File

@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField):
) )
def parse_value(self, value): def parse_value(self, value):
if value is None and not self.required: if value is NotSet and not self.required:
return self.get_default_value() return self.get_default_value()
try: try:
# noinspection PyArgumentList # noinspection PyArgumentList

View File

@ -30,3 +30,10 @@ class ProjectHyperparamValuesRequest(MultiProjectReq):
section = fields.StringField(required=True) section = fields.StringField(required=True)
name = fields.StringField(required=True) name = fields.StringField(required=True)
allow_public = fields.BoolField(default=True) allow_public = fields.BoolField(default=True)
class ProjectsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
active_users = fields.ListField(str)

View File

@ -1,15 +1,21 @@
from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import Sequence, Optional, Type from itertools import groupby
from operator import itemgetter
from typing import Sequence, Optional, Type, Tuple, Dict
from mongoengine import Q, Document from mongoengine import Q, Document
from apiserver import database from apiserver import database
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
log = config.logger(__file__) log = config.logger(__file__)
@ -132,6 +138,205 @@ class ProjectBLL:
if hasattr(entity_cls, "last_change") if hasattr(entity_cls, "last_change")
else {} else {}
) )
entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra) entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
return project return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
@classmethod
def make_projects_get_all_pipelines(
cls,
company_id: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$system_tags",
}
},
"status": {"$ifNull": ["$status", "unknown"]},
}
}
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: cls.archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery(
{"$not": cls.archived_tasks_cond}
)
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"company": {"$in": [None, "", company_id]},
"type": {"$in": ["training", "testing", "annotation"]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
@classmethod
def get_project_stats(
cls,
company: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
) -> Dict[str, dict]:
if not project_ids:
return {}
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
company, project_ids=project_ids, specific_state=specific_state
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
def set_default_count(entry):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(runtime_pipeline)
}
def get_status_counts(project_id, section):
path = "/".join((project_id, section))
return {
"total_runtime": safe_get(runtime, path, 0),
"status_count": safe_get(status_count, path, default_counts),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
]
return {
project: {
task_state.value: get_status_counts(project, task_state.value)
for task_state in report_for_states
}
for project in project_ids
}
@classmethod
def get_projects_with_active_user(
cls,
company: str,
users: Sequence[str],
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
) -> Sequence[str]:
"""Get the projects ids where user created any tasks"""
company = (
{"company__in": [None, "", company]}
if allow_public
else {"company": company}
)
projects = {"project__in": project_ids} if project_ids else {}
return Task.objects(**company, user__in=users, **projects).distinct(
field="project"
)

View File

@ -637,6 +637,35 @@ class GetMixin(PropsMixin):
return qs return qs
@classmethod
def _get_queries_for_order_field(
cls, query: Q, order_field: str
) -> Union[None, Tuple[Q, Q]]:
"""
In case the order_field is one of the cls fields and the sorting is ascending
then return the tuple of 2 queries:
1. original query with not empty constraint on the order_by field
2. original query with empty constraint on the order_by field
"""
if not order_field or order_field.startswith("-") or "[" in order_field:
return
mongo_field_name = order_field.replace(".", "__")
mongo_field = first(
v for k, v in cls.get_all_fields_with_instance() if k == mongo_field_name
)
if not mongo_field:
return
params = {}
if isinstance(mongo_field, ListField):
params["is_list"] = True
elif isinstance(mongo_field, StringField):
params["empty_value"] = ""
non_empty = query & field_exists(mongo_field_name, **params)
empty = query & field_does_not_exist(mongo_field_name, **params)
return non_empty, empty
@classmethod @classmethod
def _get_many_override_none_ordering( def _get_many_override_none_ordering(
cls: Union[Document, "GetMixin"], cls: Union[Document, "GetMixin"],
@ -675,21 +704,9 @@ class GetMixin(PropsMixin):
order_field = first( order_field = first(
field for field in order_by if not field.startswith("$") field for field in order_by if not field.startswith("$")
) )
if ( res = cls._get_queries_for_order_field(query, order_field)
order_field if res:
and not order_field.startswith("-") query_sets = [cls.objects(q) for q in res]
and "[" not in order_field
):
params = {}
mongo_field = order_field.replace(".", "__")
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
params["is_list"] = True
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
params["empty_value"] = ""
non_empty = query & field_exists(mongo_field, **params)
empty = query & field_does_not_exist(mongo_field, **params)
query_sets = [cls.objects(non_empty), cls.objects(empty)]
query_sets = [qs.order_by(*order_by) for qs in query_sets] query_sets = [qs.order_by(*order_by) for qs in query_sets]
if order_field: if order_field:
collation_override = first( collation_override = first(

View File

@ -1,12 +1,11 @@
from collections import OrderedDict, defaultdict from collections import OrderedDict
from itertools import chain
from operator import attrgetter from operator import attrgetter
from threading import Lock from threading import Lock
from typing import Sequence from typing import Sequence
import six import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document, BaseField from mongoengine.base import get_document
from apiserver.database.fields import ( from apiserver.database.fields import (
LengthRangeEmbeddedDocumentListField, LengthRangeEmbeddedDocumentListField,
@ -21,7 +20,7 @@ class PropsMixin(object):
__cached_reference_fields = None __cached_reference_fields = None
__cached_exclude_fields = None __cached_exclude_fields = None
__cached_fields_with_instance = None __cached_fields_with_instance = None
__cached_field_names_per_type = None __cached_all_fields_with_instance = None
__cached_dpath_computed_fields_lock = Lock() __cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None __cached_dpath_computed_fields = None
@ -33,37 +32,12 @@ class PropsMixin(object):
return cls.__cached_fields return cls.__cached_fields
@classmethod @classmethod
def get_field_names_for_type(cls, of_type=BaseField): def get_all_fields_with_instance(cls):
""" if cls.__cached_all_fields_with_instance is None:
Return field names per type including subfields cls.__cached_all_fields_with_instance = get_fields(
The fields of derived types are also returned cls, return_instance=True, subfields=True
"""
assert issubclass(of_type, BaseField)
if cls.__cached_field_names_per_type is None:
fields = defaultdict(list)
for name, field in get_fields(cls, return_instance=True, subfields=True):
fields[type(field)].append(name)
for type_ in fields:
fields[type_].extend(
chain.from_iterable(
fields[other_type]
for other_type in fields
if other_type != type_ and issubclass(other_type, type_)
)
)
cls.__cached_field_names_per_type = fields
if of_type not in cls.__cached_field_names_per_type:
names = list(
chain.from_iterable(
field_names
for type_, field_names in cls.__cached_field_names_per_type.items()
if issubclass(type_, of_type)
)
) )
cls.__cached_field_names_per_type[of_type] = names return cls.__cached_all_fields_with_instance
return cls.__cached_field_names_per_type[of_type]
@classmethod @classmethod
def get_fields_with_instance(cls, doc_cls): def get_fields_with_instance(cls, doc_cls):

View File

@ -413,6 +413,17 @@ get_all_ex {
} }
} }
} }
"2.13": ${get_all_ex."2.1"} {
request {
properties {
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}
}
}
}
} }
update { update {
"2.1" { "2.1" {

View File

@ -1,9 +1,5 @@
from collections import defaultdict
from datetime import datetime from datetime import datetime
from itertools import groupby
from operator import itemgetter
import dpath
from mongoengine import Q from mongoengine import Q
from apiserver.apierrors import errors from apiserver.apierrors import errors
@ -15,6 +11,7 @@ from apiserver.apimodels.projects import (
ProjectTagsRequest, ProjectTagsRequest,
ProjectTaskParentsRequest, ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest, ProjectHyperparamValuesRequest,
ProjectsGetRequest,
) )
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL
@ -23,10 +20,9 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus from apiserver.database.model.task.task import Task
from apiserver.database.utils import ( from apiserver.database.utils import (
parse_from_call, parse_from_call,
get_options,
get_company_or_none_constraint, get_company_or_none_constraint,
) )
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
@ -40,7 +36,7 @@ from apiserver.timing_context import TimingContext
org_bll = OrgBLL() org_bll = OrgBLL()
task_bll = TaskBLL() task_bll = TaskBLL()
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]} project_bll = ProjectBLL()
create_fields = { create_fields = {
"name": None, "name": None,
@ -75,199 +71,46 @@ def get_by_id(call):
call.result.data = {"project": project_dict} call.result.data = {"project": project_dict}
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None): @endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
archived = EntityVisibility.archived.value def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$system_tags",
}
},
"status": {"$ifNull": ["$status", "unknown"]},
}
}
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(archived_tasks_cond)
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"type": {"$in": ["training", "testing"]},
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
@endpoint("projects.get_all_ex")
def get_all_ex(call: APICall):
include_stats = call.data.get("include_stats")
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
allow_public = not call.data.get("non_public", False)
if stats_for_state:
try:
specific_state = EntityVisibility(stats_for_state)
except ValueError:
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
else:
specific_state = None
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"): allow_public = not request.non_public
with TimingContext("mongo", "projects_get_all"):
if request.active_users:
ids = project_bll.get_projects_with_active_user(
company=company_id,
users=request.active_users,
project_ids=call.data.get("id"),
allow_public=allow_public,
)
if not ids:
call.result.data = {"projects": []}
return
call.data["id"] = ids
projects = Project.get_many_with_join( projects = Project.get_many_with_join(
company=call.identity.company, company=company_id,
query_dict=call.data, query_dict=call.data,
query_options=get_all_query_options, query_options=get_all_query_options,
allow_public=allow_public, allow_public=allow_public,
) )
conform_output_tags(call, projects)
if not include_stats: conform_output_tags(call, projects)
if not request.include_stats:
call.result.data = {"projects": projects} call.result.data = {"projects": projects}
return return
ids = [project["id"] for project in projects] project_ids = {project["id"] for project in projects}
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines( stats = project_bll.get_project_stats(
call.identity.company, ids, specific_state=specific_state company=company_id,
project_ids=list(project_ids),
specific_state=request.stats_for_state,
) )
default_counts = dict.fromkeys(get_options(TaskStatus), 0) for project in projects:
project["stats"] = stats[project["id"]]
def set_default_count(entry): call.result.data = {"projects": projects}
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(runtime_pipeline)
}
def safe_get(obj, path, default=None):
try:
return dpath.get(obj, path)
except KeyError:
return default
def get_status_counts(project_id, section):
path = "/".join((project_id, section))
return {
"total_runtime": safe_get(runtime, path, 0),
"status_count": safe_get(status_count, path, default_counts),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
]
for project in projects:
project["stats"] = {
task_state.value: get_status_counts(project["id"], task_state.value)
for task_state in report_for_states
}
call.result.data = {"projects": projects}
@endpoint("projects.get_all") @endpoint("projects.get_all")

View File

@ -28,7 +28,9 @@ class TestEntityOrdering(TestService):
self._assertGetTasksWithOrdering(order_by="comment") self._assertGetTasksWithOrdering(order_by="comment")
# sort by parameter which type is not part of db schema # sort by parameter which type is not part of db schema
self._assertGetTasksWithOrdering(order_by="execution.parameters.test") self._assertGetTasksWithOrdering(
order_by="execution.parameters.test", valid_order=False
)
def test_order_with_paging(self): def test_order_with_paging(self):
order_field = "started" order_field = "started"
@ -97,7 +99,9 @@ class TestEntityOrdering(TestService):
return val return val
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs): def _assertGetTasksWithOrdering(
self, order_by: str = None, valid_order=True, **kwargs
):
tasks = self.api.tasks.get_all_ex( tasks = self.api.tasks.get_all_ex(
only_fields=self.only_fields, only_fields=self.only_fields,
order_by=[order_by] if isinstance(order_by, str) else order_by, order_by=[order_by] if isinstance(order_by, str) else order_by,
@ -105,14 +109,16 @@ class TestEntityOrdering(TestService):
**kwargs, **kwargs,
).tasks ).tasks
self.assertLessEqual(set(self.task_ids), set(t.id for t in tasks)) self.assertLessEqual(set(self.task_ids), set(t.id for t in tasks))
if order_by: if order_by and valid_order:
# test that the output is correctly ordered # test that the output is correctly ordered
field_name = order_by if not order_by.startswith("-") else order_by[1:] field_name = order_by if not order_by.startswith("-") else order_by[1:]
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks] field_vals = [
self._get_value_for_path(t, field_name.split(".")) for t in tasks
]
self._assertSorted( self._assertSorted(
field_vals, field_vals,
ascending=not order_by.startswith("-"), ascending=not order_by.startswith("-"),
is_numeric=field_name.startswith("execution.parameters.") is_numeric=field_name.startswith("execution.parameters."),
) )
def _create_tasks(self): def _create_tasks(self):

View File

@ -0,0 +1,65 @@
from boltons.iterutils import first
from apiserver.tests.automated import TestService
class TestProjectsRetrieval(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_active_user(self):
user = self.api.users.get_current_user().user.id
project1 = self.temp_project(name="Project retrieval1")
project2 = self.temp_project(name="Project retrieval2")
self.temp_task(project=project2)
projects = self.api.projects.get_all_ex().projects
self.assertTrue({project1, project2}.issubset({p.id for p in projects}))
projects = self.api.projects.get_all_ex(active_users=[user]).projects
ids = {p.id for p in projects}
self.assertFalse(project1 in ids)
self.assertTrue(project2 in ids)
def test_stats(self):
project = self.temp_project()
self.temp_task(project=project)
self.temp_task(project=project)
archived_task = self.temp_task(project=project)
self.api.tasks.archive(tasks=[archived_task])
p = self._get_project(project)
self.assertFalse("stats" in p)
p = self._get_project(project, include_stats=True)
self.assertFalse("archived" in p.stats)
self.assertTrue(p.stats.active.status_count.created, 2)
p = self._get_project(project, include_stats=True, stats_for_state=None)
self.assertTrue(p.stats.active.status_count.created, 2)
self.assertTrue(p.stats.archived.status_count.created, 1)
def _get_project(self, project, **kwargs):
projects = self.api.projects.get_all_ex(id=[project], **kwargs).projects
p = first(p for p in projects if p.id == project)
self.assertIsNotNone(p)
return p
def temp_project(self, **kwargs) -> str:
self.update_missing(
kwargs,
name="Test projects retrieval",
description="test",
delete_params=dict(force=True),
)
return self.create_temp("projects", **kwargs)
def temp_task(self, **kwargs) -> str:
self.update_missing(
kwargs,
type="testing",
name="test projects retrieval",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs)

View File

@ -1,10 +1,14 @@
from enum import Enum from enum import Enum
class StringEnum(Enum): class StringEnum(str, Enum):
def __str__(self): def __str__(self):
return self.value return self.value
@classmethod
def values(cls):
return list(map(str, cls))
# noinspection PyMethodParameters # noinspection PyMethodParameters
def _generate_next_value_(name, start, count, last_values): def _generate_next_value_(name, start, count, last_values):
return name return name