Support active users in projects

This commit is contained in:
allegroai 2021-05-03 17:36:04 +03:00
parent 6411954002
commit d029d56508
10 changed files with 379 additions and 247 deletions

View File

@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField):
)
def parse_value(self, value):
if value is None and not self.required:
if value is NotSet and not self.required:
return self.get_default_value()
try:
# noinspection PyArgumentList

View File

@ -30,3 +30,10 @@ class ProjectHyperparamValuesRequest(MultiProjectReq):
section = fields.StringField(required=True)
name = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
active_users = fields.ListField(str)

View File

@ -1,15 +1,21 @@
from collections import defaultdict
from datetime import datetime
from typing import Sequence, Optional, Type
from itertools import groupby
from operator import itemgetter
from typing import Sequence, Optional, Type, Tuple, Dict
from mongoengine import Q, Document
from apiserver import database
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.database.utils import get_options
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
log = config.logger(__file__)
@ -132,6 +138,205 @@ class ProjectBLL:
if hasattr(entity_cls, "last_change")
else {}
)
entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra)
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
@classmethod
def make_projects_get_all_pipelines(
cls,
company_id: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$system_tags",
}
},
"status": {"$ifNull": ["$status", "unknown"]},
}
}
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: cls.archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery(
{"$not": cls.archived_tasks_cond}
)
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"company": {"$in": [None, "", company_id]},
"type": {"$in": ["training", "testing", "annotation"]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
@classmethod
def get_project_stats(
cls,
company: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
) -> Dict[str, dict]:
if not project_ids:
return {}
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
company, project_ids=project_ids, specific_state=specific_state
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
def set_default_count(entry):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(runtime_pipeline)
}
def get_status_counts(project_id, section):
path = "/".join((project_id, section))
return {
"total_runtime": safe_get(runtime, path, 0),
"status_count": safe_get(status_count, path, default_counts),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
]
return {
project: {
task_state.value: get_status_counts(project, task_state.value)
for task_state in report_for_states
}
for project in project_ids
}
@classmethod
def get_projects_with_active_user(
cls,
company: str,
users: Sequence[str],
project_ids: Optional[Sequence[str]] = None,
allow_public: bool = True,
) -> Sequence[str]:
"""Get the projects ids where user created any tasks"""
company = (
{"company__in": [None, "", company]}
if allow_public
else {"company": company}
)
projects = {"project__in": project_ids} if project_ids else {}
return Task.objects(**company, user__in=users, **projects).distinct(
field="project"
)

View File

