Add pipelines support

This commit is contained in:
allegroai 2022-03-15 16:28:59 +02:00
parent e1992e2054
commit da8a45072f
17 changed files with 501 additions and 93 deletions

View File

@ -0,0 +1,19 @@
from jsonmodels import models, fields
from apiserver.apimodels import ListField
class Arg(models.Base):
name = fields.StringField(required=True)
value = fields.StringField(required=True)
class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)

View File

@ -1,6 +1,6 @@
from jsonmodels import models, fields from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
@ -51,8 +51,14 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
allow_public = fields.BoolField(default=True) allow_public = fields.BoolField(default=True)
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
key = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectsGetRequest(models.Base): class ProjectsGetRequest(models.Base):
include_stats = fields.BoolField(default=False) include_stats = fields.BoolField(default=False)
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True) stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active) stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False) non_public = fields.BoolField(default=False)

View File

@ -14,6 +14,7 @@ from typing import (
TypeVar, TypeVar,
Callable, Callable,
Mapping, Mapping,
Any,
) )
from mongoengine import Q, Document from mongoengine import Q, Document
@ -22,6 +23,7 @@ from apiserver import database
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
@ -204,6 +206,7 @@ class ProjectBLL:
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
default_output_destination: str = None, default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str: ) -> str:
""" """
Create a new project. Create a new project.
@ -226,7 +229,12 @@ class ProjectBLL:
created=now, created=now,
last_update=now, last_update=now,
) )
parent = _ensure_project(company=company, user=user, name=location) parent = _ensure_project(
company=company,
user=user,
name=location,
creation_params=parent_creation_params,
)
_save_under_parent(project=project, parent=parent) _save_under_parent(project=project, parent=parent)
if parent: if parent:
parent.update(last_update=now) parent.update(last_update=now)
@ -244,13 +252,14 @@ class ProjectBLL:
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
default_output_destination: str = None, default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str: ) -> str:
""" """
Find a project named `project_name` or create a new one. Find a project named `project_name` or create a new one.
Returns project ID Returns project ID
""" """
if not project_id and not project_name: if not project_id and not project_name:
raise ValueError("project id or name required") raise errors.bad_request.ValidationError("project id or name required")
if project_id: if project_id:
project = Project.objects(company=company, id=project_id).only("id").first() project = Project.objects(company=company, id=project_id).only("id").first()
@ -271,6 +280,7 @@ class ProjectBLL:
tags=tags, tags=tags,
system_tags=system_tags, system_tags=system_tags,
default_output_destination=default_output_destination, default_output_destination=default_output_destination,
parent_creation_params=parent_creation_params,
) )
@classmethod @classmethod
@ -314,6 +324,7 @@ class ProjectBLL:
company_id: str, company_id: str,
project_ids: Sequence[str], project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None, specific_state: Optional[EntityVisibility] = None,
filter_: Mapping[str, Any] = None,
) -> Tuple[Sequence, Sequence]: ) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value archived = EntityVisibility.archived.value
@ -337,10 +348,9 @@ class ProjectBLL:
status_count_pipeline = [ status_count_pipeline = [
# count tasks per project per status # count tasks per project per status
{ {
"$match": { "$match": cls.get_match_conditions(
"company": {"$in": [None, "", company_id]}, company=company_id, project_ids=project_ids, filter_=filter_
"project": {"$in": project_ids}, )
}
}, },
ensure_valid_fields(), ensure_valid_fields(),
{ {
@ -455,8 +465,9 @@ class ProjectBLL:
# only count run time for these types of tasks # only count run time for these types of tasks
{ {
"$match": { "$match": {
"company": {"$in": [None, "", company_id]}, **cls.get_match_conditions(
"project": {"$in": project_ids}, company=company_id, project_ids=project_ids, filter_=filter_
),
**get_state_filter(), **get_state_filter(),
} }
}, },
@ -500,6 +511,7 @@ class ProjectBLL:
project_ids: Sequence[str], project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None, specific_state: Optional[EntityVisibility] = None,
include_children: bool = True, include_children: bool = True,
filter_: Mapping[str, Any] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]: ) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids: if not project_ids:
return {}, {} return {}, {}
@ -516,6 +528,7 @@ class ProjectBLL:
company, company,
project_ids=list(project_ids_with_children), project_ids=list(project_ids_with_children),
specific_state=specific_state, specific_state=specific_state,
filter_=filter_,
) )
default_counts = dict.fromkeys(get_options(TaskStatus), 0) default_counts = dict.fromkeys(get_options(TaskStatus), 0)
@ -589,10 +602,9 @@ class ProjectBLL:
return { return {
"status_count": project_section_statuses, "status_count": project_section_statuses,
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
"total_tasks": sum(project_section_statuses.values()), "total_tasks": sum(project_section_statuses.values()),
"total_runtime": project_runtime.get(section, 0), "total_runtime": project_runtime.get(section, 0),
"completed_tasks": project_runtime.get( "completed_tasks_24h": project_runtime.get(
f"{section}_recently_completed", 0 f"{section}_recently_completed", 0
), ),
"last_task_run": get_time_or_none( "last_task_run": get_time_or_none(
@ -652,6 +664,30 @@ class ProjectBLL:
return res return res
@classmethod
def get_project_tags(
cls,
company_id: str,
include_system: bool,
projects: Sequence[str] = None,
filter_: Dict[str, Sequence[str]] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
with TimingContext("mongo", "get_tags_from_db"):
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if projects:
query &= Q(id__in=_ids_with_children(projects))
tags = Project.objects(query).distinct("tags")
system_tags = (
Project.objects(query).distinct("system_tags") if include_system else []
)
return tags, system_tags
@classmethod @classmethod
def get_projects_with_active_user( def get_projects_with_active_user(
cls, cls,
@ -708,6 +744,7 @@ class ProjectBLL:
if include_subprojects: if include_subprojects:
projects = _ids_with_children(projects) projects = _ids_with_children(projects)
query &= Q(project__in=projects) query &= Q(project__in=projects)
if state == EntityVisibility.archived: if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value]) query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active: elif state == EntityVisibility.active:
@ -735,6 +772,7 @@ class ProjectBLL:
if project_ids: if project_ids:
project_ids = _ids_with_children(project_ids) project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids) query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type") res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types) return set(res).intersection(external_task_types)
@ -750,9 +788,34 @@ class ProjectBLL:
query &= Q(project__in=project_ids) query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework") return Model.objects(query).distinct(field="framework")
@staticmethod
def get_match_conditions(
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
):
conditions = {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
if not filter_:
return conditions
for field in ("tags", "system_tags"):
field_filter = filter_.get(field)
if not field_filter:
continue
if not isinstance(field_filter, list) or not all(
isinstance(t, str) for t in field_filter
):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
conditions[field] = {"$in": field_filter}
return conditions
@classmethod @classmethod
def calc_own_contents( def calc_own_contents(
cls, company: str, project_ids: Sequence[str] cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
) -> Dict[str, dict]: ) -> Dict[str, dict]:
""" """
Returns the amount of task/models per requested project Returns the amount of task/models per requested project
@ -764,13 +827,12 @@ class ProjectBLL:
pipeline = [ pipeline = [
{ {
"$match": { "$match": cls.get_match_conditions(
"company": {"$in": [None, "", company]}, company=company, project_ids=project_ids, filter_=filter_
"project": {"$in": project_ids}, )
}
}, },
{"$project": {"project": 1}}, {"$project": {"project": 1}},
{"$group": {"_id": "$project", "count": {"$sum": 1}}} {"$group": {"_id": "$project", "count": {"$sum": 1}}},
] ]
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict: def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:

View File

@ -1,6 +1,6 @@
import json import json
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime, timedelta from datetime import datetime
from typing import ( from typing import (
Sequence, Sequence,
Optional, Optional,
@ -28,12 +28,21 @@ class ProjectQueries:
def _get_project_constraint( def _get_project_constraint(
project_ids: Sequence[str], include_subprojects: bool project_ids: Sequence[str], include_subprojects: bool
) -> dict: ) -> dict:
"""
If passed projects is None means top level projects
If passed projects is empty means no project filtering
"""
if include_subprojects: if include_subprojects:
if project_ids is None: if not project_ids:
return {} return {}
project_ids = _ids_with_children(project_ids) project_ids = _ids_with_children(project_ids)
return {"project": {"$in": project_ids if project_ids is not None else [None]}} if project_ids is None:
project_ids = [None]
if not project_ids:
return {}
return {"project": {"$in": project_ids}}
@staticmethod @staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict: def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
@ -106,16 +115,11 @@ class ProjectQueries:
return total, remaining, results return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]] ParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values( def _get_cached_param_values(
self, key: str, last_update: datetime self, key: str, last_update: datetime, allowed_delta_sec=0
) -> Optional[HyperParamValues]: ) -> Optional[ParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try: try:
cached = self.redis.get(key) cached = self.redis.get(key)
if not cached: if not cached:
@ -123,12 +127,12 @@ class ProjectQueries:
data = json.loads(cached) data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"]) cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta: if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
return data["total"], data["values"] return data["total"], data["values"]
except Exception as ex: except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}") log.error(f"Error retrieving params cached values: {str(ex)}")
def get_hyperparam_distinct_values( def get_task_hyperparam_distinct_values(
self, self,
company_id: str, company_id: str,
project_ids: Sequence[str], project_ids: Sequence[str],
@ -136,7 +140,7 @@ class ProjectQueries:
name: str, name: str,
include_subprojects: bool, include_subprojects: bool,
allow_public: bool = True, allow_public: bool = True,
) -> HyperParamValues: ) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public) company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint( project_constraint = self._get_project_constraint(
project_ids, include_subprojects project_ids, include_subprojects
@ -158,8 +162,12 @@ class ProjectQueries:
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}" redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow() last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values( cached_res = self._get_cached_param_values(
key=redis_key, last_update=last_update key=redis_key,
last_update=last_update,
allowed_delta_sec=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
),
) )
if cached_res: if cached_res:
return cached_res return cached_res
@ -290,3 +298,73 @@ class ProjectQueries:
remaining = max(0, total - (len(results) + page * page_size)) remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results return total, remaining, results
def get_model_metadata_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
key: str,
include_subprojects: bool,
allow_public: bool = True,
) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
last_updated_model = (
Model.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_model:
return 0, []
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
last_update = last_updated_model.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.models.metadata_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values

View File

@ -25,7 +25,9 @@ def _validate_project_name(project_name: str) -> Tuple[str, str]:
return name_separator.join(name_parts), name_separator.join(name_parts[:-1]) return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]: def _ensure_project(
company: str, user: str, name: str, creation_params: dict = None
) -> Optional[Project]:
""" """
Makes sure that the project with the given name exists Makes sure that the project with the given name exists
If needed auto-create the project and all the missing projects in the path to it If needed auto-create the project and all the missing projects in the path to it
@ -48,9 +50,9 @@ def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
created=now, created=now,
last_update=now, last_update=now,
name=name, name=name,
description="", **(creation_params or dict(description="")),
) )
parent = _ensure_project(company, user, location) parent = _ensure_project(company, user, location, creation_params=creation_params)
_save_under_parent(project=project, parent=parent) _save_under_parent(project=project, parent=parent)
if parent: if parent:
parent.update(last_update=now) parent.update(last_update=now)

