diff --git a/apiserver/apimodels/__init__.py b/apiserver/apimodels/__init__.py index 5683b16..50dcb48 100644 --- a/apiserver/apimodels/__init__.py +++ b/apiserver/apimodels/__init__.py @@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField): ) def parse_value(self, value): - if value is None and not self.required: + if value is NotSet and not self.required: return self.get_default_value() try: # noinspection PyArgumentList diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 32994ea..ccbd11e 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -30,3 +30,10 @@ class ProjectHyperparamValuesRequest(MultiProjectReq): section = fields.StringField(required=True) name = fields.StringField(required=True) allow_public = fields.BoolField(default=True) + + +class ProjectsGetRequest(models.Base): + include_stats = fields.BoolField(default=False) + stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active) + non_public = fields.BoolField(default=False) + active_users = fields.ListField(str) diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index 0ad70b1..dff97d5 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -1,15 +1,21 @@ +from collections import defaultdict from datetime import datetime -from typing import Sequence, Optional, Type +from itertools import groupby +from operator import itemgetter +from typing import Sequence, Optional, Type, Tuple, Dict from mongoengine import Q, Document from apiserver import database from apiserver.apierrors import errors from apiserver.config_repo import config +from apiserver.database.model import EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.project import Project -from apiserver.database.model.task.task import Task +from apiserver.database.model.task.task import Task, TaskStatus +from apiserver.database.utils import get_options from apiserver.timing_context import TimingContext +from apiserver.tools import safe_get log = config.logger(__file__) @@ -132,6 +138,205 @@ class ProjectBLL: if hasattr(entity_cls, "last_change") else {} ) - entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra) + entity_cls.objects(company=company, id__in=ids).update( + set__project=project, **extra + ) return project + + archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]} + + @classmethod + def make_projects_get_all_pipelines( + cls, + company_id: str, + project_ids: Sequence[str], + specific_state: Optional[EntityVisibility] = None, + ) -> Tuple[Sequence, Sequence]: + archived = EntityVisibility.archived.value + + def ensure_valid_fields(): + """ + Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond + """ + return { + "$addFields": { + "system_tags": { + "$cond": { + "if": {"$ne": [{"$type": "$system_tags"}, "array"]}, + "then": [], + "else": "$system_tags", + } + }, + "status": {"$ifNull": ["$status", "unknown"]}, + } + } + + status_count_pipeline = [ + # count tasks per project per status + { + "$match": { + "company": {"$in": [None, "", company_id]}, + "project": {"$in": project_ids}, + } + }, + ensure_valid_fields(), + { + "$group": { + "_id": { + "project": "$project", + "status": "$status", + archived: cls.archived_tasks_cond, + }, + "count": {"$sum": 1}, + } + }, + # for each project, create a list of (status, count, archived) + { + "$group": { + "_id": "$_id.project", + "counts": { + "$push": { + "status": "$_id.status", + "count": "$count", + archived: "$_id.%s" % archived, + } + }, + } + }, + ] + + def runtime_subquery(additional_cond): + return { + # the sum of + "$sum": { + # for each task + "$cond": { + # if completed and started and completed > started + "if": { + "$and": [ + "$started", + "$completed", + {"$gt": ["$completed", "$started"]}, + additional_cond, + ] + }, + # then: floor((completed - started) / 1000) + "then": { + "$floor": { + "$divide": [ + {"$subtract": ["$completed", "$started"]}, + 1000.0, + ] + } + }, + "else": 0, + } + } + } + + group_step = {"_id": "$project"} + + for state in EntityVisibility: + if specific_state and state != specific_state: + continue + if state == EntityVisibility.active: + group_step[state.value] = runtime_subquery( + {"$not": cls.archived_tasks_cond} + ) + elif state == EntityVisibility.archived: + group_step[state.value] = runtime_subquery(cls.archived_tasks_cond) + + runtime_pipeline = [ + # only count run time for these types of tasks + { + "$match": { + "company": {"$in": [None, "", company_id]}, + "type": {"$in": ["training", "testing", "annotation"]}, + "project": {"$in": project_ids}, + } + }, + ensure_valid_fields(), + { + # for each project + "$group": group_step + }, + ] + + return status_count_pipeline, runtime_pipeline + + @classmethod + def get_project_stats( + cls, + company: str, + project_ids: Sequence[str], + specific_state: Optional[EntityVisibility] = None, + ) -> Dict[str, dict]: + if not project_ids: + return {} + + status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines( + company, project_ids=project_ids, specific_state=specific_state + ) + + default_counts = dict.fromkeys(get_options(TaskStatus), 0) + + def set_default_count(entry): + return dict(default_counts, **entry) + + status_count = defaultdict(lambda: {}) + key = itemgetter(EntityVisibility.archived.value) + for result in Task.aggregate(status_count_pipeline): + for k, group in groupby(sorted(result["counts"], key=key), key): + section = ( + EntityVisibility.archived if k else EntityVisibility.active + ).value + status_count[result["_id"]][section] = set_default_count( + { + count_entry["status"]: count_entry["count"] + for count_entry in group + } + ) + + runtime = { + result["_id"]: {k: v for k, v in result.items() if k != "_id"} + for result in Task.aggregate(runtime_pipeline) + } + + def get_status_counts(project_id, section): + path = "/".join((project_id, section)) + return { + "total_runtime": safe_get(runtime, path, 0), + "status_count": safe_get(status_count, path, default_counts), + } + + report_for_states = [ + s for s in EntityVisibility if not specific_state or specific_state == s + ] + + return { + project: { + task_state.value: get_status_counts(project, task_state.value) + for task_state in report_for_states + } + for project in project_ids + } + + @classmethod + def get_projects_with_active_user( + cls, + company: str, + users: Sequence[str], + project_ids: Optional[Sequence[str]] = None, + allow_public: bool = True, + ) -> Sequence[str]: + """Get the projects ids where user created any tasks""" + company = ( + {"company__in": [None, "", company]} + if allow_public + else {"company": company} + ) + projects = {"project__in": project_ids} if project_ids else {} + return Task.objects(**company, user__in=users, **projects).distinct( + field="project" + ) diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index 05366ce..52c3dac 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -637,6 +637,35 @@ class GetMixin(PropsMixin): return qs + @classmethod + def _get_queries_for_order_field( + cls, query: Q, order_field: str + ) -> Union[None, Tuple[Q, Q]]: + """ + In case the order_field is one of the cls fields and the sorting is ascending + then return the tuple of 2 queries: + 1. original query with not empty constraint on the order_by field + 2. original query with empty constraint on the order_by field + """ + if not order_field or order_field.startswith("-") or "[" in order_field: + return + + mongo_field_name = order_field.replace(".", "__") + mongo_field = first( + v for k, v in cls.get_all_fields_with_instance() if k == mongo_field_name + ) + if not mongo_field: + return + + params = {} + if isinstance(mongo_field, ListField): + params["is_list"] = True + elif isinstance(mongo_field, StringField): + params["empty_value"] = "" + non_empty = query & field_exists(mongo_field_name, **params) + empty = query & field_does_not_exist(mongo_field_name, **params) + return non_empty, empty + @classmethod def _get_many_override_none_ordering( cls: Union[Document, "GetMixin"], @@ -675,21 +704,9 @@ class GetMixin(PropsMixin): order_field = first( field for field in order_by if not field.startswith("$") ) - if ( - order_field - and not order_field.startswith("-") - and "[" not in order_field - ): - params = {} - mongo_field = order_field.replace(".", "__") - if mongo_field in cls.get_field_names_for_type(of_type=ListField): - params["is_list"] = True - elif mongo_field in cls.get_field_names_for_type(of_type=StringField): - params["empty_value"] = "" - non_empty = query & field_exists(mongo_field, **params) - empty = query & field_does_not_exist(mongo_field, **params) - query_sets = [cls.objects(non_empty), cls.objects(empty)] - + res = cls._get_queries_for_order_field(query, order_field) + if res: + query_sets = [cls.objects(q) for q in res] query_sets = [qs.order_by(*order_by) for qs in query_sets] if order_field: collation_override = first( diff --git a/apiserver/database/props.py b/apiserver/database/props.py index 4eb03c9..248d313 100644 --- a/apiserver/database/props.py +++ b/apiserver/database/props.py @@ -1,12 +1,11 @@ -from collections import OrderedDict, defaultdict -from itertools import chain +from collections import OrderedDict from operator import attrgetter from threading import Lock from typing import Sequence import six from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField -from mongoengine.base import get_document, BaseField +from mongoengine.base import get_document from apiserver.database.fields import ( LengthRangeEmbeddedDocumentListField, @@ -21,7 +20,7 @@ class PropsMixin(object): __cached_reference_fields = None __cached_exclude_fields = None __cached_fields_with_instance = None - __cached_field_names_per_type = None + __cached_all_fields_with_instance = None __cached_dpath_computed_fields_lock = Lock() __cached_dpath_computed_fields = None @@ -33,37 +32,12 @@ class PropsMixin(object): return cls.__cached_fields @classmethod - def get_field_names_for_type(cls, of_type=BaseField): - """ - Return field names per type including subfields - The fields of derived types are also returned - """ - assert issubclass(of_type, BaseField) - if cls.__cached_field_names_per_type is None: - fields = defaultdict(list) - for name, field in get_fields(cls, return_instance=True, subfields=True): - fields[type(field)].append(name) - for type_ in fields: - fields[type_].extend( - chain.from_iterable( - fields[other_type] - for other_type in fields - if other_type != type_ and issubclass(other_type, type_) - ) - ) - cls.__cached_field_names_per_type = fields - - if of_type not in cls.__cached_field_names_per_type: - names = list( - chain.from_iterable( - field_names - for type_, field_names in cls.__cached_field_names_per_type.items() - if issubclass(type_, of_type) - ) + def get_all_fields_with_instance(cls): + if cls.__cached_all_fields_with_instance is None: + cls.__cached_all_fields_with_instance = get_fields( + cls, return_instance=True, subfields=True ) - cls.__cached_field_names_per_type[of_type] = names - - return cls.__cached_field_names_per_type[of_type] + return cls.__cached_all_fields_with_instance @classmethod def get_fields_with_instance(cls, doc_cls): diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 793dcbe..608a11c 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -413,6 +413,17 @@ get_all_ex { } } } + "2.13": ${get_all_ex."2.1"} { + request { + properties { + active_users { + descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users" + type: array + items: {type: string} + } + } + } + } } update { "2.1" { diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 1e510c7..2628233 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -1,9 +1,5 @@ -from collections import defaultdict from datetime import datetime -from itertools import groupby -from operator import itemgetter -import dpath from mongoengine import Q from apiserver.apierrors import errors @@ -15,6 +11,7 @@ from apiserver.apimodels.projects import ( ProjectTagsRequest, ProjectTaskParentsRequest, ProjectHyperparamValuesRequest, + ProjectsGetRequest, ) from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL @@ -23,10 +20,9 @@ from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.project import Project -from apiserver.database.model.task.task import Task, TaskStatus +from apiserver.database.model.task.task import Task from apiserver.database.utils import ( parse_from_call, - get_options, get_company_or_none_constraint, ) from apiserver.service_repo import APICall, endpoint @@ -40,7 +36,7 @@ from apiserver.timing_context import TimingContext org_bll = OrgBLL() task_bll = TaskBLL() -archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]} +project_bll = ProjectBLL() create_fields = { "name": None, @@ -75,199 +71,46 @@ def get_by_id(call): call.result.data = {"project": project_dict} -def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None): - archived = EntityVisibility.archived.value - - def ensure_valid_fields(): - """ - Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond - """ - return { - "$addFields": { - "system_tags": { - "$cond": { - "if": {"$ne": [{"$type": "$system_tags"}, "array"]}, - "then": [], - "else": "$system_tags", - } - }, - "status": {"$ifNull": ["$status", "unknown"]}, - } - } - - status_count_pipeline = [ - # count tasks per project per status - { - "$match": { - "company": {"$in": [None, "", company_id]}, - "project": {"$in": project_ids}, - } - }, - ensure_valid_fields(), - { - "$group": { - "_id": { - "project": "$project", - "status": "$status", - archived: archived_tasks_cond, - }, - "count": {"$sum": 1}, - } - }, - # for each project, create a list of (status, count, archived) - { - "$group": { - "_id": "$_id.project", - "counts": { - "$push": { - "status": "$_id.status", - "count": "$count", - archived: "$_id.%s" % archived, - } - }, - } - }, - ] - - def runtime_subquery(additional_cond): - return { - # the sum of - "$sum": { - # for each task - "$cond": { - # if completed and started and completed > started - "if": { - "$and": [ - "$started", - "$completed", - {"$gt": ["$completed", "$started"]}, - additional_cond, - ] - }, - # then: floor((completed - started) / 1000) - "then": { - "$floor": { - "$divide": [ - {"$subtract": ["$completed", "$started"]}, - 1000.0, - ] - } - }, - "else": 0, - } - } - } - - group_step = {"_id": "$project"} - - for state in EntityVisibility: - if specific_state and state != specific_state: - continue - if state == EntityVisibility.active: - group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond}) - elif state == EntityVisibility.archived: - group_step[state.value] = runtime_subquery(archived_tasks_cond) - - runtime_pipeline = [ - # only count run time for these types of tasks - { - "$match": { - "type": {"$in": ["training", "testing"]}, - "company": {"$in": [None, "", company_id]}, - "project": {"$in": project_ids}, - } - }, - ensure_valid_fields(), - { - # for each project - "$group": group_step - }, - ] - - return status_count_pipeline, runtime_pipeline - - -@endpoint("projects.get_all_ex") -def get_all_ex(call: APICall): - include_stats = call.data.get("include_stats") - stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value) - allow_public = not call.data.get("non_public", False) - - if stats_for_state: - try: - specific_state = EntityVisibility(stats_for_state) - except ValueError: - raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state) - else: - specific_state = None - +@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest) +def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): conform_tag_fields(call, call.data) - with translate_errors_context(), TimingContext("mongo", "projects_get_all"): + allow_public = not request.non_public + with TimingContext("mongo", "projects_get_all"): + if request.active_users: + ids = project_bll.get_projects_with_active_user( + company=company_id, + users=request.active_users, + project_ids=call.data.get("id"), + allow_public=allow_public, + ) + if not ids: + call.result.data = {"projects": []} + return + call.data["id"] = ids + projects = Project.get_many_with_join( - company=call.identity.company, + company=company_id, query_dict=call.data, query_options=get_all_query_options, allow_public=allow_public, ) - conform_output_tags(call, projects) - if not include_stats: + conform_output_tags(call, projects) + if not request.include_stats: call.result.data = {"projects": projects} return - ids = [project["id"] for project in projects] - status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines( - call.identity.company, ids, specific_state=specific_state + project_ids = {project["id"] for project in projects} + stats = project_bll.get_project_stats( + company=company_id, + project_ids=list(project_ids), + specific_state=request.stats_for_state, ) - default_counts = dict.fromkeys(get_options(TaskStatus), 0) + for project in projects: + project["stats"] = stats[project["id"]] - def set_default_count(entry): - return dict(default_counts, **entry) - - status_count = defaultdict(lambda: {}) - key = itemgetter(EntityVisibility.archived.value) - for result in Task.aggregate(status_count_pipeline): - for k, group in groupby(sorted(result["counts"], key=key), key): - section = ( - EntityVisibility.archived if k else EntityVisibility.active - ).value - status_count[result["_id"]][section] = set_default_count( - { - count_entry["status"]: count_entry["count"] - for count_entry in group - } - ) - - runtime = { - result["_id"]: {k: v for k, v in result.items() if k != "_id"} - for result in Task.aggregate(runtime_pipeline) - } - - def safe_get(obj, path, default=None): - try: - return dpath.get(obj, path) - except KeyError: - return default - - def get_status_counts(project_id, section): - path = "/".join((project_id, section)) - return { - "total_runtime": safe_get(runtime, path, 0), - "status_count": safe_get(status_count, path, default_counts), - } - - report_for_states = [ - s for s in EntityVisibility if not specific_state or specific_state == s - ] - - for project in projects: - project["stats"] = { - task_state.value: get_status_counts(project["id"], task_state.value) - for task_state in report_for_states - } - - call.result.data = {"projects": projects} + call.result.data = {"projects": projects} @endpoint("projects.get_all") diff --git a/apiserver/tests/automated/test_entity_ordering.py b/apiserver/tests/automated/test_entity_ordering.py index 611fa62..2966b51 100644 --- a/apiserver/tests/automated/test_entity_ordering.py +++ b/apiserver/tests/automated/test_entity_ordering.py @@ -28,7 +28,9 @@ class TestEntityOrdering(TestService): self._assertGetTasksWithOrdering(order_by="comment") # sort by parameter which type is not part of db schema - self._assertGetTasksWithOrdering(order_by="execution.parameters.test") + self._assertGetTasksWithOrdering( + order_by="execution.parameters.test", valid_order=False + ) def test_order_with_paging(self): order_field = "started" @@ -97,7 +99,9 @@ class TestEntityOrdering(TestService): return val - def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs): + def _assertGetTasksWithOrdering( + self, order_by: str = None, valid_order=True, **kwargs + ): tasks = self.api.tasks.get_all_ex( only_fields=self.only_fields, order_by=[order_by] if isinstance(order_by, str) else order_by, @@ -105,14 +109,16 @@ class TestEntityOrdering(TestService): **kwargs, ).tasks self.assertLessEqual(set(self.task_ids), set(t.id for t in tasks)) - if order_by: + if order_by and valid_order: # test that the output is correctly ordered field_name = order_by if not order_by.startswith("-") else order_by[1:] - field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks] + field_vals = [ + self._get_value_for_path(t, field_name.split(".")) for t in tasks + ] self._assertSorted( field_vals, ascending=not order_by.startswith("-"), - is_numeric=field_name.startswith("execution.parameters.") + is_numeric=field_name.startswith("execution.parameters."), ) def _create_tasks(self): diff --git a/apiserver/tests/automated/test_projects_retrieval.py b/apiserver/tests/automated/test_projects_retrieval.py new file mode 100644 index 0000000..effa739 --- /dev/null +++ b/apiserver/tests/automated/test_projects_retrieval.py @@ -0,0 +1,65 @@ +from boltons.iterutils import first + +from apiserver.tests.automated import TestService + + +class TestProjectsRetrieval(TestService): + def setUp(self, **kwargs): + super().setUp(version="2.13") + + def test_active_user(self): + user = self.api.users.get_current_user().user.id + project1 = self.temp_project(name="Project retrieval1") + project2 = self.temp_project(name="Project retrieval2") + self.temp_task(project=project2) + + projects = self.api.projects.get_all_ex().projects + self.assertTrue({project1, project2}.issubset({p.id for p in projects})) + + projects = self.api.projects.get_all_ex(active_users=[user]).projects + ids = {p.id for p in projects} + self.assertFalse(project1 in ids) + self.assertTrue(project2 in ids) + + def test_stats(self): + project = self.temp_project() + self.temp_task(project=project) + self.temp_task(project=project) + archived_task = self.temp_task(project=project) + self.api.tasks.archive(tasks=[archived_task]) + + p = self._get_project(project) + self.assertFalse("stats" in p) + + p = self._get_project(project, include_stats=True) + self.assertFalse("archived" in p.stats) + self.assertTrue(p.stats.active.status_count.created, 2) + + p = self._get_project(project, include_stats=True, stats_for_state=None) + self.assertTrue(p.stats.active.status_count.created, 2) + self.assertTrue(p.stats.archived.status_count.created, 1) + + def _get_project(self, project, **kwargs): + projects = self.api.projects.get_all_ex(id=[project], **kwargs).projects + p = first(p for p in projects if p.id == project) + self.assertIsNotNone(p) + return p + + def temp_project(self, **kwargs) -> str: + self.update_missing( + kwargs, + name="Test projects retrieval", + description="test", + delete_params=dict(force=True), + ) + return self.create_temp("projects", **kwargs) + + def temp_task(self, **kwargs) -> str: + self.update_missing( + kwargs, + type="testing", + name="test projects retrieval", + input=dict(view=dict()), + delete_params=dict(force=True), + ) + return self.create_temp("tasks", **kwargs) diff --git a/apiserver/utilities/stringenum.py b/apiserver/utilities/stringenum.py index 8eee1cd..a62dbf7 100644 --- a/apiserver/utilities/stringenum.py +++ b/apiserver/utilities/stringenum.py @@ -1,10 +1,14 @@ from enum import Enum -class StringEnum(Enum): +class StringEnum(str, Enum): def __str__(self): return self.value + @classmethod + def values(cls): + return list(map(str, cls)) + # noinspection PyMethodParameters def _generate_next_value_(name, start, count, last_values): - return name \ No newline at end of file + return name