@ -637,6 +637,35 @@ class GetMixin(PropsMixin):
return qs
@classmethod
def _get_queries_for_order_field(
cls, query: Q, order_field: str
) -> Union[None, Tuple[Q, Q]]:
"""
In case the order_field is one of the cls fields and the sorting is ascending
then return the tuple of 2 queries:
1. original query with not empty constraint on the order_by field
2. original query with empty constraint on the order_by field
"""
if not order_field or order_field.startswith("-") or "[" in order_field:
return
mongo_field_name = order_field.replace(".", "__")
mongo_field = first(
v for k, v in cls.get_all_fields_with_instance() if k == mongo_field_name
)
if not mongo_field:
return
params = {}
if isinstance(mongo_field, ListField):
params["is_list"] = True
elif isinstance(mongo_field, StringField):
params["empty_value"] = ""
non_empty = query & field_exists(mongo_field_name, **params)
empty = query & field_does_not_exist(mongo_field_name, **params)
return non_empty, empty
@classmethod
def _get_many_override_none_ordering(
cls: Union[Document, "GetMixin"],
@ -675,21 +704,9 @@ class GetMixin(PropsMixin):
order_field = first(
field for field in order_by if not field.startswith("$")
)
if (
order_field
and not order_field.startswith("-")
and "[" not in order_field
):
params = {}
mongo_field = order_field.replace(".", "__")
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
params["is_list"] = True
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
params["empty_value"] = ""
non_empty = query & field_exists(mongo_field, **params)
empty = query & field_does_not_exist(mongo_field, **params)
query_sets = [cls.objects(non_empty), cls.objects(empty)]
res = cls._get_queries_for_order_field(query, order_field)
if res:
query_sets = [cls.objects(q) for q in res]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
if order_field:
collation_override = first(

View File

@ -1,12 +1,11 @@
from collections import OrderedDict, defaultdict
from itertools import chain
from collections import OrderedDict
from operator import attrgetter
from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document, BaseField
from mongoengine.base import get_document
from apiserver.database.fields import (
LengthRangeEmbeddedDocumentListField,
@ -21,7 +20,7 @@ class PropsMixin(object):
__cached_reference_fields = None
__cached_exclude_fields = None
__cached_fields_with_instance = None
__cached_field_names_per_type = None
__cached_all_fields_with_instance = None
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
@ -33,37 +32,12 @@ class PropsMixin(object):
return cls.__cached_fields
@classmethod
def get_field_names_for_type(cls, of_type=BaseField):
"""
Return field names per type including subfields
The fields of derived types are also returned
"""
assert issubclass(of_type, BaseField)
if cls.__cached_field_names_per_type is None:
fields = defaultdict(list)
for name, field in get_fields(cls, return_instance=True, subfields=True):
fields[type(field)].append(name)
for type_ in fields:
fields[type_].extend(
chain.from_iterable(
fields[other_type]
for other_type in fields
if other_type != type_ and issubclass(other_type, type_)
)
)
cls.__cached_field_names_per_type = fields
if of_type not in cls.__cached_field_names_per_type:
names = list(
chain.from_iterable(
field_names
for type_, field_names in cls.__cached_field_names_per_type.items()
if issubclass(type_, of_type)
)
def get_all_fields_with_instance(cls):
if cls.__cached_all_fields_with_instance is None:
cls.__cached_all_fields_with_instance = get_fields(
cls, return_instance=True, subfields=True
)
cls.__cached_field_names_per_type[of_type] = names
return cls.__cached_field_names_per_type[of_type]
return cls.__cached_all_fields_with_instance
@classmethod
def get_fields_with_instance(cls, doc_cls):

View File

@ -413,6 +413,17 @@ get_all_ex {
}
}
}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
active_users {
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
type: array
items: {type: string}
}
}
}
}
}
update {
"2.1" {

View File

@ -1,9 +1,5 @@
from collections import defaultdict
from datetime import datetime
from itertools import groupby
from operator import itemgetter
import dpath
from mongoengine import Q
from apiserver.apierrors import errors
@ -15,6 +11,7 @@ from apiserver.apimodels.projects import (
ProjectTagsRequest,
ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest,
ProjectsGetRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
@ -23,10 +20,9 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.database.model.task.task import Task
from apiserver.database.utils import (
parse_from_call,
get_options,
get_company_or_none_constraint,
)
from apiserver.service_repo import APICall, endpoint
@ -40,7 +36,7 @@ from apiserver.timing_context import TimingContext
org_bll = OrgBLL()
task_bll = TaskBLL()
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
project_bll = ProjectBLL()
create_fields = {
"name": None,
@ -75,199 +71,46 @@ def get_by_id(call):
call.result.data = {"project": project_dict}
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
archived = EntityVisibility.archived.value
def ensure_valid_fields():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$system_tags",
}
},
"status": {"$ifNull": ["$status", "unknown"]},
}
}
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(archived_tasks_cond)
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"type": {"$in": ["training", "testing"]},
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
},
ensure_valid_fields(),
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
@endpoint("projects.get_all_ex")
def get_all_ex(call: APICall):
include_stats = call.data.get("include_stats")
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
allow_public = not call.data.get("non_public", False)
if stats_for_state:
try:
specific_state = EntityVisibility(stats_for_state)
except ValueError:
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
else:
specific_state = None
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_tag_fields(call, call.data)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
allow_public = not request.non_public
with TimingContext("mongo", "projects_get_all"):
if request.active_users:
ids = project_bll.get_projects_with_active_user(
company=company_id,
users=request.active_users,
project_ids=call.data.get("id"),
allow_public=allow_public,
)
if not ids:
call.result.data = {"projects": []}
return
call.data["id"] = ids
projects = Project.get_many_with_join(
company=call.identity.company,
company=company_id,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=allow_public,
)
conform_output_tags(call, projects)
if not include_stats:
conform_output_tags(call, projects)
if not request.include_stats:
call.result.data = {"projects": projects}
return
ids = [project["id"] for project in projects]
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
call.identity.company, ids, specific_state=specific_state
project_ids = {project["id"] for project in projects}
stats = project_bll.get_project_stats(
company=company_id,
project_ids=list(project_ids),
specific_state=request.stats_for_state,
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
for project in projects:
project["stats"] = stats[project["id"]]
def set_default_count(entry):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(runtime_pipeline)
}
def safe_get(obj, path, default=None):
try:
return dpath.get(obj, path)
except KeyError:
return default
def get_status_counts(project_id, section):
path = "/".join((project_id, section))
return {
"total_runtime": safe_get(runtime, path, 0),
"status_count": safe_get(status_count, path, default_counts),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
]
for project in projects:
project["stats"] = {
task_state.value: get_status_counts(project["id"], task_state.value)
for task_state in report_for_states
}
call.result.data = {"projects": projects}
call.result.data = {"projects": projects}
@endpoint("projects.get_all")

