mirror of
https://github.com/clearml/clearml-server
synced 2025-04-03 04:40:57 +00:00
Support active users in projects
This commit is contained in:
parent
6411954002
commit
d029d56508
@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField):
|
||||
)
|
||||
|
||||
def parse_value(self, value):
|
||||
if value is None and not self.required:
|
||||
if value is NotSet and not self.required:
|
||||
return self.get_default_value()
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
|
@ -30,3 +30,10 @@ class ProjectHyperparamValuesRequest(MultiProjectReq):
|
||||
section = fields.StringField(required=True)
|
||||
name = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
active_users = fields.ListField(str)
|
||||
|
@ -1,15 +1,21 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Sequence, Optional, Type
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional, Type, Tuple, Dict
|
||||
|
||||
from mongoengine import Q, Document
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@ -132,6 +138,205 @@ class ProjectBLL:
|
||||
if hasattr(entity_cls, "last_change")
|
||||
else {}
|
||||
)
|
||||
entity_cls.objects(company=company, id__in=ids).update(set__project=project, **extra)
|
||||
entity_cls.objects(company=company, id__in=ids).update(
|
||||
set__project=project, **extra
|
||||
)
|
||||
|
||||
return project
|
||||
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
|
||||
@classmethod
|
||||
def make_projects_get_all_pipelines(
|
||||
cls,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
def ensure_valid_fields():
|
||||
"""
|
||||
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
||||
"""
|
||||
return {
|
||||
"$addFields": {
|
||||
"system_tags": {
|
||||
"$cond": {
|
||||
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
|
||||
"then": [],
|
||||
"else": "$system_tags",
|
||||
}
|
||||
},
|
||||
"status": {"$ifNull": ["$status", "unknown"]},
|
||||
}
|
||||
}
|
||||
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"project": "$project",
|
||||
"status": "$status",
|
||||
archived: cls.archived_tasks_cond,
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
# for each project, create a list of (status, count, archived)
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.project",
|
||||
"counts": {
|
||||
"$push": {
|
||||
"status": "$_id.status",
|
||||
"count": "$count",
|
||||
archived: "$_id.%s" % archived,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed and started and completed > started
|
||||
"if": {
|
||||
"$and": [
|
||||
"$started",
|
||||
"$completed",
|
||||
{"$gt": ["$completed", "$started"]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
# then: floor((completed - started) / 1000)
|
||||
"then": {
|
||||
"$floor": {
|
||||
"$divide": [
|
||||
{"$subtract": ["$completed", "$started"]},
|
||||
1000.0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in EntityVisibility:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == EntityVisibility.active:
|
||||
group_step[state.value] = runtime_subquery(
|
||||
{"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
elif state == EntityVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"type": {"$in": ["training", "testing", "annotation"]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
# for each project
|
||||
"$group": group_step
|
||||
},
|
||||
]
|
||||
|
||||
return status_count_pipeline, runtime_pipeline
|
||||
|
||||
@classmethod
|
||||
def get_project_stats(
|
||||
cls,
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
) -> Dict[str, dict]:
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
status_count_pipeline, runtime_pipeline = cls.make_projects_get_all_pipelines(
|
||||
company, project_ids=project_ids, specific_state=specific_state
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
|
||||
def set_default_count(entry):
|
||||
return dict(default_counts, **entry)
|
||||
|
||||
status_count = defaultdict(lambda: {})
|
||||
key = itemgetter(EntityVisibility.archived.value)
|
||||
for result in Task.aggregate(status_count_pipeline):
|
||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||
section = (
|
||||
EntityVisibility.archived if k else EntityVisibility.active
|
||||
).value
|
||||
status_count[result["_id"]][section] = set_default_count(
|
||||
{
|
||||
count_entry["status"]: count_entry["count"]
|
||||
for count_entry in group
|
||||
}
|
||||
)
|
||||
|
||||
runtime = {
|
||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||
for result in Task.aggregate(runtime_pipeline)
|
||||
}
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
path = "/".join((project_id, section))
|
||||
return {
|
||||
"total_runtime": safe_get(runtime, path, 0),
|
||||
"status_count": safe_get(status_count, path, default_counts),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
return {
|
||||
project: {
|
||||
task_state.value: get_status_counts(project, task_state.value)
|
||||
for task_state in report_for_states
|
||||
}
|
||||
for project in project_ids
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
cls,
|
||||
company: str,
|
||||
users: Sequence[str],
|
||||
project_ids: Optional[Sequence[str]] = None,
|
||||
allow_public: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""Get the projects ids where user created any tasks"""
|
||||
company = (
|
||||
{"company__in": [None, "", company]}
|
||||
if allow_public
|
||||
else {"company": company}
|
||||
)
|
||||
projects = {"project__in": project_ids} if project_ids else {}
|
||||
return Task.objects(**company, user__in=users, **projects).distinct(
|
||||
field="project"
|
||||
)
|
||||
|
@ -637,6 +637,35 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
def _get_queries_for_order_field(
|
||||
cls, query: Q, order_field: str
|
||||
) -> Union[None, Tuple[Q, Q]]:
|
||||
"""
|
||||
In case the order_field is one of the cls fields and the sorting is ascending
|
||||
then return the tuple of 2 queries:
|
||||
1. original query with not empty constraint on the order_by field
|
||||
2. original query with empty constraint on the order_by field
|
||||
"""
|
||||
if not order_field or order_field.startswith("-") or "[" in order_field:
|
||||
return
|
||||
|
||||
mongo_field_name = order_field.replace(".", "__")
|
||||
mongo_field = first(
|
||||
v for k, v in cls.get_all_fields_with_instance() if k == mongo_field_name
|
||||
)
|
||||
if not mongo_field:
|
||||
return
|
||||
|
||||
params = {}
|
||||
if isinstance(mongo_field, ListField):
|
||||
params["is_list"] = True
|
||||
elif isinstance(mongo_field, StringField):
|
||||
params["empty_value"] = ""
|
||||
non_empty = query & field_exists(mongo_field_name, **params)
|
||||
empty = query & field_does_not_exist(mongo_field_name, **params)
|
||||
return non_empty, empty
|
||||
|
||||
@classmethod
|
||||
def _get_many_override_none_ordering(
|
||||
cls: Union[Document, "GetMixin"],
|
||||
@ -675,21 +704,9 @@ class GetMixin(PropsMixin):
|
||||
order_field = first(
|
||||
field for field in order_by if not field.startswith("$")
|
||||
)
|
||||
if (
|
||||
order_field
|
||||
and not order_field.startswith("-")
|
||||
and "[" not in order_field
|
||||
):
|
||||
params = {}
|
||||
mongo_field = order_field.replace(".", "__")
|
||||
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
|
||||
params["is_list"] = True
|
||||
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
|
||||
params["empty_value"] = ""
|
||||
non_empty = query & field_exists(mongo_field, **params)
|
||||
empty = query & field_does_not_exist(mongo_field, **params)
|
||||
query_sets = [cls.objects(non_empty), cls.objects(empty)]
|
||||
|
||||
res = cls._get_queries_for_order_field(query, order_field)
|
||||
if res:
|
||||
query_sets = [cls.objects(q) for q in res]
|
||||
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
||||
if order_field:
|
||||
collation_override = first(
|
||||
|
@ -1,12 +1,11 @@
|
||||
from collections import OrderedDict, defaultdict
|
||||
from itertools import chain
|
||||
from collections import OrderedDict
|
||||
from operator import attrgetter
|
||||
from threading import Lock
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine.base import get_document, BaseField
|
||||
from mongoengine.base import get_document
|
||||
|
||||
from apiserver.database.fields import (
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
@ -21,7 +20,7 @@ class PropsMixin(object):
|
||||
__cached_reference_fields = None
|
||||
__cached_exclude_fields = None
|
||||
__cached_fields_with_instance = None
|
||||
__cached_field_names_per_type = None
|
||||
__cached_all_fields_with_instance = None
|
||||
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
@ -33,37 +32,12 @@ class PropsMixin(object):
|
||||
return cls.__cached_fields
|
||||
|
||||
@classmethod
|
||||
def get_field_names_for_type(cls, of_type=BaseField):
|
||||
"""
|
||||
Return field names per type including subfields
|
||||
The fields of derived types are also returned
|
||||
"""
|
||||
assert issubclass(of_type, BaseField)
|
||||
if cls.__cached_field_names_per_type is None:
|
||||
fields = defaultdict(list)
|
||||
for name, field in get_fields(cls, return_instance=True, subfields=True):
|
||||
fields[type(field)].append(name)
|
||||
for type_ in fields:
|
||||
fields[type_].extend(
|
||||
chain.from_iterable(
|
||||
fields[other_type]
|
||||
for other_type in fields
|
||||
if other_type != type_ and issubclass(other_type, type_)
|
||||
)
|
||||
)
|
||||
cls.__cached_field_names_per_type = fields
|
||||
|
||||
if of_type not in cls.__cached_field_names_per_type:
|
||||
names = list(
|
||||
chain.from_iterable(
|
||||
field_names
|
||||
for type_, field_names in cls.__cached_field_names_per_type.items()
|
||||
if issubclass(type_, of_type)
|
||||
)
|
||||
def get_all_fields_with_instance(cls):
|
||||
if cls.__cached_all_fields_with_instance is None:
|
||||
cls.__cached_all_fields_with_instance = get_fields(
|
||||
cls, return_instance=True, subfields=True
|
||||
)
|
||||
cls.__cached_field_names_per_type[of_type] = names
|
||||
|
||||
return cls.__cached_field_names_per_type[of_type]
|
||||
return cls.__cached_all_fields_with_instance
|
||||
|
||||
@classmethod
|
||||
def get_fields_with_instance(cls, doc_cls):
|
||||
|
@ -413,6 +413,17 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_all_ex."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
active_users {
|
||||
descritpion: "The list of users that were active in the project. If passes then the resulting projects are filtered to the ones that have tasks created by these users"
|
||||
type: array
|
||||
items: {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
|
@ -1,9 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
|
||||
import dpath
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
@ -15,6 +11,7 @@ from apiserver.apimodels.projects import (
|
||||
ProjectTagsRequest,
|
||||
ProjectTaskParentsRequest,
|
||||
ProjectHyperparamValuesRequest,
|
||||
ProjectsGetRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
@ -23,10 +20,9 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.database.utils import (
|
||||
parse_from_call,
|
||||
get_options,
|
||||
get_company_or_none_constraint,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
@ -40,7 +36,7 @@ from apiserver.timing_context import TimingContext
|
||||
|
||||
org_bll = OrgBLL()
|
||||
task_bll = TaskBLL()
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
@ -75,199 +71,46 @@ def get_by_id(call):
|
||||
call.result.data = {"project": project_dict}
|
||||
|
||||
|
||||
def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None):
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
def ensure_valid_fields():
|
||||
"""
|
||||
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
|
||||
"""
|
||||
return {
|
||||
"$addFields": {
|
||||
"system_tags": {
|
||||
"$cond": {
|
||||
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
|
||||
"then": [],
|
||||
"else": "$system_tags",
|
||||
}
|
||||
},
|
||||
"status": {"$ifNull": ["$status", "unknown"]},
|
||||
}
|
||||
}
|
||||
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"project": "$project",
|
||||
"status": "$status",
|
||||
archived: archived_tasks_cond,
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
# for each project, create a list of (status, count, archived)
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.project",
|
||||
"counts": {
|
||||
"$push": {
|
||||
"status": "$_id.status",
|
||||
"count": "$count",
|
||||
archived: "$_id.%s" % archived,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed and started and completed > started
|
||||
"if": {
|
||||
"$and": [
|
||||
"$started",
|
||||
"$completed",
|
||||
{"$gt": ["$completed", "$started"]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
# then: floor((completed - started) / 1000)
|
||||
"then": {
|
||||
"$floor": {
|
||||
"$divide": [
|
||||
{"$subtract": ["$completed", "$started"]},
|
||||
1000.0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in EntityVisibility:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == EntityVisibility.active:
|
||||
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
|
||||
elif state == EntityVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(archived_tasks_cond)
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"type": {"$in": ["training", "testing"]},
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
# for each project
|
||||
"$group": group_step
|
||||
},
|
||||
]
|
||||
|
||||
return status_count_pipeline, runtime_pipeline
|
||||
|
||||
|
||||
@endpoint("projects.get_all_ex")
|
||||
def get_all_ex(call: APICall):
|
||||
include_stats = call.data.get("include_stats")
|
||||
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
|
||||
allow_public = not call.data.get("non_public", False)
|
||||
|
||||
if stats_for_state:
|
||||
try:
|
||||
specific_state = EntityVisibility(stats_for_state)
|
||||
except ValueError:
|
||||
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
|
||||
else:
|
||||
specific_state = None
|
||||
|
||||
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
|
||||
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
allow_public = not request.non_public
|
||||
with TimingContext("mongo", "projects_get_all"):
|
||||
if request.active_users:
|
||||
ids = project_bll.get_projects_with_active_user(
|
||||
company=company_id,
|
||||
users=request.active_users,
|
||||
project_ids=call.data.get("id"),
|
||||
allow_public=allow_public,
|
||||
)
|
||||
if not ids:
|
||||
call.result.data = {"projects": []}
|
||||
return
|
||||
call.data["id"] = ids
|
||||
|
||||
projects = Project.get_many_with_join(
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
|
||||
if not include_stats:
|
||||
conform_output_tags(call, projects)
|
||||
if not request.include_stats:
|
||||
call.result.data = {"projects": projects}
|
||||
return
|
||||
|
||||
ids = [project["id"] for project in projects]
|
||||
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
|
||||
call.identity.company, ids, specific_state=specific_state
|
||||
project_ids = {project["id"] for project in projects}
|
||||
stats = project_bll.get_project_stats(
|
||||
company=company_id,
|
||||
project_ids=list(project_ids),
|
||||
specific_state=request.stats_for_state,
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
for project in projects:
|
||||
project["stats"] = stats[project["id"]]
|
||||
|
||||
def set_default_count(entry):
|
||||
return dict(default_counts, **entry)
|
||||
|
||||
status_count = defaultdict(lambda: {})
|
||||
key = itemgetter(EntityVisibility.archived.value)
|
||||
for result in Task.aggregate(status_count_pipeline):
|
||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||
section = (
|
||||
EntityVisibility.archived if k else EntityVisibility.active
|
||||
).value
|
||||
status_count[result["_id"]][section] = set_default_count(
|
||||
{
|
||||
count_entry["status"]: count_entry["count"]
|
||||
for count_entry in group
|
||||
}
|
||||
)
|
||||
|
||||
runtime = {
|
||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||
for result in Task.aggregate(runtime_pipeline)
|
||||
}
|
||||
|
||||
def safe_get(obj, path, default=None):
|
||||
try:
|
||||
return dpath.get(obj, path)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
path = "/".join((project_id, section))
|
||||
return {
|
||||
"total_runtime": safe_get(runtime, path, 0),
|
||||
"status_count": safe_get(status_count, path, default_counts),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
for project in projects:
|
||||
project["stats"] = {
|
||||
task_state.value: get_status_counts(project["id"], task_state.value)
|
||||
for task_state in report_for_states
|
||||
}
|
||||
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects}
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
|
@ -28,7 +28,9 @@ class TestEntityOrdering(TestService):
|
||||
self._assertGetTasksWithOrdering(order_by="comment")
|
||||
|
||||
# sort by parameter which type is not part of db schema
|
||||
self._assertGetTasksWithOrdering(order_by="execution.parameters.test")
|
||||
self._assertGetTasksWithOrdering(
|
||||
order_by="execution.parameters.test", valid_order=False
|
||||
)
|
||||
|
||||
def test_order_with_paging(self):
|
||||
order_field = "started"
|
||||
@ -97,7 +99,9 @@ class TestEntityOrdering(TestService):
|
||||
|
||||
return val
|
||||
|
||||
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
|
||||
def _assertGetTasksWithOrdering(
|
||||
self, order_by: str = None, valid_order=True, **kwargs
|
||||
):
|
||||
tasks = self.api.tasks.get_all_ex(
|
||||
only_fields=self.only_fields,
|
||||
order_by=[order_by] if isinstance(order_by, str) else order_by,
|
||||
@ -105,14 +109,16 @@ class TestEntityOrdering(TestService):
|
||||
**kwargs,
|
||||
).tasks
|
||||
self.assertLessEqual(set(self.task_ids), set(t.id for t in tasks))
|
||||
if order_by:
|
||||
if order_by and valid_order:
|
||||
# test that the output is correctly ordered
|
||||
field_name = order_by if not order_by.startswith("-") else order_by[1:]
|
||||
field_vals = [self._get_value_for_path(t, field_name.split(".")) for t in tasks]
|
||||
field_vals = [
|
||||
self._get_value_for_path(t, field_name.split(".")) for t in tasks
|
||||
]
|
||||
self._assertSorted(
|
||||
field_vals,
|
||||
ascending=not order_by.startswith("-"),
|
||||
is_numeric=field_name.startswith("execution.parameters.")
|
||||
is_numeric=field_name.startswith("execution.parameters."),
|
||||
)
|
||||
|
||||
def _create_tasks(self):
|
||||
|
65
apiserver/tests/automated/test_projects_retrieval.py
Normal file
65
apiserver/tests/automated/test_projects_retrieval.py
Normal file
@ -0,0 +1,65 @@
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestProjectsRetrieval(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.13")
|
||||
|
||||
def test_active_user(self):
|
||||
user = self.api.users.get_current_user().user.id
|
||||
project1 = self.temp_project(name="Project retrieval1")
|
||||
project2 = self.temp_project(name="Project retrieval2")
|
||||
self.temp_task(project=project2)
|
||||
|
||||
projects = self.api.projects.get_all_ex().projects
|
||||
self.assertTrue({project1, project2}.issubset({p.id for p in projects}))
|
||||
|
||||
projects = self.api.projects.get_all_ex(active_users=[user]).projects
|
||||
ids = {p.id for p in projects}
|
||||
self.assertFalse(project1 in ids)
|
||||
self.assertTrue(project2 in ids)
|
||||
|
||||
def test_stats(self):
|
||||
project = self.temp_project()
|
||||
self.temp_task(project=project)
|
||||
self.temp_task(project=project)
|
||||
archived_task = self.temp_task(project=project)
|
||||
self.api.tasks.archive(tasks=[archived_task])
|
||||
|
||||
p = self._get_project(project)
|
||||
self.assertFalse("stats" in p)
|
||||
|
||||
p = self._get_project(project, include_stats=True)
|
||||
self.assertFalse("archived" in p.stats)
|
||||
self.assertTrue(p.stats.active.status_count.created, 2)
|
||||
|
||||
p = self._get_project(project, include_stats=True, stats_for_state=None)
|
||||
self.assertTrue(p.stats.active.status_count.created, 2)
|
||||
self.assertTrue(p.stats.archived.status_count.created, 1)
|
||||
|
||||
def _get_project(self, project, **kwargs):
|
||||
projects = self.api.projects.get_all_ex(id=[project], **kwargs).projects
|
||||
p = first(p for p in projects if p.id == project)
|
||||
self.assertIsNotNone(p)
|
||||
return p
|
||||
|
||||
def temp_project(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
name="Test projects retrieval",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("projects", **kwargs)
|
||||
|
||||
def temp_task(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
type="testing",
|
||||
name="test projects retrieval",
|
||||
input=dict(view=dict()),
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
@ -1,10 +1,14 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StringEnum(Enum):
|
||||
class StringEnum(str, Enum):
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return list(map(str, cls))
|
||||
|
||||
# noinspection PyMethodParameters
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
return name
|
||||
|
Loading…
Reference in New Issue
Block a user