mirror of
https://github.com/clearml/clearml-server
synced 2025-04-28 17:51:24 +00:00
Add pipelines support
This commit is contained in:
parent
e1992e2054
commit
da8a45072f
19
apiserver/apimodels/pipelines.py
Normal file
19
apiserver/apimodels/pipelines.py
Normal file
@ -0,0 +1,19 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class Arg(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
value = fields.StringField(required=True)
|
||||
|
||||
|
||||
class StartPipelineRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
queue = fields.StringField(required=True)
|
||||
args = ListField(Arg)
|
||||
|
||||
|
||||
class StartPipelineResponse(models.Base):
|
||||
pipeline = fields.StringField(required=True)
|
||||
enqueued = fields.BoolField(required=True)
|
@ -1,6 +1,6 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField, ActualEnumField
|
||||
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
||||
from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
|
||||
@ -51,8 +51,14 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
|
||||
key = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
include_stats_filter = DictField()
|
||||
stats_with_children = fields.BoolField(default=True)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
|
@ -14,6 +14,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Callable,
|
||||
Mapping,
|
||||
Any,
|
||||
)
|
||||
|
||||
from mongoengine import Q, Document
|
||||
@ -22,6 +23,7 @@ from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility, AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
|
||||
@ -204,6 +206,7 @@ class ProjectBLL:
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new project.
|
||||
@ -226,7 +229,12 @@ class ProjectBLL:
|
||||
created=now,
|
||||
last_update=now,
|
||||
)
|
||||
parent = _ensure_project(company=company, user=user, name=location)
|
||||
parent = _ensure_project(
|
||||
company=company,
|
||||
user=user,
|
||||
name=location,
|
||||
creation_params=parent_creation_params,
|
||||
)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
@ -244,13 +252,14 @@ class ProjectBLL:
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find a project named `project_name` or create a new one.
|
||||
Returns project ID
|
||||
"""
|
||||
if not project_id and not project_name:
|
||||
raise ValueError("project id or name required")
|
||||
raise errors.bad_request.ValidationError("project id or name required")
|
||||
|
||||
if project_id:
|
||||
project = Project.objects(company=company, id=project_id).only("id").first()
|
||||
@ -271,6 +280,7 @@ class ProjectBLL:
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
default_output_destination=default_output_destination,
|
||||
parent_creation_params=parent_creation_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -314,6 +324,7 @@ class ProjectBLL:
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
@ -337,10 +348,9 @@ class ProjectBLL:
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company_id, project_ids=project_ids, filter_=filter_
|
||||
)
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
@ -455,8 +465,9 @@ class ProjectBLL:
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
**cls.get_match_conditions(
|
||||
company=company_id, project_ids=project_ids, filter_=filter_
|
||||
),
|
||||
**get_state_filter(),
|
||||
}
|
||||
},
|
||||
@ -500,6 +511,7 @@ class ProjectBLL:
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
include_children: bool = True,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
@ -516,6 +528,7 @@ class ProjectBLL:
|
||||
company,
|
||||
project_ids=list(project_ids_with_children),
|
||||
specific_state=specific_state,
|
||||
filter_=filter_,
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
@ -589,10 +602,9 @@ class ProjectBLL:
|
||||
|
||||
return {
|
||||
"status_count": project_section_statuses,
|
||||
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
|
||||
"total_tasks": sum(project_section_statuses.values()),
|
||||
"total_runtime": project_runtime.get(section, 0),
|
||||
"completed_tasks": project_runtime.get(
|
||||
"completed_tasks_24h": project_runtime.get(
|
||||
f"{section}_recently_completed", 0
|
||||
),
|
||||
"last_task_run": get_time_or_none(
|
||||
@ -652,6 +664,30 @@ class ProjectBLL:
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_project_tags(
|
||||
cls,
|
||||
company_id: str,
|
||||
include_system: bool,
|
||||
projects: Sequence[str] = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
with TimingContext("mongo", "get_tags_from_db"):
|
||||
query = Q(company=company_id)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
|
||||
if projects:
|
||||
query &= Q(id__in=_ids_with_children(projects))
|
||||
|
||||
tags = Project.objects(query).distinct("tags")
|
||||
system_tags = (
|
||||
Project.objects(query).distinct("system_tags") if include_system else []
|
||||
)
|
||||
return tags, system_tags
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
cls,
|
||||
@ -708,6 +744,7 @@ class ProjectBLL:
|
||||
if include_subprojects:
|
||||
projects = _ids_with_children(projects)
|
||||
query &= Q(project__in=projects)
|
||||
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
@ -735,6 +772,7 @@ class ProjectBLL:
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@ -750,9 +788,34 @@ class ProjectBLL:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
|
||||
@staticmethod
|
||||
def get_match_conditions(
|
||||
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
|
||||
):
|
||||
conditions = {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
if not filter_:
|
||||
return conditions
|
||||
|
||||
for field in ("tags", "system_tags"):
|
||||
field_filter = filter_.get(field)
|
||||
if not field_filter:
|
||||
continue
|
||||
if not isinstance(field_filter, list) or not all(
|
||||
isinstance(t, str) for t in field_filter
|
||||
):
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"List of strings expected for the field: {field}"
|
||||
)
|
||||
conditions[field] = {"$in": field_filter}
|
||||
|
||||
return conditions
|
||||
|
||||
@classmethod
|
||||
def calc_own_contents(
|
||||
cls, company: str, project_ids: Sequence[str]
|
||||
cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of task/models per requested project
|
||||
@ -764,13 +827,12 @@ class ProjectBLL:
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company, project_ids=project_ids, filter_=filter_
|
||||
)
|
||||
},
|
||||
{"$project": {"project": 1}},
|
||||
{"$group": {"_id": "$project", "count": {"$sum": 1}}}
|
||||
{"$group": {"_id": "$project", "count": {"$sum": 1}}},
|
||||
]
|
||||
|
||||
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
@ -28,12 +28,21 @@ class ProjectQueries:
|
||||
def _get_project_constraint(
|
||||
project_ids: Sequence[str], include_subprojects: bool
|
||||
) -> dict:
|
||||
"""
|
||||
If passed projects is None means top level projects
|
||||
If passed projects is empty means no project filtering
|
||||
"""
|
||||
if include_subprojects:
|
||||
if project_ids is None:
|
||||
if not project_ids:
|
||||
return {}
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
|
||||
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
|
||||
if project_ids is None:
|
||||
project_ids = [None]
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
return {"project": {"$in": project_ids}}
|
||||
|
||||
@staticmethod
|
||||
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
|
||||
@ -106,16 +115,11 @@ class ProjectQueries:
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
ParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
def _get_cached_param_values(
|
||||
self, key: str, last_update: datetime, allowed_delta_sec=0
|
||||
) -> Optional[ParamValues]:
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
@ -123,12 +127,12 @@ class ProjectQueries:
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
log.error(f"Error retrieving params cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
def get_task_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
@ -136,7 +140,7 @@ class ProjectQueries:
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
@ -158,8 +162,12 @@ class ProjectQueries:
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key,
|
||||
last_update=last_update,
|
||||
allowed_delta_sec=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
),
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
@ -290,3 +298,73 @@ class ProjectQueries:
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
def get_model_metadata_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
key: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
|
||||
last_updated_model = (
|
||||
Model.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_model:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
|
||||
last_update = last_updated_model.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.models.metadata_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
@ -25,7 +25,9 @@ def _validate_project_name(project_name: str) -> Tuple[str, str]:
|
||||
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
|
||||
|
||||
|
||||
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
def _ensure_project(
|
||||
company: str, user: str, name: str, creation_params: dict = None
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Makes sure that the project with the given name exists
|
||||
If needed auto-create the project and all the missing projects in the path to it
|
||||
@ -48,9 +50,9 @@ def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name,
|
||||
description="",
|
||||
**(creation_params or dict(description="")),
|
||||
)
|
||||
parent = _ensure_project(company, user, location)
|
||||
parent = _ensure_project(company, user, location, creation_params=creation_params)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
@ -112,6 +112,8 @@
|
||||
workers {
|
||||
# Auto-register unknown workers on status reports and other calls
|
||||
auto_register: true
|
||||
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
|
||||
auto_unregister: true
|
||||
# Timeout in seconds on task status update. If exceeded
|
||||
# then task can be stopped without communicating to the worker
|
||||
task_update_timeout: 600
|
||||
|
7
apiserver/config/default/services/models.conf
Normal file
7
apiserver/config/default/services/models.conf
Normal file
@ -0,0 +1,7 @@
|
||||
metadata_values {
|
||||
# maximal amount of distinct model values to retrieve
|
||||
max_count: 100
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
47
apiserver/schema/services/pipelines.conf
Normal file
47
apiserver/schema/services/pipelines.conf
Normal file
@ -0,0 +1,47 @@
|
||||
_description: "Provides a management API for pipelines in the system."
|
||||
_definitions {
|
||||
}
|
||||
|
||||
start_pipeline {
|
||||
"2.17" {
|
||||
description: "Start a pipeline"
|
||||
request {
|
||||
type: object
|
||||
required: [ task ]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the task on which the pipeline will be based"
|
||||
type: string
|
||||
}
|
||||
queue {
|
||||
description: "Queue ID in which the created pipeline task will be enqueued"
|
||||
type: string
|
||||
}
|
||||
args {
|
||||
description: "Task arguments, name/value to be placed in the hyperparameters Args section"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
name: { type: string }
|
||||
value: { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
pipeline {
|
||||
description: "ID of the new pipeline task"
|
||||
type: string
|
||||
}
|
||||
enqueued {
|
||||
description: "True if the task was successfuly enqueued"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -566,6 +566,19 @@ get_all_ex {
|
||||
default: true
|
||||
}
|
||||
}
|
||||
"2.17": ${get_all_ex."2.16"} {
|
||||
request.properties.include_stats_filter {
|
||||
description: The filter for selecting entities that participate in statistics calculation
|
||||
type: object
|
||||
properties {
|
||||
system_tags {
|
||||
description: The list of allowed system tags
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
@ -913,6 +926,49 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_values {
|
||||
"2.17" {
|
||||
description: """Get a list of distinct values for the chosen model metadata key"""
|
||||
request {
|
||||
type: object
|
||||
required: [key]
|
||||
properties {
|
||||
projects {
|
||||
description: "Project IDs"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
key {
|
||||
description: "Metadata key"
|
||||
type: string
|
||||
}
|
||||
allow_public {
|
||||
description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
total {
|
||||
description: "Total number of distinct values"
|
||||
type: integer
|
||||
}
|
||||
values {
|
||||
description: "The list of the unique values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_keys {
|
||||
"2.17" {
|
||||
description: """Get a list of all metadata keys used in models within the given project."""
|
||||
@ -962,6 +1018,13 @@ get_model_metadata_keys {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_project_tags {
|
||||
"2.17" {
|
||||
description: "Get user and system tags used for the specified projects and their children"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
|
@ -5,7 +5,7 @@ from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.database.model import User
|
||||
from apiserver.service_repo import endpoint, APICall
|
||||
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
|
||||
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
|
||||
|
||||
org_bll = OrgBLL()
|
||||
|
||||
@ -21,17 +21,13 @@ def get_tags(call: APICall, company, request: TagsRequest):
|
||||
for field, vals in tags.items():
|
||||
ret[field] |= vals
|
||||
|
||||
call.result.data = get_tags_response(ret)
|
||||
call.result.data = sort_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint("organization.get_user_companies")
|
||||
def get_user_companies(call: APICall, company_id: str, _):
|
||||
users = [
|
||||
{
|
||||
"id": u.id,
|
||||
"name": u.name,
|
||||
"avatar": u.avatar,
|
||||
}
|
||||
{"id": u.id, "name": u.name, "avatar": u.avatar}
|
||||
for u in User.objects(company=company_id).only("avatar", "name", "company")
|
||||
]
|
||||
|
||||
|
68
apiserver/services/pipelines.py
Normal file
68
apiserver/services/pipelines.py
Normal file
@ -0,0 +1,68 @@
|
||||
import re
|
||||
|
||||
from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import enqueue_task
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
def _update_task_name(task: Task):
|
||||
if not task or not task.project:
|
||||
return
|
||||
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if not project:
|
||||
return
|
||||
|
||||
_, _, name_prefix = project.name.rpartition("/")
|
||||
name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$")
|
||||
count = Task.objects(
|
||||
project=task.project, system_tags__in=["pipeline"], name=name_mask
|
||||
).count()
|
||||
new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix
|
||||
task.update(name=new_name)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
|
||||
)
|
||||
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
|
||||
hyperparams = None
|
||||
if request.args:
|
||||
hyperparams = {
|
||||
"Args": {
|
||||
str(arg.name): {
|
||||
"section": "Args",
|
||||
"name": str(arg.name),
|
||||
"value": str(arg.value),
|
||||
}
|
||||
for arg in request.args or []
|
||||
}
|
||||
}
|
||||
|
||||
task, _ = task_bll.clone_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
hyperparams=hyperparams,
|
||||
)
|
||||
|
||||
_update_task_name(task)
|
||||
|
||||
queued, res = enqueue_task(
|
||||
task_id=task.id,
|
||||
company_id=company_id,
|
||||
queue_id=request.queue,
|
||||
status_message="Starting pipeline",
|
||||
status_reason="",
|
||||
)
|
||||
|
||||
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
|
@ -17,6 +17,7 @@ from apiserver.apimodels.projects import (
|
||||
MergeRequest,
|
||||
ProjectOrNoneRequest,
|
||||
ProjectRequest,
|
||||
ProjectModelMetadataValuesRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, ProjectQueries
|
||||
@ -35,7 +36,7 @@ from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
get_tags_filter_dictionary,
|
||||
get_tags_response,
|
||||
sort_tags_response,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
@ -124,7 +125,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
}
|
||||
if existing_requested_ids:
|
||||
contents = project_bll.calc_own_contents(
|
||||
company=company_id, project_ids=list(existing_requested_ids)
|
||||
company=company_id,
|
||||
project_ids=list(existing_requested_ids),
|
||||
filter_=request.include_stats_filter,
|
||||
)
|
||||
for project in projects:
|
||||
project.update(**contents.get(project["id"], {}))
|
||||
@ -140,6 +143,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
project_ids=list(project_ids),
|
||||
specific_state=request.stats_for_state,
|
||||
include_children=request.stats_with_children,
|
||||
filter_=request.include_stats_filter,
|
||||
)
|
||||
|
||||
for project in projects:
|
||||
@ -292,6 +296,23 @@ def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRe
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.get_model_metadata_values")
|
||||
def get_model_metadata_values(
|
||||
call: APICall, company_id: str, request: ProjectModelMetadataValuesRequest
|
||||
):
|
||||
total, values = project_queries.get_model_metadata_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
key=request.key,
|
||||
include_subprojects=request.include_subprojects,
|
||||
allow_public=request.allow_public,
|
||||
)
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"values": values,
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.9",
|
||||
@ -322,7 +343,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsReque
|
||||
def get_hyperparam_values(
|
||||
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
||||
):
|
||||
total, values = project_queries.get_hyperparam_distinct_values(
|
||||
total, values = project_queries.get_task_hyperparam_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
section=request.section,
|
||||
@ -336,6 +357,17 @@ def get_hyperparam_values(
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.get_project_tags")
|
||||
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
tags, system_tags = project_bll.get_project_tags(
|
||||
company,
|
||||
include_system=request.include_system,
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = sort_tags_response({"tags": tags, "system_tags": system_tags})
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||
)
|
||||
@ -347,7 +379,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
call.result.data = sort_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@ -361,7 +393,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
call.result.data = sort_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
|
@ -23,7 +23,7 @@ def get_tags_filter_dictionary(input_: Filter) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def get_tags_response(ret: dict) -> dict:
|
||||
def sort_tags_response(ret: dict) -> dict:
|
||||
return {field: sorted(vals) for field, vals in ret.items()}
|
||||
|
||||
|
||||
|
@ -4,10 +4,29 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestProjectTags(TestService):
|
||||
def setUp(self, version="2.12"):
|
||||
super().setUp(version=version)
|
||||
def test_project_own_tags(self):
|
||||
p1_tags = ["Tag 1", "Tag 2"]
|
||||
p1 = self.create_temp(
|
||||
"projects", name="Test project tags1", description="test", tags=p1_tags
|
||||
)
|
||||
p2_tags = ["Tag 1", "Tag 3"]
|
||||
p2 = self.create_temp(
|
||||
"projects",
|
||||
name="Test project tags2",
|
||||
description="test",
|
||||
tags=p2_tags,
|
||||
system_tags=["hidden"],
|
||||
)
|
||||
|
||||
def test_project_tags(self):
|
||||
res = self.api.projects.get_project_tags(projects=[p1, p2])
|
||||
self.assertEqual(set(res.tags), set(p1_tags) | set(p2_tags))
|
||||
|
||||
res = self.api.projects.get_project_tags(
|
||||
projects=[p1, p2], filter={"system_tags": ["__$not", "hidden"]}
|
||||
)
|
||||
self.assertEqual(res.tags, p1_tags)
|
||||
|
||||
def test_project_entities_tags(self):
|
||||
tags_1 = ["Test tag 1", "Test tag 2"]
|
||||
tags_2 = ["Test tag 3", "Test tag 4"]
|
||||
|
||||
|
@ -28,25 +28,33 @@ class TestQueueAndModelMetadata(TestService):
|
||||
def test_project_meta_query(self):
|
||||
self._temp_model("TestMetadata", metadata=self.meta1)
|
||||
project = self.temp_project(name="MetaParent")
|
||||
test_key = "test_key"
|
||||
test_key2 = "test_key2"
|
||||
test_value = "test_value"
|
||||
test_value2 = "test_value2"
|
||||
model_id = self._temp_model(
|
||||
"TestMetadata2",
|
||||
project=project,
|
||||
metadata={
|
||||
"test_key": {"key": "test_key", "type": "str", "value": "test_value"},
|
||||
"test_key2": {"key": "test_key2", "type": "str", "value": "test_value"},
|
||||
test_key: {"key": test_key, "type": "str", "value": test_value},
|
||||
test_key2: {"key": test_key2, "type": "str", "value": test_value2},
|
||||
},
|
||||
)
|
||||
res = self.api.projects.get_model_metadata_keys()
|
||||
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
|
||||
self.assertTrue({test_key, test_key2}.issubset(set(res["keys"])))
|
||||
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
|
||||
self.assertTrue("test_key" in res["keys"])
|
||||
self.assertFalse("test_key2" in res["keys"])
|
||||
self.assertTrue(test_key in res["keys"])
|
||||
self.assertFalse(test_key2 in res["keys"])
|
||||
|
||||
model = self.api.models.get_all_ex(
|
||||
id=[model_id], only_fields=["metadata.test_key"]
|
||||
).models[0]
|
||||
self.assertTrue("test_key" in model.metadata)
|
||||
self.assertFalse("test_key2" in model.metadata)
|
||||
self.assertTrue(test_key in model.metadata)
|
||||
self.assertFalse(test_key2 in model.metadata)
|
||||
|
||||
res = self.api.projects.get_model_metadata_values(key=test_key)
|
||||
self.assertEqual(res.total, 1)
|
||||
self.assertEqual(res["values"], [test_value])
|
||||
|
||||
def _test_meta_operations(
|
||||
self, service: APIClient.Service, entity: str, _id: str,
|
||||
|
@ -199,10 +199,10 @@ class TestSubProjects(TestService):
|
||||
res1 = next(p for p in res if p.id == project1)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["in_progress"], 0)
|
||||
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
|
||||
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
|
||||
self.assertEqual(res1.stats["active"]["completed_tasks_24h"], 2)
|
||||
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
|
||||
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
|
||||
self.assertEqual(
|
||||
{sp.name for sp in res1.sub_projects},
|
||||
{
|
||||
@ -214,10 +214,10 @@ class TestSubProjects(TestService):
|
||||
res2 = next(p for p in res if p.id == project2)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["in_progress"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["completed"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
|
||||
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
|
||||
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
|
||||
self.assertEqual(res2.sub_projects, [])
|
||||
|
||||
def _run_tasks(self, *tasks):
|
||||
|
@ -133,6 +133,32 @@ class TestTags(TestService):
|
||||
).models
|
||||
self.assertFound(model_id, [], models)
|
||||
|
||||
def testQueueTags(self):
|
||||
q_id = self._temp_queue(system_tags=["default"])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["default"]
|
||||
).queues
|
||||
self.assertFound(q_id, ["default"], queues)
|
||||
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertNotFound(q_id, queues)
|
||||
|
||||
self.api.queues.update(queue=q_id, system_tags=[])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertFound(q_id, [], queues)
|
||||
|
||||
# test default queue
|
||||
queues = self.api.queues.get_all(system_tags=["default"]).queues
|
||||
if queues:
|
||||
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
|
||||
else:
|
||||
self.api.queues.update(queue=q_id, system_tags=["default"])
|
||||
self.assertEqual(q_id, self.api.queues.get_default().id)
|
||||
|
||||
def testTaskTags(self):
|
||||
task_id = self._temp_task(
|
||||
name="Test tags", system_tags=["active"]
|
||||
@ -169,38 +195,11 @@ class TestTags(TestService):
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
self.assertEqual(task.status, "stopped")
|
||||
|
||||
def testQueueTags(self):
|
||||
q_id = self._temp_queue(system_tags=["default"])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["default"]
|
||||
).queues
|
||||
self.assertFound(q_id, ["default"], queues)
|
||||
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertNotFound(q_id, queues)
|
||||
|
||||
self.api.queues.update(queue=q_id, system_tags=[])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertFound(q_id, [], queues)
|
||||
|
||||
# test default queue
|
||||
queues = self.api.queues.get_all(system_tags=["default"]).queues
|
||||
if queues:
|
||||
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
|
||||
else:
|
||||
self.api.queues.update(queue=q_id, system_tags=["default"])
|
||||
self.assertEqual(q_id, self.api.queues.get_default().id)
|
||||
|
||||
def assertProjectStats(self, project: AttrDict):
|
||||
self.assertEqual(set(project.stats.keys()), {"active"})
|
||||
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
|
||||
self.assertEqual(project.stats.active.completed_tasks, 1)
|
||||
self.assertEqual(project.stats.active.completed_tasks_24h, 1)
|
||||
self.assertEqual(project.stats.active.total_tasks, 1)
|
||||
self.assertEqual(project.stats.active.running_tasks, 0)
|
||||
for status, count in project.stats.active.status_count.items():
|
||||
self.assertEqual(count, 1 if status == "stopped" else 0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user