Add pipelines support

This commit is contained in:
allegroai 2022-03-15 16:28:59 +02:00
parent e1992e2054
commit da8a45072f
17 changed files with 501 additions and 93 deletions

View File

@ -0,0 +1,19 @@
from jsonmodels import models, fields
from apiserver.apimodels import ListField
class Arg(models.Base):
name = fields.StringField(required=True)
value = fields.StringField(required=True)
class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)

View File

@ -1,6 +1,6 @@
from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility
@ -51,8 +51,14 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
allow_public = fields.BoolField(default=True)
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
key = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)

View File

@ -14,6 +14,7 @@ from typing import (
TypeVar,
Callable,
Mapping,
Any,
)
from mongoengine import Q, Document
@ -22,6 +23,7 @@ from apiserver import database
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
@ -204,6 +206,7 @@ class ProjectBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str:
"""
Create a new project.
@ -226,7 +229,12 @@ class ProjectBLL:
created=now,
last_update=now,
)
parent = _ensure_project(company=company, user=user, name=location)
parent = _ensure_project(
company=company,
user=user,
name=location,
creation_params=parent_creation_params,
)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)
@ -244,13 +252,14 @@ class ProjectBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str:
"""
Find a project named `project_name` or create a new one.
Returns project ID
"""
if not project_id and not project_name:
raise ValueError("project id or name required")
raise errors.bad_request.ValidationError("project id or name required")
if project_id:
project = Project.objects(company=company, id=project_id).only("id").first()
@ -271,6 +280,7 @@ class ProjectBLL:
tags=tags,
system_tags=system_tags,
default_output_destination=default_output_destination,
parent_creation_params=parent_creation_params,
)
@classmethod
@ -314,6 +324,7 @@ class ProjectBLL:
company_id: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
filter_: Mapping[str, Any] = None,
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
@ -337,10 +348,9 @@ class ProjectBLL:
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
"$match": cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
)
},
ensure_valid_fields(),
{
@ -455,8 +465,9 @@ class ProjectBLL:
# only count run time for these types of tasks
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
**cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
),
**get_state_filter(),
}
},
@ -500,6 +511,7 @@ class ProjectBLL:
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
include_children: bool = True,
filter_: Mapping[str, Any] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
@ -516,6 +528,7 @@ class ProjectBLL:
company,
project_ids=list(project_ids_with_children),
specific_state=specific_state,
filter_=filter_,
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
@ -589,10 +602,9 @@ class ProjectBLL:
return {
"status_count": project_section_statuses,
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
"total_tasks": sum(project_section_statuses.values()),
"total_runtime": project_runtime.get(section, 0),
"completed_tasks": project_runtime.get(
"completed_tasks_24h": project_runtime.get(
f"{section}_recently_completed", 0
),
"last_task_run": get_time_or_none(
@ -652,6 +664,30 @@ class ProjectBLL:
return res
@classmethod
def get_project_tags(
cls,
company_id: str,
include_system: bool,
projects: Sequence[str] = None,
filter_: Dict[str, Sequence[str]] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
with TimingContext("mongo", "get_tags_from_db"):
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if projects:
query &= Q(id__in=_ids_with_children(projects))
tags = Project.objects(query).distinct("tags")
system_tags = (
Project.objects(query).distinct("system_tags") if include_system else []
)
return tags, system_tags
@classmethod
def get_projects_with_active_user(
cls,
@ -708,6 +744,7 @@ class ProjectBLL:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
@ -735,6 +772,7 @@ class ProjectBLL:
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@ -750,9 +788,34 @@ class ProjectBLL:
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")
@staticmethod
def get_match_conditions(
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
):
conditions = {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
if not filter_:
return conditions
for field in ("tags", "system_tags"):
field_filter = filter_.get(field)
if not field_filter:
continue
if not isinstance(field_filter, list) or not all(
isinstance(t, str) for t in field_filter
):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
conditions[field] = {"$in": field_filter}
return conditions
@classmethod
def calc_own_contents(
cls, company: str, project_ids: Sequence[str]
cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
) -> Dict[str, dict]:
"""
Returns the amount of task/models per requested project
@ -764,13 +827,12 @@ class ProjectBLL:
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
"$match": cls.get_match_conditions(
company=company, project_ids=project_ids, filter_=filter_
)
},
{"$project": {"project": 1}},
{"$group": {"_id": "$project", "count": {"$sum": 1}}}
{"$group": {"_id": "$project", "count": {"$sum": 1}}},
]
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:

View File