View File

@ -28,7 +28,9 @@ class TestEntityOrdering(TestService):
self._assertGetTasksWithOrdering(order_by="comment")
# sort by parameter which type is not part of db schema
self._assertGetTasksWithOrdering(order_by="execution.parameters.test")
self._assertGetTasksWithOrdering(
order_by="execution.parameters.test", valid_order=False
)
def test_order_with_paging(self):
order_field = "started"
@ -97,7 +99,9 @@ class TestEntityOrdering(TestService):
return val
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
def _assertGetTasksWithOrdering(
self, order_by: str = None, valid_order=True, **kwargs
):
tasks = self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=[order_by] if isinstance(order_by, str) else order_by,
@ -105,14 +109,16 @@ class TestEntityOrdering(TestService):
**kwargs,
).tasks
self.assertLessEqual(set(self.task_ids), set(t.id for t in tasks))
if order_by:
if order_by and valid_order:
# test that the output is correctly ordered
field_name = order_by if not order_by.startswith("-") else order_by[1:]
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
field_vals = [
self._get_value_for_path(t, field_name.split(".")) for t in tasks
]
self._assertSorted(
field_vals,
ascending=not order_by.startswith("-"),
is_numeric=field_name.startswith("execution.parameters.")
is_numeric=field_name.startswith("execution.parameters."),
)
def _create_tasks(self):

View File

@ -0,0 +1,65 @@
from boltons.iterutils import first
from apiserver.tests.automated import TestService
class TestProjectsRetrieval(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_active_user(self):
user = self.api.users.get_current_user().user.id
project1 = self.temp_project(name="Project retrieval1")
project2 = self.temp_project(name="Project retrieval2")
self.temp_task(project=project2)
projects = self.api.projects.get_all_ex().projects
self.assertTrue({project1, project2}.issubset({p.id for p in projects}))
projects = self.api.projects.get_all_ex(active_users=[user]).projects
ids = {p.id for p in projects}
self.assertFalse(project1 in ids)
self.assertTrue(project2 in ids)
def test_stats(self):
project = self.temp_project()
self.temp_task(project=project)
self.temp_task(project=project)
archived_task = self.temp_task(project=project)
self.api.tasks.archive(tasks=[archived_task])
p = self._get_project(project)
self.assertFalse("stats" in p)
p = self._get_project(project, include_stats=True)
self.assertFalse("archived" in p.stats)
self.assertTrue(p.stats.active.status_count.created, 2)
p = self._get_project(project, include_stats=True, stats_for_state=None)
self.assertTrue(p.stats.active.status_count.created, 2)
self.assertTrue(p.stats.archived.status_count.created, 1)
def _get_project(self, project, **kwargs):
projects = self.api.projects.get_all_ex(id=[project], **kwargs).projects
p = first(p for p in projects if p.id == project)
self.assertIsNotNone(p)
return p
def temp_project(self, **kwargs) -> str:
self.update_missing(
kwargs,
name="Test projects retrieval",
description="test",
delete_params=dict(force=True),
)
return self.create_temp("projects", **kwargs)
def temp_task(self, **kwargs) -> str:
self.update_missing(
kwargs,
type="testing",
name="test projects retrieval",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs)

View File

@ -1,10 +1,14 @@
from enum import Enum
class StringEnum(Enum):
class StringEnum(str, Enum):
def __str__(self):
return self.value
@classmethod
def values(cls):
return list(map(str, cls))
# noinspection PyMethodParameters
def _generate_next_value_(name, start, count, last_values):
return name
return name