diff --git a/apiserver/apimodels/pipelines.py b/apiserver/apimodels/pipelines.py new file mode 100644 index 0000000..ef8bd52 --- /dev/null +++ b/apiserver/apimodels/pipelines.py @@ -0,0 +1,19 @@ +from jsonmodels import models, fields + +from apiserver.apimodels import ListField + + +class Arg(models.Base): + name = fields.StringField(required=True) + value = fields.StringField(required=True) + + +class StartPipelineRequest(models.Base): + task = fields.StringField(required=True) + queue = fields.StringField(required=True) + args = ListField(Arg) + + +class StartPipelineResponse(models.Base): + pipeline = fields.StringField(required=True) + enqueued = fields.BoolField(required=True) diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index fb4ae2f..447daa5 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -1,6 +1,6 @@ from jsonmodels import models, fields -from apiserver.apimodels import ListField, ActualEnumField +from apiserver.apimodels import ListField, ActualEnumField, DictField from apiserver.apimodels.organization import TagsRequest from apiserver.database.model import EntityVisibility @@ -51,8 +51,14 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest): allow_public = fields.BoolField(default=True) +class ProjectModelMetadataValuesRequest(MultiProjectRequest): + key = fields.StringField(required=True) + allow_public = fields.BoolField(default=True) + + class ProjectsGetRequest(models.Base): include_stats = fields.BoolField(default=False) + include_stats_filter = DictField() stats_with_children = fields.BoolField(default=True) stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active) non_public = fields.BoolField(default=False) diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index d0b1bc2..e36d557 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -14,6 +14,7 @@ from typing import ( TypeVar, Callable, Mapping, + Any, ) from mongoengine import Q, Document @@ -22,6 +23,7 @@ from apiserver import database from apiserver.apierrors import errors from apiserver.config_repo import config from apiserver.database.model import EntityVisibility, AttributedDocument +from apiserver.database.model.base import GetMixin from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskStatus, external_task_types @@ -204,6 +206,7 @@ class ProjectBLL: tags: Sequence[str] = None, system_tags: Sequence[str] = None, default_output_destination: str = None, + parent_creation_params: dict = None, ) -> str: """ Create a new project. @@ -226,7 +229,12 @@ class ProjectBLL: created=now, last_update=now, ) - parent = _ensure_project(company=company, user=user, name=location) + parent = _ensure_project( + company=company, + user=user, + name=location, + creation_params=parent_creation_params, + ) _save_under_parent(project=project, parent=parent) if parent: parent.update(last_update=now) @@ -244,13 +252,14 @@ class ProjectBLL: tags: Sequence[str] = None, system_tags: Sequence[str] = None, default_output_destination: str = None, + parent_creation_params: dict = None, ) -> str: """ Find a project named `project_name` or create a new one. Returns project ID """ if not project_id and not project_name: - raise ValueError("project id or name required") + raise errors.bad_request.ValidationError("project id or name required") if project_id: project = Project.objects(company=company, id=project_id).only("id").first() @@ -271,6 +280,7 @@ class ProjectBLL: tags=tags, system_tags=system_tags, default_output_destination=default_output_destination, + parent_creation_params=parent_creation_params, ) @classmethod @@ -314,6 +324,7 @@ class ProjectBLL: company_id: str, project_ids: Sequence[str], specific_state: Optional[EntityVisibility] = None, + filter_: Mapping[str, Any] = None, ) -> Tuple[Sequence, Sequence]: archived = EntityVisibility.archived.value @@ -337,10 +348,9 @@ class ProjectBLL: status_count_pipeline = [ # count tasks per project per status { - "$match": { - "company": {"$in": [None, "", company_id]}, - "project": {"$in": project_ids}, - } + "$match": cls.get_match_conditions( + company=company_id, project_ids=project_ids, filter_=filter_ + ) }, ensure_valid_fields(), { @@ -455,8 +465,9 @@ class ProjectBLL: # only count run time for these types of tasks { "$match": { - "company": {"$in": [None, "", company_id]}, - "project": {"$in": project_ids}, + **cls.get_match_conditions( + company=company_id, project_ids=project_ids, filter_=filter_ + ), **get_state_filter(), } }, @@ -500,6 +511,7 @@ class ProjectBLL: project_ids: Sequence[str], specific_state: Optional[EntityVisibility] = None, include_children: bool = True, + filter_: Mapping[str, Any] = None, ) -> Tuple[Dict[str, dict], Dict[str, dict]]: if not project_ids: return {}, {} @@ -516,6 +528,7 @@ class ProjectBLL: company, project_ids=list(project_ids_with_children), specific_state=specific_state, + filter_=filter_, ) default_counts = dict.fromkeys(get_options(TaskStatus), 0) @@ -589,10 +602,9 @@ class ProjectBLL: return { "status_count": project_section_statuses, - "running_tasks": project_section_statuses.get(TaskStatus.in_progress), "total_tasks": sum(project_section_statuses.values()), "total_runtime": project_runtime.get(section, 0), - "completed_tasks": project_runtime.get( + "completed_tasks_24h": project_runtime.get( f"{section}_recently_completed", 0 ), "last_task_run": get_time_or_none( @@ -652,6 +664,30 @@ class ProjectBLL: return res + @classmethod + def get_project_tags( + cls, + company_id: str, + include_system: bool, + projects: Sequence[str] = None, + filter_: Dict[str, Sequence[str]] = None, + ) -> Tuple[Sequence[str], Sequence[str]]: + with TimingContext("mongo", "get_tags_from_db"): + query = Q(company=company_id) + if filter_: + for name, vals in filter_.items(): + if vals: + query &= GetMixin.get_list_field_query(name, vals) + + if projects: + query &= Q(id__in=_ids_with_children(projects)) + + tags = Project.objects(query).distinct("tags") + system_tags = ( + Project.objects(query).distinct("system_tags") if include_system else [] + ) + return tags, system_tags + @classmethod def get_projects_with_active_user( cls, @@ -708,6 +744,7 @@ class ProjectBLL: if include_subprojects: projects = _ids_with_children(projects) query &= Q(project__in=projects) + if state == EntityVisibility.archived: query &= Q(system_tags__in=[EntityVisibility.archived.value]) elif state == EntityVisibility.active: @@ -735,6 +772,7 @@ class ProjectBLL: if project_ids: project_ids = _ids_with_children(project_ids) query &= Q(project__in=project_ids) + res = Task.objects(query).distinct(field="type") return set(res).intersection(external_task_types) @@ -750,9 +788,34 @@ class ProjectBLL: query &= Q(project__in=project_ids) return Model.objects(query).distinct(field="framework") + @staticmethod + def get_match_conditions( + company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] + ): + conditions = { + "company": {"$in": [None, "", company]}, + "project": {"$in": project_ids}, + } + if not filter_: + return conditions + + for field in ("tags", "system_tags"): + field_filter = filter_.get(field) + if not field_filter: + continue + if not isinstance(field_filter, list) or not all( + isinstance(t, str) for t in field_filter + ): + raise errors.bad_request.ValidationError( + f"List of strings expected for the field: {field}" + ) + conditions[field] = {"$in": field_filter} + + return conditions + @classmethod def calc_own_contents( - cls, company: str, project_ids: Sequence[str] + cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None ) -> Dict[str, dict]: """ Returns the amount of task/models per requested project @@ -764,13 +827,12 @@ class ProjectBLL: pipeline = [ { - "$match": { - "company": {"$in": [None, "", company]}, - "project": {"$in": project_ids}, - } + "$match": cls.get_match_conditions( + company=company, project_ids=project_ids, filter_=filter_ + ) }, {"$project": {"project": 1}}, - {"$group": {"_id": "$project", "count": {"$sum": 1}}} + {"$group": {"_id": "$project", "count": {"$sum": 1}}}, ] def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict: diff --git a/apiserver/bll/project/project_queries.py b/apiserver/bll/project/project_queries.py index 43ffed2..e7a3a95 100644 --- a/apiserver/bll/project/project_queries.py +++ b/apiserver/bll/project/project_queries.py @@ -1,6 +1,6 @@ import json from collections import OrderedDict -from datetime import datetime, timedelta +from datetime import datetime from typing import ( Sequence, Optional, @@ -28,12 +28,21 @@ class ProjectQueries: def _get_project_constraint( project_ids: Sequence[str], include_subprojects: bool ) -> dict: + """ + If passed projects is None means top level projects + If passed projects is empty means no project filtering + """ if include_subprojects: - if project_ids is None: + if not project_ids: return {} project_ids = _ids_with_children(project_ids) - return {"project": {"$in": project_ids if project_ids is not None else [None]}} + if project_ids is None: + project_ids = [None] + if not project_ids: + return {} + + return {"project": {"$in": project_ids}} @staticmethod def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict: @@ -106,16 +115,11 @@ class ProjectQueries: return total, remaining, results - HyperParamValues = Tuple[int, Sequence[str]] + ParamValues = Tuple[int, Sequence[str]] - def _get_cached_hyperparam_values( - self, key: str, last_update: datetime - ) -> Optional[HyperParamValues]: - allowed_delta = timedelta( - seconds=config.get( - "services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60 - ) - ) + def _get_cached_param_values( + self, key: str, last_update: datetime, allowed_delta_sec=0 + ) -> Optional[ParamValues]: try: cached = self.redis.get(key) if not cached: @@ -123,12 +127,12 @@ class ProjectQueries: data = json.loads(cached) cached_last_update = datetime.fromtimestamp(data["last_update"]) - if (last_update - cached_last_update) < allowed_delta: + if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec: return data["total"], data["values"] except Exception as ex: - log.error(f"Error retrieving hyperparam cached values: {str(ex)}") + log.error(f"Error retrieving params cached values: {str(ex)}") - def get_hyperparam_distinct_values( + def get_task_hyperparam_distinct_values( self, company_id: str, project_ids: Sequence[str], @@ -136,7 +140,7 @@ class ProjectQueries: name: str, include_subprojects: bool, allow_public: bool = True, - ) -> HyperParamValues: + ) -> ParamValues: company_constraint = self._get_company_constraint(company_id, allow_public) project_constraint = self._get_project_constraint( project_ids, include_subprojects @@ -158,8 +162,12 @@ class ProjectQueries: redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}" last_update = last_updated_task.last_update or datetime.utcnow() - cached_res = self._get_cached_hyperparam_values( - key=redis_key, last_update=last_update + cached_res = self._get_cached_param_values( + key=redis_key, + last_update=last_update, + allowed_delta_sec=config.get( + "services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60 + ), ) if cached_res: return cached_res @@ -290,3 +298,73 @@ class ProjectQueries: remaining = max(0, total - (len(results) + page * page_size)) return total, remaining, results + + def get_model_metadata_distinct_values( + self, + company_id: str, + project_ids: Sequence[str], + key: str, + include_subprojects: bool, + allow_public: bool = True, + ) -> ParamValues: + company_constraint = self._get_company_constraint(company_id, allow_public) + project_constraint = self._get_project_constraint( + project_ids, include_subprojects + ) + key_path = f"metadata.{ParameterKeyEscaper.escape(key)}" + last_updated_model = ( + Model.objects( + **company_constraint, + **project_constraint, + **{f"{key_path.replace('.', '__')}__exists": True}, + ) + .only("last_update") + .order_by("-last_update") + .limit(1) + .first() + ) + if not last_updated_model: + return 0, [] + + redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}" + last_update = last_updated_model.last_update or datetime.utcnow() + cached_res = self._get_cached_param_values( + key=redis_key, last_update=last_update + ) + if cached_res: + return cached_res + + max_values = config.get("services.models.metadata_values.max_count", 100) + pipeline = [ + { + "$match": { + **company_constraint, + **project_constraint, + key_path: {"$exists": True}, + } + }, + {"$project": {"value": f"${key_path}.value"}}, + {"$group": {"_id": "$value"}}, + {"$sort": {"_id": 1}}, + {"$limit": max_values}, + { + "$group": { + "_id": 1, + "total": {"$sum": 1}, + "results": {"$push": "$$ROOT._id"}, + } + }, + ] + + result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None) + if not result: + return 0, [] + + total = int(result.get("total", 0)) + values = result.get("results", []) + + ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400) + cached = dict(last_update=last_update.timestamp(), total=total, values=values) + self.redis.setex(redis_key, ttl, json.dumps(cached)) + + return total, values diff --git a/apiserver/bll/project/sub_projects.py b/apiserver/bll/project/sub_projects.py index 51f7736..0c46a6e 100644 --- a/apiserver/bll/project/sub_projects.py +++ b/apiserver/bll/project/sub_projects.py @@ -25,7 +25,9 @@ def _validate_project_name(project_name: str) -> Tuple[str, str]: return name_separator.join(name_parts), name_separator.join(name_parts[:-1]) -def _ensure_project(company: str, user: str, name: str) -> Optional[Project]: +def _ensure_project( + company: str, user: str, name: str, creation_params: dict = None +) -> Optional[Project]: """ Makes sure that the project with the given name exists If needed auto-create the project and all the missing projects in the path to it @@ -48,9 +50,9 @@ def _ensure_project(company: str, user: str, name: str) -> Optional[Project]: created=now, last_update=now, name=name, - description="", + **(creation_params or dict(description="")), ) - parent = _ensure_project(company, user, location) + parent = _ensure_project(company, user, location, creation_params=creation_params) _save_under_parent(project=project, parent=parent) if parent: parent.update(last_update=now) diff --git a/apiserver/config/default/apiserver.conf b/apiserver/config/default/apiserver.conf index 51ab237..d8b1e0b 100644 --- a/apiserver/config/default/apiserver.conf +++ b/apiserver/config/default/apiserver.conf @@ -112,6 +112,8 @@ workers { # Auto-register unknown workers on status reports and other calls auto_register: true + # Assume unknow workers have unregistered (i.e. do not raise unregistered error) + auto_unregister: true # Timeout in seconds on task status update. If exceeded # then task can be stopped without communicating to the worker task_update_timeout: 600 diff --git a/apiserver/config/default/services/models.conf b/apiserver/config/default/services/models.conf new file mode 100644 index 0000000..a637440 --- /dev/null +++ b/apiserver/config/default/services/models.conf @@ -0,0 +1,7 @@ +metadata_values { + # maximal amount of distinct model values to retrieve + max_count: 100 + + # cache ttl sec + cache_ttl_sec: 86400 +} diff --git a/apiserver/schema/services/pipelines.conf b/apiserver/schema/services/pipelines.conf new file mode 100644 index 0000000..091ff9b --- /dev/null +++ b/apiserver/schema/services/pipelines.conf @@ -0,0 +1,47 @@ +_description: "Provides a management API for pipelines in the system." +_definitions { +} + +start_pipeline { + "2.17" { + description: "Start a pipeline" + request { + type: object + required: [ task ] + properties { + task { + description: "ID of the task on which the pipeline will be based" + type: string + } + queue { + description: "Queue ID in which the created pipeline task will be enqueued" + type: string + } + args { + description: "Task arguments, name/value to be placed in the hyperparameters Args section" + type: array + items { + type: object + properties { + name: { type: string } + value: { type: [string, null] } + } + } + } + } + } + response { + type: object + properties { + pipeline { + description: "ID of the new pipeline task" + type: string + } + enqueued { + description: "True if the task was successfuly enqueued" + type: boolean + } + } + } + } +} \ No newline at end of file diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 4a05850..0f47fdd 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -566,6 +566,19 @@ get_all_ex { default: true } } + "2.17": ${get_all_ex."2.16"} { + request.properties.include_stats_filter { + description: The filter for selecting entities that participate in statistics calculation + type: object + properties { + system_tags { + description: The list of allowed system tags + type: array + items { type: string } + } + } + } + } } update { "2.1" { @@ -913,6 +926,49 @@ get_hyper_parameters { } } } +get_model_metadata_values { + "2.17" { + description: """Get a list of distinct values for the chosen model metadata key""" + request { + type: object + required: [key] + properties { + projects { + description: "Project IDs" + type: array + items {type: string} + } + key { + description: "Metadata key" + type: string + } + allow_public { + description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'" + type: boolean + } + include_subprojects { + description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models" + type: boolean + default: true + } + } + } + response { + type: object + properties { + total { + description: "Total number of distinct values" + type: integer + } + values { + description: "The list of the unique values" + type: array + items {type: string} + } + } + } + } +} get_model_metadata_keys { "2.17" { description: """Get a list of all metadata keys used in models within the given project.""" @@ -962,6 +1018,13 @@ get_model_metadata_keys { } } } +get_project_tags { + "2.17" { + description: "Get user and system tags used for the specified projects and their children" + request = ${_definitions.tags_request} + response = ${_definitions.tags_response} + } +} get_task_tags { "2.8" { description: "Get user and system tags used for the tasks under the specified projects" diff --git a/apiserver/services/organization.py b/apiserver/services/organization.py index 7139adb..7bc82a9 100644 --- a/apiserver/services/organization.py +++ b/apiserver/services/organization.py @@ -5,7 +5,7 @@ from apiserver.apimodels.organization import TagsRequest from apiserver.bll.organization import OrgBLL, Tags from apiserver.database.model import User from apiserver.service_repo import endpoint, APICall -from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response +from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response org_bll = OrgBLL() @@ -21,17 +21,13 @@ def get_tags(call: APICall, company, request: TagsRequest): for field, vals in tags.items(): ret[field] |= vals - call.result.data = get_tags_response(ret) + call.result.data = sort_tags_response(ret) @endpoint("organization.get_user_companies") def get_user_companies(call: APICall, company_id: str, _): users = [ - { - "id": u.id, - "name": u.name, - "avatar": u.avatar, - } + {"id": u.id, "name": u.name, "avatar": u.avatar} for u in User.objects(company=company_id).only("avatar", "name", "company") ] diff --git a/apiserver/services/pipelines.py b/apiserver/services/pipelines.py new file mode 100644 index 0000000..76d2657 --- /dev/null +++ b/apiserver/services/pipelines.py @@ -0,0 +1,68 @@ +import re + +from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest +from apiserver.bll.organization import OrgBLL +from apiserver.bll.project import ProjectBLL +from apiserver.bll.task import TaskBLL +from apiserver.bll.task.task_operations import enqueue_task +from apiserver.database.model.project import Project +from apiserver.database.model.task.task import Task +from apiserver.service_repo import APICall, endpoint + +org_bll = OrgBLL() +project_bll = ProjectBLL() +task_bll = TaskBLL() + + +def _update_task_name(task: Task): + if not task or not task.project: + return + + project = Project.objects(id=task.project).only("name").first() + if not project: + return + + _, _, name_prefix = project.name.rpartition("/") + name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$") + count = Task.objects( + project=task.project, system_tags__in=["pipeline"], name=name_mask + ).count() + new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix + task.update(name=new_name) + + +@endpoint( + "pipelines.start_pipeline", response_data_model=StartPipelineResponse, +) +def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest): + hyperparams = None + if request.args: + hyperparams = { + "Args": { + str(arg.name): { + "section": "Args", + "name": str(arg.name), + "value": str(arg.value), + } + for arg in request.args or [] + } + } + + task, _ = task_bll.clone_task( + company_id=company_id, + user_id=call.identity.user, + task_id=request.task, + hyperparams=hyperparams, + ) + + _update_task_name(task) + + queued, res = enqueue_task( + task_id=task.id, + company_id=company_id, + queue_id=request.queue, + status_message="Starting pipeline", + status_reason="", + ) + + return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued)) diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 098e2ab..c2b72cc 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -17,6 +17,7 @@ from apiserver.apimodels.projects import ( MergeRequest, ProjectOrNoneRequest, ProjectRequest, + ProjectModelMetadataValuesRequest, ) from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL, ProjectQueries @@ -35,7 +36,7 @@ from apiserver.services.utils import ( conform_tag_fields, conform_output_tags, get_tags_filter_dictionary, - get_tags_response, + sort_tags_response, ) from apiserver.timing_context import TimingContext @@ -124,7 +125,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): } if existing_requested_ids: contents = project_bll.calc_own_contents( - company=company_id, project_ids=list(existing_requested_ids) + company=company_id, + project_ids=list(existing_requested_ids), + filter_=request.include_stats_filter, ) for project in projects: project.update(**contents.get(project["id"], {})) @@ -140,6 +143,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): project_ids=list(project_ids), specific_state=request.stats_for_state, include_children=request.stats_with_children, + filter_=request.include_stats_filter, ) for project in projects: @@ -292,6 +296,23 @@ def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRe } +@endpoint("projects.get_model_metadata_values") +def get_model_metadata_values( + call: APICall, company_id: str, request: ProjectModelMetadataValuesRequest +): + total, values = project_queries.get_model_metadata_distinct_values( + company_id, + project_ids=request.projects, + key=request.key, + include_subprojects=request.include_subprojects, + allow_public=request.allow_public, + ) + call.result.data = { + "total": total, + "values": values, + } + + @endpoint( "projects.get_hyper_parameters", min_version="2.9", @@ -322,7 +343,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsReque def get_hyperparam_values( call: APICall, company_id: str, request: ProjectHyperparamValuesRequest ): - total, values = project_queries.get_hyperparam_distinct_values( + total, values = project_queries.get_task_hyperparam_distinct_values( company_id, project_ids=request.projects, section=request.section, @@ -336,6 +357,17 @@ def get_hyperparam_values( } +@endpoint("projects.get_project_tags") +def get_tags(call: APICall, company, request: ProjectTagsRequest): + tags, system_tags = project_bll.get_project_tags( + company, + include_system=request.include_system, + filter_=get_tags_filter_dictionary(request.filter), + projects=request.projects, + ) + call.result.data = sort_tags_response({"tags": tags, "system_tags": system_tags}) + + @endpoint( "projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest ) @@ -347,7 +379,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest): filter_=get_tags_filter_dictionary(request.filter), projects=request.projects, ) - call.result.data = get_tags_response(ret) + call.result.data = sort_tags_response(ret) @endpoint( @@ -361,7 +393,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest): filter_=get_tags_filter_dictionary(request.filter), projects=request.projects, ) - call.result.data = get_tags_response(ret) + call.result.data = sort_tags_response(ret) @endpoint( diff --git a/apiserver/services/utils.py b/apiserver/services/utils.py index a7893dd..f94d9ad 100644 --- a/apiserver/services/utils.py +++ b/apiserver/services/utils.py @@ -23,7 +23,7 @@ def get_tags_filter_dictionary(input_: Filter) -> dict: } -def get_tags_response(ret: dict) -> dict: +def sort_tags_response(ret: dict) -> dict: return {field: sorted(vals) for field, vals in ret.items()} diff --git a/apiserver/tests/automated/test_project_tags.py b/apiserver/tests/automated/test_project_tags.py index 95b09f4..f6647a7 100644 --- a/apiserver/tests/automated/test_project_tags.py +++ b/apiserver/tests/automated/test_project_tags.py @@ -4,10 +4,29 @@ from apiserver.tests.automated import TestService class TestProjectTags(TestService): - def setUp(self, version="2.12"): - super().setUp(version=version) + def test_project_own_tags(self): + p1_tags = ["Tag 1", "Tag 2"] + p1 = self.create_temp( + "projects", name="Test project tags1", description="test", tags=p1_tags + ) + p2_tags = ["Tag 1", "Tag 3"] + p2 = self.create_temp( + "projects", + name="Test project tags2", + description="test", + tags=p2_tags, + system_tags=["hidden"], + ) - def test_project_tags(self): + res = self.api.projects.get_project_tags(projects=[p1, p2]) + self.assertEqual(set(res.tags), set(p1_tags) | set(p2_tags)) + + res = self.api.projects.get_project_tags( + projects=[p1, p2], filter={"system_tags": ["__$not", "hidden"]} + ) + self.assertEqual(res.tags, p1_tags) + + def test_project_entities_tags(self): tags_1 = ["Test tag 1", "Test tag 2"] tags_2 = ["Test tag 3", "Test tag 4"] diff --git a/apiserver/tests/automated/test_queue_model_metadata.py b/apiserver/tests/automated/test_queue_model_metadata.py index 2f57731..7d9f509 100644 --- a/apiserver/tests/automated/test_queue_model_metadata.py +++ b/apiserver/tests/automated/test_queue_model_metadata.py @@ -28,25 +28,33 @@ class TestQueueAndModelMetadata(TestService): def test_project_meta_query(self): self._temp_model("TestMetadata", metadata=self.meta1) project = self.temp_project(name="MetaParent") + test_key = "test_key" + test_key2 = "test_key2" + test_value = "test_value" + test_value2 = "test_value2" model_id = self._temp_model( "TestMetadata2", project=project, metadata={ - "test_key": {"key": "test_key", "type": "str", "value": "test_value"}, - "test_key2": {"key": "test_key2", "type": "str", "value": "test_value"}, + test_key: {"key": test_key, "type": "str", "value": test_value}, + test_key2: {"key": test_key2, "type": "str", "value": test_value2}, }, ) res = self.api.projects.get_model_metadata_keys() - self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"]))) + self.assertTrue({test_key, test_key2}.issubset(set(res["keys"]))) res = self.api.projects.get_model_metadata_keys(include_subprojects=False) - self.assertTrue("test_key" in res["keys"]) - self.assertFalse("test_key2" in res["keys"]) + self.assertTrue(test_key in res["keys"]) + self.assertFalse(test_key2 in res["keys"]) model = self.api.models.get_all_ex( id=[model_id], only_fields=["metadata.test_key"] ).models[0] - self.assertTrue("test_key" in model.metadata) - self.assertFalse("test_key2" in model.metadata) + self.assertTrue(test_key in model.metadata) + self.assertFalse(test_key2 in model.metadata) + + res = self.api.projects.get_model_metadata_values(key=test_key) + self.assertEqual(res.total, 1) + self.assertEqual(res["values"], [test_value]) def _test_meta_operations( self, service: APIClient.Service, entity: str, _id: str, diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index bcd6c21..fc1aa47 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -199,10 +199,10 @@ class TestSubProjects(TestService): res1 = next(p for p in res if p.id == project1) self.assertEqual(res1.stats["active"]["status_count"]["created"], 0) self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2) + self.assertEqual(res1.stats["active"]["status_count"]["in_progress"], 0) self.assertEqual(res1.stats["active"]["total_runtime"], 2) - self.assertEqual(res1.stats["active"]["completed_tasks"], 2) + self.assertEqual(res1.stats["active"]["completed_tasks_24h"], 2) self.assertEqual(res1.stats["active"]["total_tasks"], 2) - self.assertEqual(res1.stats["active"]["running_tasks"], 0) self.assertEqual( {sp.name for sp in res1.sub_projects}, { @@ -214,10 +214,10 @@ class TestSubProjects(TestService): res2 = next(p for p in res if p.id == project2) self.assertEqual(res2.stats["active"]["status_count"]["created"], 0) self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0) + self.assertEqual(res2.stats["active"]["status_count"]["in_progress"], 0) + self.assertEqual(res2.stats["active"]["status_count"]["completed"], 0) self.assertEqual(res2.stats["active"]["total_runtime"], 0) - self.assertEqual(res2.stats["active"]["completed_tasks"], 0) self.assertEqual(res2.stats["active"]["total_tasks"], 0) - self.assertEqual(res2.stats["active"]["running_tasks"], 0) self.assertEqual(res2.sub_projects, []) def _run_tasks(self, *tasks): diff --git a/apiserver/tests/automated/test_tags.py b/apiserver/tests/automated/test_tags.py index 231aacc..014c45c 100644 --- a/apiserver/tests/automated/test_tags.py +++ b/apiserver/tests/automated/test_tags.py @@ -133,6 +133,32 @@ class TestTags(TestService): ).models self.assertFound(model_id, [], models) + def testQueueTags(self): + q_id = self._temp_queue(system_tags=["default"]) + queues = self.api.queues.get_all_ex( + name="Test tags", system_tags=["default"] + ).queues + self.assertFound(q_id, ["default"], queues) + + queues = self.api.queues.get_all_ex( + name="Test tags", system_tags=["-default"] + ).queues + self.assertNotFound(q_id, queues) + + self.api.queues.update(queue=q_id, system_tags=[]) + queues = self.api.queues.get_all_ex( + name="Test tags", system_tags=["-default"] + ).queues + self.assertFound(q_id, [], queues) + + # test default queue + queues = self.api.queues.get_all(system_tags=["default"]).queues + if queues: + self.assertEqual(queues[0].id, self.api.queues.get_default().id) + else: + self.api.queues.update(queue=q_id, system_tags=["default"]) + self.assertEqual(q_id, self.api.queues.get_default().id) + def testTaskTags(self): task_id = self._temp_task( name="Test tags", system_tags=["active"] @@ -169,38 +195,11 @@ class TestTags(TestService): task = self.api.tasks.get_by_id(task=task_id).task self.assertEqual(task.status, "stopped") - def testQueueTags(self): - q_id = self._temp_queue(system_tags=["default"]) - queues = self.api.queues.get_all_ex( - name="Test tags", system_tags=["default"] - ).queues - self.assertFound(q_id, ["default"], queues) - - queues = self.api.queues.get_all_ex( - name="Test tags", system_tags=["-default"] - ).queues - self.assertNotFound(q_id, queues) - - self.api.queues.update(queue=q_id, system_tags=[]) - queues = self.api.queues.get_all_ex( - name="Test tags", system_tags=["-default"] - ).queues - self.assertFound(q_id, [], queues) - - # test default queue - queues = self.api.queues.get_all(system_tags=["default"]).queues - if queues: - self.assertEqual(queues[0].id, self.api.queues.get_default().id) - else: - self.api.queues.update(queue=q_id, system_tags=["default"]) - self.assertEqual(q_id, self.api.queues.get_default().id) - def assertProjectStats(self, project: AttrDict): self.assertEqual(set(project.stats.keys()), {"active"}) self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0) - self.assertEqual(project.stats.active.completed_tasks, 1) + self.assertEqual(project.stats.active.completed_tasks_24h, 1) self.assertEqual(project.stats.active.total_tasks, 1) - self.assertEqual(project.stats.active.running_tasks, 0) for status, count in project.stats.active.status_count.items(): self.assertEqual(count, 1 if status == "stopped" else 0)