@ -1,6 +1,6 @@
import json
from collections import OrderedDict
from datetime import datetime, timedelta
from datetime import datetime
from typing import (
Sequence,
Optional,
@ -28,12 +28,21 @@ class ProjectQueries:
def _get_project_constraint(
project_ids: Sequence[str], include_subprojects: bool
) -> dict:
"""
If passed projects is None means top level projects
If passed projects is empty means no project filtering
"""
if include_subprojects:
if project_ids is None:
if not project_ids:
return {}
project_ids = _ids_with_children(project_ids)
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
if project_ids is None:
project_ids = [None]
if not project_ids:
return {}
return {"project": {"$in": project_ids}}
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
@ -106,16 +115,11 @@ class ProjectQueries:
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
ParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
def _get_cached_param_values(
self, key: str, last_update: datetime, allowed_delta_sec=0
) -> Optional[ParamValues]:
try:
cached = self.redis.get(key)
if not cached:
@ -123,12 +127,12 @@ class ProjectQueries:
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
log.error(f"Error retrieving params cached values: {str(ex)}")
def get_hyperparam_distinct_values(
def get_task_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
@ -136,7 +140,7 @@ class ProjectQueries:
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
@ -158,8 +162,12 @@ class ProjectQueries:
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
cached_res = self._get_cached_param_values(
key=redis_key,
last_update=last_update,
allowed_delta_sec=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
),
)
if cached_res:
return cached_res
@ -290,3 +298,73 @@ class ProjectQueries:
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
def get_model_metadata_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
key: str,
include_subprojects: bool,
allow_public: bool = True,
) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
last_updated_model = (
Model.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_model:
return 0, []
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
last_update = last_updated_model.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.models.metadata_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values

View File

@ -25,7 +25,9 @@ def _validate_project_name(project_name: str) -> Tuple[str, str]:
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
def _ensure_project(
company: str, user: str, name: str, creation_params: dict = None
) -> Optional[Project]:
"""
Makes sure that the project with the given name exists
If needed auto-create the project and all the missing projects in the path to it
@ -48,9 +50,9 @@ def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
created=now,
last_update=now,
name=name,
description="",
**(creation_params or dict(description="")),
)
parent = _ensure_project(company, user, location)
parent = _ensure_project(company, user, location, creation_params=creation_params)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)

View File

@ -112,6 +112,8 @@
workers {
# Auto-register unknown workers on status reports and other calls
auto_register: true
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
auto_unregister: true
# Timeout in seconds on task status update. If exceeded
# then task can be stopped without communicating to the worker
task_update_timeout: 600

View File

@ -0,0 +1,7 @@
metadata_values {
# maximal amount of distinct model values to retrieve
max_count: 100
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@ -0,0 +1,47 @@
_description: "Provides a management API for pipelines in the system."
_definitions {
}
start_pipeline {
"2.17" {
description: "Start a pipeline"
request {
type: object
required: [ task ]
properties {
task {
description: "ID of the task on which the pipeline will be based"
type: string
}
queue {
description: "Queue ID in which the created pipeline task will be enqueued"
type: string
}
args {
description: "Task arguments, name/value to be placed in the hyperparameters Args section"
type: array
items {
type: object
properties {
name: { type: string }
value: { type: [string, null] }
}
}
}
}
}
response {
type: object
properties {
pipeline {
description: "ID of the new pipeline task"
type: string
}
enqueued {
description: "True if the task was successfuly enqueued"
type: boolean
}
}
}
}
}

View File

@ -566,6 +566,19 @@ get_all_ex {
default: true
}
}
"2.17": ${get_all_ex."2.16"} {
request.properties.include_stats_filter {
description: The filter for selecting entities that participate in statistics calculation
type: object
properties {
system_tags {
description: The list of allowed system tags
type: array
items { type: string }
}
}
}
}
}
update {
"2.1" {
@ -913,6 +926,49 @@ get_hyper_parameters {
}
}
}
get_model_metadata_values {
"2.17" {
description: """Get a list of distinct values for the chosen model metadata key"""
request {
type: object
required: [key]
properties {
projects {
description: "Project IDs"
type: array
items {type: string}
}
key {
description: "Metadata key"
type: string
}
allow_public {
description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'"
type: boolean
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models"
type: boolean
default: true
}
}
}
response {
type: object
properties {
total {
description: "Total number of distinct values"
type: integer
}
values {
description: "The list of the unique values"
type: array
items {type: string}
}
}
}
}
}
get_model_metadata_keys {
"2.17" {
description: """Get a list of all metadata keys used in models within the given project."""
@ -962,6 +1018,13 @@ get_model_metadata_keys {
}
}
}
get_project_tags {
"2.17" {
description: "Get user and system tags used for the specified projects and their children"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"

View File

@ -5,7 +5,7 @@ from apiserver.apimodels.organization import TagsRequest
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.database.model import User
from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
org_bll = OrgBLL()
@ -21,17 +21,13 @@ def get_tags(call: APICall, company, request: TagsRequest):
for field, vals in tags.items():
ret[field] |= vals
call.result.data = get_tags_response(ret)
call.result.data = sort_tags_response(ret)
@endpoint("organization.get_user_companies")
def get_user_companies(call: APICall, company_id: str, _):
users = [
{
"id": u.id,
"name": u.name,
"avatar": u.avatar,
}
{"id": u.id, "name": u.name, "avatar": u.avatar}
for u in User.objects(company=company_id).only("avatar", "name", "company")
]

View File

@ -0,0 +1,68 @@
import re
from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
def _update_task_name(task: Task):
if not task or not task.project:
return
project = Project.objects(id=task.project).only("name").first()
if not project:
return
_, _, name_prefix = project.name.rpartition("/")
name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$")
count = Task.objects(
project=task.project, system_tags__in=["pipeline"], name=name_mask
).count()
new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix
task.update(name=new_name)
@endpoint(
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
)
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
hyperparams = None
if request.args:
hyperparams = {
"Args": {
str(arg.name): {
"section": "Args",
"name": str(arg.name),
"value": str(arg.value),
}
for arg in request.args or []
}
}
task, _ = task_bll.clone_task(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=hyperparams,
)
_update_task_name(task)
queued, res = enqueue_task(
task_id=task.id,
company_id=company_id,
queue_id=request.queue,
status_message="Starting pipeline",
status_reason="",
)
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))

View File

@ -17,6 +17,7 @@ from apiserver.apimodels.projects import (
MergeRequest,
ProjectOrNoneRequest,
ProjectRequest,
ProjectModelMetadataValuesRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries
@ -35,7 +36,7 @@ from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
get_tags_filter_dictionary,
get_tags_response,
sort_tags_response,
)
from apiserver.timing_context import TimingContext
@ -124,7 +125,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
}
if existing_requested_ids:
contents = project_bll.calc_own_contents(
company=company_id, project_ids=list(existing_requested_ids)
company=company_id,
project_ids=list(existing_requested_ids),
filter_=request.include_stats_filter,
)
for project in projects:
project.update(**contents.get(project["id"], {}))
@ -140,6 +143,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
project_ids=list(project_ids),
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
filter_=request.include_stats_filter,
)
for project in projects:
@ -292,6 +296,23 @@ def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRe
}
@endpoint("projects.get_model_metadata_values")
def get_model_metadata_values(
call: APICall, company_id: str, request: ProjectModelMetadataValuesRequest
):
total, values = project_queries.get_model_metadata_distinct_values(
company_id,
project_ids=request.projects,
key=request.key,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
)
call.result.data = {
"total": total,
"values": values,
}
@endpoint(
"projects.get_hyper_parameters",
min_version="2.9",
@ -322,7 +343,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsReque
def get_hyperparam_values(
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
):
total, values = project_queries.get_hyperparam_distinct_values(
total, values = project_queries.get_task_hyperparam_distinct_values(
company_id,
project_ids=request.projects,
section=request.section,
@ -336,6 +357,17 @@ def get_hyperparam_values(
}
@endpoint("projects.get_project_tags")
def get_tags(call: APICall, company, request: ProjectTagsRequest):
tags, system_tags = project_bll.get_project_tags(
company,
include_system=request.include_system,
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = sort_tags_response({"tags": tags, "system_tags": system_tags})
@endpoint(
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
@ -347,7 +379,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)
call.result.data = sort_tags_response(ret)
@endpoint(
@ -361,7 +393,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)
call.result.data = sort_tags_response(ret)
@endpoint(

View File

@ -23,7 +23,7 @@ def get_tags_filter_dictionary(input_: Filter) -> dict:
}
def get_tags_response(ret: dict) -> dict:
def sort_tags_response(ret: dict) -> dict:
return {field: sorted(vals) for field, vals in ret.items()}

View File

@ -4,10 +4,29 @@ from apiserver.tests.automated import TestService
class TestProjectTags(TestService):
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_project_own_tags(self):
p1_tags = ["Tag 1", "Tag 2"]
p1 = self.create_temp(
"projects", name="Test project tags1", description="test", tags=p1_tags
)
p2_tags = ["Tag 1", "Tag 3"]
p2 = self.create_temp(
"projects",
name="Test project tags2",
description="test",
tags=p2_tags,
system_tags=["hidden"],
)
def test_project_tags(self):
res = self.api.projects.get_project_tags(projects=[p1, p2])
self.assertEqual(set(res.tags), set(p1_tags) | set(p2_tags))
res = self.api.projects.get_project_tags(
projects=[p1, p2], filter={"system_tags": ["__$not", "hidden"]}
)
self.assertEqual(res.tags, p1_tags)
def test_project_entities_tags(self):
tags_1 = ["Test tag 1", "Test tag 2"]
tags_2 = ["Test tag 3", "Test tag 4"]

View File

@ -28,25 +28,33 @@ class TestQueueAndModelMetadata(TestService):
def test_project_meta_query(self):
self._temp_model("TestMetadata", metadata=self.meta1)
project = self.temp_project(name="MetaParent")
test_key = "test_key"
test_key2 = "test_key2"
test_value = "test_value"
test_value2 = "test_value2"
model_id = self._temp_model(
"TestMetadata2",
project=project,
metadata={
"test_key": {"key": "test_key", "type": "str", "value": "test_value"},
"test_key2": {"key": "test_key2", "type": "str", "value": "test_value"},
test_key: {"key": test_key, "type": "str", "value": test_value},
test_key2: {"key": test_key2, "type": "str", "value": test_value2},
},
)
res = self.api.projects.get_model_metadata_keys()
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
self.assertTrue({test_key, test_key2}.issubset(set(res["keys"])))
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
self.assertTrue("test_key" in res["keys"])
self.assertFalse("test_key2" in res["keys"])
self.assertTrue(test_key in res["keys"])
self.assertFalse(test_key2 in res["keys"])
model = self.api.models.get_all_ex(
id=[model_id], only_fields=["metadata.test_key"]
).models[0]
self.assertTrue("test_key" in model.metadata)
self.assertFalse("test_key2" in model.metadata)
self.assertTrue(test_key in model.metadata)
self.assertFalse(test_key2 in model.metadata)
res = self.api.projects.get_model_metadata_values(key=test_key)
self.assertEqual(res.total, 1)
self.assertEqual(res["values"], [test_value])
def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str,

View File

@ -199,10 +199,10 @@ class TestSubProjects(TestService):
res1 = next(p for p in res if p.id == project1)
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
self.assertEqual(res1.stats["active"]["status_count"]["in_progress"], 0)
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
self.assertEqual(res1.stats["active"]["completed_tasks_24h"], 2)
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
self.assertEqual(
{sp.name for sp in res1.sub_projects},
{
@ -214,10 +214,10 @@ class TestSubProjects(TestService):
res2 = next(p for p in res if p.id == project2)
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["in_progress"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["completed"], 0)
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
self.assertEqual(res2.sub_projects, [])
def _run_tasks(self, *tasks):

View File

@ -133,6 +133,32 @@ class TestTags(TestService):
).models
self.assertFound(model_id, [], models)
def testQueueTags(self):
q_id = self._temp_queue(system_tags=["default"])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["default"]
).queues
self.assertFound(q_id, ["default"], queues)
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertNotFound(q_id, queues)
self.api.queues.update(queue=q_id, system_tags=[])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertFound(q_id, [], queues)
# test default queue
queues = self.api.queues.get_all(system_tags=["default"]).queues
if queues:
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
else:
self.api.queues.update(queue=q_id, system_tags=["default"])
self.assertEqual(q_id, self.api.queues.get_default().id)
def testTaskTags(self):
task_id = self._temp_task(
name="Test tags", system_tags=["active"]
@ -169,38 +195,11 @@ class TestTags(TestService):
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.status, "stopped")
def testQueueTags(self):
q_id = self._temp_queue(system_tags=["default"])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["default"]
).queues
self.assertFound(q_id, ["default"], queues)
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertNotFound(q_id, queues)
self.api.queues.update(queue=q_id, system_tags=[])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertFound(q_id, [], queues)
# test default queue
queues = self.api.queues.get_all(system_tags=["default"]).queues
if queues:
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
else:
self.api.queues.update(queue=q_id, system_tags=["default"])
self.assertEqual(q_id, self.api.queues.get_default().id)
def assertProjectStats(self, project: AttrDict):
self.assertEqual(set(project.stats.keys()), {"active"})
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
self.assertEqual(project.stats.active.completed_tasks, 1)
self.assertEqual(project.stats.active.completed_tasks_24h, 1)
self.assertEqual(project.stats.active.total_tasks, 1)
self.assertEqual(project.stats.active.running_tasks, 0)
for status, count in project.stats.active.status_count.items():
self.assertEqual(count, 1 if status == "stopped" else 0)