View File

@ -112,6 +112,8 @@
workers { workers {
# Auto-register unknown workers on status reports and other calls # Auto-register unknown workers on status reports and other calls
auto_register: true auto_register: true
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
auto_unregister: true
# Timeout in seconds on task status update. If exceeded # Timeout in seconds on task status update. If exceeded
# then task can be stopped without communicating to the worker # then task can be stopped without communicating to the worker
task_update_timeout: 600 task_update_timeout: 600

View File

@ -0,0 +1,7 @@
metadata_values {
# maximal amount of distinct model values to retrieve
max_count: 100
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@ -0,0 +1,47 @@
_description: "Provides a management API for pipelines in the system."
_definitions {
}
start_pipeline {
"2.17" {
description: "Start a pipeline"
request {
type: object
required: [ task ]
properties {
task {
description: "ID of the task on which the pipeline will be based"
type: string
}
queue {
description: "Queue ID in which the created pipeline task will be enqueued"
type: string
}
args {
description: "Task arguments, name/value to be placed in the hyperparameters Args section"
type: array
items {
type: object
properties {
name: { type: string }
value: { type: [string, null] }
}
}
}
}
}
response {
type: object
properties {
pipeline {
description: "ID of the new pipeline task"
type: string
}
enqueued {
description: "True if the task was successfuly enqueued"
type: boolean
}
}
}
}
}

View File

@ -566,6 +566,19 @@ get_all_ex {
default: true default: true
} }
} }
"2.17": ${get_all_ex."2.16"} {
request.properties.include_stats_filter {
description: The filter for selecting entities that participate in statistics calculation
type: object
properties {
system_tags {
description: The list of allowed system tags
type: array
items { type: string }
}
}
}
}
} }
update { update {
"2.1" { "2.1" {
@ -913,6 +926,49 @@ get_hyper_parameters {
} }
} }
} }
get_model_metadata_values {
"2.17" {
description: """Get a list of distinct values for the chosen model metadata key"""
request {
type: object
required: [key]
properties {
projects {
description: "Project IDs"
type: array
items {type: string}
}
key {
description: "Metadata key"
type: string
}
allow_public {
description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'"
type: boolean
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models"
type: boolean
default: true
}
}
}
response {
type: object
properties {
total {
description: "Total number of distinct values"
type: integer
}
values {
description: "The list of the unique values"
type: array
items {type: string}
}
}
}
}
}
get_model_metadata_keys { get_model_metadata_keys {
"2.17" { "2.17" {
description: """Get a list of all metadata keys used in models within the given project.""" description: """Get a list of all metadata keys used in models within the given project."""
@ -962,6 +1018,13 @@ get_model_metadata_keys {
} }
} }
} }
get_project_tags {
"2.17" {
description: "Get user and system tags used for the specified projects and their children"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}
get_task_tags { get_task_tags {
"2.8" { "2.8" {
description: "Get user and system tags used for the tasks under the specified projects" description: "Get user and system tags used for the tasks under the specified projects"

View File

@ -5,7 +5,7 @@ from apiserver.apimodels.organization import TagsRequest
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.database.model import User from apiserver.database.model import User
from apiserver.service_repo import endpoint, APICall from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
org_bll = OrgBLL() org_bll = OrgBLL()
@ -21,17 +21,13 @@ def get_tags(call: APICall, company, request: TagsRequest):
for field, vals in tags.items(): for field, vals in tags.items():
ret[field] |= vals ret[field] |= vals
call.result.data = get_tags_response(ret) call.result.data = sort_tags_response(ret)
@endpoint("organization.get_user_companies") @endpoint("organization.get_user_companies")
def get_user_companies(call: APICall, company_id: str, _): def get_user_companies(call: APICall, company_id: str, _):
users = [ users = [
{ {"id": u.id, "name": u.name, "avatar": u.avatar}
"id": u.id,
"name": u.name,
"avatar": u.avatar,
}
for u in User.objects(company=company_id).only("avatar", "name", "company") for u in User.objects(company=company_id).only("avatar", "name", "company")
] ]

View File

@ -0,0 +1,68 @@
import re
from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
def _update_task_name(task: Task):
if not task or not task.project:
return
project = Project.objects(id=task.project).only("name").first()
if not project:
return
_, _, name_prefix = project.name.rpartition("/")
name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$")
count = Task.objects(
project=task.project, system_tags__in=["pipeline"], name=name_mask
).count()
new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix
task.update(name=new_name)
@endpoint(
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
)
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
hyperparams = None
if request.args:
hyperparams = {
"Args": {
str(arg.name): {
"section": "Args",
"name": str(arg.name),
"value": str(arg.value),
}
for arg in request.args or []
}
}
task, _ = task_bll.clone_task(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=hyperparams,
)
_update_task_name(task)
queued, res = enqueue_task(
task_id=task.id,
company_id=company_id,
queue_id=request.queue,
status_message="Starting pipeline",
status_reason="",
)
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))

View File

@ -17,6 +17,7 @@ from apiserver.apimodels.projects import (
MergeRequest, MergeRequest,
ProjectOrNoneRequest, ProjectOrNoneRequest,
ProjectRequest, ProjectRequest,
ProjectModelMetadataValuesRequest,
) )
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries from apiserver.bll.project import ProjectBLL, ProjectQueries
@ -35,7 +36,7 @@ from apiserver.services.utils import (
conform_tag_fields, conform_tag_fields,
conform_output_tags, conform_output_tags,
get_tags_filter_dictionary, get_tags_filter_dictionary,
get_tags_response, sort_tags_response,
) )
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
@ -124,7 +125,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
} }
if existing_requested_ids: if existing_requested_ids:
contents = project_bll.calc_own_contents( contents = project_bll.calc_own_contents(
company=company_id, project_ids=list(existing_requested_ids) company=company_id,
project_ids=list(existing_requested_ids),
filter_=request.include_stats_filter,
) )
for project in projects: for project in projects:
project.update(**contents.get(project["id"], {})) project.update(**contents.get(project["id"], {}))
@ -140,6 +143,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
project_ids=list(project_ids), project_ids=list(project_ids),
specific_state=request.stats_for_state, specific_state=request.stats_for_state,
include_children=request.stats_with_children, include_children=request.stats_with_children,
filter_=request.include_stats_filter,
) )
for project in projects: for project in projects:
@ -292,6 +296,23 @@ def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRe
} }
@endpoint("projects.get_model_metadata_values")
def get_model_metadata_values(
call: APICall, company_id: str, request: ProjectModelMetadataValuesRequest
):
total, values = project_queries.get_model_metadata_distinct_values(
company_id,
project_ids=request.projects,
key=request.key,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
)
call.result.data = {
"total": total,
"values": values,
}
@endpoint( @endpoint(
"projects.get_hyper_parameters", "projects.get_hyper_parameters",
min_version="2.9", min_version="2.9",
@ -322,7 +343,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsReque
def get_hyperparam_values( def get_hyperparam_values(
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
): ):
total, values = project_queries.get_hyperparam_distinct_values( total, values = project_queries.get_task_hyperparam_distinct_values(
company_id, company_id,
project_ids=request.projects, project_ids=request.projects,
section=request.section, section=request.section,
@ -336,6 +357,17 @@ def get_hyperparam_values(
} }
@endpoint("projects.get_project_tags")
def get_tags(call: APICall, company, request: ProjectTagsRequest):
tags, system_tags = project_bll.get_project_tags(
company,
include_system=request.include_system,
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = sort_tags_response({"tags": tags, "system_tags": system_tags})
@endpoint( @endpoint(
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest "projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
) )
@ -347,7 +379,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
filter_=get_tags_filter_dictionary(request.filter), filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects, projects=request.projects,
) )
call.result.data = get_tags_response(ret) call.result.data = sort_tags_response(ret)
@endpoint( @endpoint(
@ -361,7 +393,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
filter_=get_tags_filter_dictionary(request.filter), filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects, projects=request.projects,
) )
call.result.data = get_tags_response(ret) call.result.data = sort_tags_response(ret)
@endpoint( @endpoint(

View File

@ -23,7 +23,7 @@ def get_tags_filter_dictionary(input_: Filter) -> dict:
} }
def get_tags_response(ret: dict) -> dict: def sort_tags_response(ret: dict) -> dict:
return {field: sorted(vals) for field, vals in ret.items()} return {field: sorted(vals) for field, vals in ret.items()}

View File

@ -4,10 +4,29 @@ from apiserver.tests.automated import TestService
class TestProjectTags(TestService): class TestProjectTags(TestService):
def setUp(self, version="2.12"): def test_project_own_tags(self):
super().setUp(version=version) p1_tags = ["Tag 1", "Tag 2"]
p1 = self.create_temp(
"projects", name="Test project tags1", description="test", tags=p1_tags
)
p2_tags = ["Tag 1", "Tag 3"]
p2 = self.create_temp(
"projects",
name="Test project tags2",
description="test",
tags=p2_tags,
system_tags=["hidden"],
)
def test_project_tags(self): res = self.api.projects.get_project_tags(projects=[p1, p2])
self.assertEqual(set(res.tags), set(p1_tags) | set(p2_tags))
res = self.api.projects.get_project_tags(
projects=[p1, p2], filter={"system_tags": ["__$not", "hidden"]}
)
self.assertEqual(res.tags, p1_tags)
def test_project_entities_tags(self):
tags_1 = ["Test tag 1", "Test tag 2"] tags_1 = ["Test tag 1", "Test tag 2"]
tags_2 = ["Test tag 3", "Test tag 4"] tags_2 = ["Test tag 3", "Test tag 4"]

View File

@ -28,25 +28,33 @@ class TestQueueAndModelMetadata(TestService):
def test_project_meta_query(self): def test_project_meta_query(self):
self._temp_model("TestMetadata", metadata=self.meta1) self._temp_model("TestMetadata", metadata=self.meta1)
project = self.temp_project(name="MetaParent") project = self.temp_project(name="MetaParent")
test_key = "test_key"
test_key2 = "test_key2"
test_value = "test_value"
test_value2 = "test_value2"
model_id = self._temp_model( model_id = self._temp_model(
"TestMetadata2", "TestMetadata2",
project=project, project=project,
metadata={ metadata={
"test_key": {"key": "test_key", "type": "str", "value": "test_value"}, test_key: {"key": test_key, "type": "str", "value": test_value},
"test_key2": {"key": "test_key2", "type": "str", "value": "test_value"}, test_key2: {"key": test_key2, "type": "str", "value": test_value2},
}, },
) )
res = self.api.projects.get_model_metadata_keys() res = self.api.projects.get_model_metadata_keys()
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"]))) self.assertTrue({test_key, test_key2}.issubset(set(res["keys"])))
res = self.api.projects.get_model_metadata_keys(include_subprojects=False) res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
self.assertTrue("test_key" in res["keys"]) self.assertTrue(test_key in res["keys"])
self.assertFalse("test_key2" in res["keys"]) self.assertFalse(test_key2 in res["keys"])
model = self.api.models.get_all_ex( model = self.api.models.get_all_ex(
id=[model_id], only_fields=["metadata.test_key"] id=[model_id], only_fields=["metadata.test_key"]
).models[0] ).models[0]
self.assertTrue("test_key" in model.metadata) self.assertTrue(test_key in model.metadata)
self.assertFalse("test_key2" in model.metadata) self.assertFalse(test_key2 in model.metadata)
res = self.api.projects.get_model_metadata_values(key=test_key)
self.assertEqual(res.total, 1)
self.assertEqual(res["values"], [test_value])
def _test_meta_operations( def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str, self, service: APIClient.Service, entity: str, _id: str,

View File

@ -199,10 +199,10 @@ class TestSubProjects(TestService):
res1 = next(p for p in res if p.id == project1) res1 = next(p for p in res if p.id == project1)
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0) self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2) self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
self.assertEqual(res1.stats["active"]["status_count"]["in_progress"], 0)
self.assertEqual(res1.stats["active"]["total_runtime"], 2) self.assertEqual(res1.stats["active"]["total_runtime"], 2)
self.assertEqual(res1.stats["active"]["completed_tasks"], 2) self.assertEqual(res1.stats["active"]["completed_tasks_24h"], 2)
self.assertEqual(res1.stats["active"]["total_tasks"], 2) self.assertEqual(res1.stats["active"]["total_tasks"], 2)
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
self.assertEqual( self.assertEqual(
{sp.name for sp in res1.sub_projects}, {sp.name for sp in res1.sub_projects},
{ {
@ -214,10 +214,10 @@ class TestSubProjects(TestService):
res2 = next(p for p in res if p.id == project2) res2 = next(p for p in res if p.id == project2)
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0) self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0) self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["in_progress"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["completed"], 0)
self.assertEqual(res2.stats["active"]["total_runtime"], 0) self.assertEqual(res2.stats["active"]["total_runtime"], 0)
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
self.assertEqual(res2.stats["active"]["total_tasks"], 0) self.assertEqual(res2.stats["active"]["total_tasks"], 0)
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
self.assertEqual(res2.sub_projects, []) self.assertEqual(res2.sub_projects, [])
def _run_tasks(self, *tasks): def _run_tasks(self, *tasks):

View File

@ -133,6 +133,32 @@ class TestTags(TestService):
).models ).models
self.assertFound(model_id, [], models) self.assertFound(model_id, [], models)
def testQueueTags(self):
q_id = self._temp_queue(system_tags=["default"])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["default"]
).queues
self.assertFound(q_id, ["default"], queues)
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertNotFound(q_id, queues)
self.api.queues.update(queue=q_id, system_tags=[])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertFound(q_id, [], queues)
# test default queue
queues = self.api.queues.get_all(system_tags=["default"]).queues
if queues:
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
else:
self.api.queues.update(queue=q_id, system_tags=["default"])
self.assertEqual(q_id, self.api.queues.get_default().id)
def testTaskTags(self): def testTaskTags(self):
task_id = self._temp_task( task_id = self._temp_task(
name="Test tags", system_tags=["active"] name="Test tags", system_tags=["active"]
@ -169,38 +195,11 @@ class TestTags(TestService):
task = self.api.tasks.get_by_id(task=task_id).task task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.status, "stopped") self.assertEqual(task.status, "stopped")
def testQueueTags(self):
q_id = self._temp_queue(system_tags=["default"])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["default"]
).queues
self.assertFound(q_id, ["default"], queues)
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertNotFound(q_id, queues)
self.api.queues.update(queue=q_id, system_tags=[])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertFound(q_id, [], queues)
# test default queue
queues = self.api.queues.get_all(system_tags=["default"]).queues
if queues:
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
else:
self.api.queues.update(queue=q_id, system_tags=["default"])
self.assertEqual(q_id, self.api.queues.get_default().id)
def assertProjectStats(self, project: AttrDict): def assertProjectStats(self, project: AttrDict):
self.assertEqual(set(project.stats.keys()), {"active"}) self.assertEqual(set(project.stats.keys()), {"active"})
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0) self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
self.assertEqual(project.stats.active.completed_tasks, 1) self.assertEqual(project.stats.active.completed_tasks_24h, 1)
self.assertEqual(project.stats.active.total_tasks, 1) self.assertEqual(project.stats.active.total_tasks, 1)
self.assertEqual(project.stats.active.running_tasks, 0)
for status, count in project.stats.active.status_count.items(): for status, count in project.stats.active.status_count.items():
self.assertEqual(count, 1 if status == "stopped" else 0) self.assertEqual(count, 1 if status == "stopped" else 0)