Compare commits

25 Commits

Author SHA1 Message Date
allegroai
69737308fe Version bump to v1.4.0 2022-04-18 16:38:22 +03:00
allegroai
a6dbea808a Add indices for task.last_update and task.status_changed 2022-04-18 16:37:22 +03:00
allegroai
5131b17901 Support not returning hidden sub-projects when include_stats is specified without search_hidden 2022-04-18 16:36:14 +03:00
allegroai
5f21c3a56d Add support for searching hidden projects and tasks 2022-04-18 16:34:18 +03:00
allegroai
2350ac64ed Fix internal error on count task events if there is no events index 2022-04-18 16:31:02 +03:00
allegroai
d146127c18 Add events.clear_scroll endpoint to clear event search scrolls 2022-04-18 16:29:57 +03:00
Mal Miller
abd65e103e Ensure agent-services waits for API server to be ready (#129) 2022-03-31 11:10:45 +03:00
pollfly
bf65ea7bd0 Resize admonitions (#126) 2022-03-27 15:04:43 +03:00
pollfly
73e278a8ed Add deprecation notes to legacy docs (#124) 2022-03-23 23:51:55 +02:00
Zied ANDOLSI
d92dfbbdb7 Allow ClearML to be served with a URL path prefix (#121)
* add server root url

* [Feature Request] Add proxy_pass for root url other than /

* [Feature Request] Add proxy_pass for root url other than /

* add support for web sub path

* add support for web sub path

* use default conf instead of created a custom one

* code reivew: move cp command in if block

* Add commented env var in the docker-compose file

Co-authored-by: Zied ANDOLSI <zandolsi@prophesee.ai>
2022-03-22 17:21:58 +02:00
Zied ANDOLSI
5c1e419eb5 Allow overriding clearml web git url on build (#122)
* add server root url

* [Feature Request] Add possibility to override clearml web git url

Co-authored-by: Zied ANDOLSI <zandolsi@prophesee.ai>
2022-03-17 14:35:50 +02:00
allegroai
124684f53f Version bump to v1.3.0 2022-03-15 16:34:35 +02:00
allegroai
455b5d6758 Fix pre-populate to convert model metadata from the old format 2022-03-15 16:30:14 +02:00
allegroai
c04e2e498b Support credentials label and last_used_from fields 2022-03-15 16:29:37 +02:00
allegroai
da8a45072f Add pipelines support 2022-03-15 16:28:59 +02:00
allegroai
e1992e2054 Fix queue metrics calculation 2022-03-15 16:28:49 +02:00
allegroai
c17cedd93a Support disabling response compression in fileserver 2022-03-15 16:27:31 +02:00
allegroai
b6ad8f8790 Add support for worker auto-unregister (instead of raising an error) 2022-03-15 16:25:14 +02:00
allegroai
5acc7eebc3 Set API version to 2.17 2022-03-15 16:22:51 +02:00
allegroai
941927dfcd Return fixed fileserver header 2022-03-15 16:21:52 +02:00
allegroai
02933a9c93 Support disabling response compression
Return fixed server header
2022-03-15 16:21:14 +02:00
allegroai
e537651f29 Better support for assets upload/download 2022-03-15 16:19:52 +02:00
allegroai
af09fba755 Add metadata dict support for models, queues
Add more info for projects
2022-03-15 16:18:57 +02:00
Reuben Morais
04ea9018a3 Add missing g++ dep to server build (#111) 2022-02-21 22:14:22 +02:00
allegroai
ff7e1be24f Updated docker-compose files for v1.2.0 2022-02-14 15:27:23 +02:00
70 changed files with 1621 additions and 570 deletions

View File

@@ -81,6 +81,7 @@ class Credentials(Base):
class CredentialsResponse(Credentials):
secret_key = StringField()
last_used = DateTimeField(default=None)
last_used_from = StringField()
class CreateCredentialsRequest(Base):

View File

@@ -137,3 +137,7 @@ class TaskPlotsRequest(Base):
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class ClearScrollRequest(Base):
scroll_id: str = StringField()

View File

@@ -1,7 +1,7 @@
from typing import Sequence
from jsonmodels import validators
from jsonmodels.fields import StringField
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
@@ -21,3 +21,4 @@ class AddOrUpdateMetadata(Base):
metadata: Sequence[MetadataItem] = ListField(
[MetadataItem], validators=validators.Length(minimum_value=1)
)
replace_metadata = BoolField(default=False)

View File

@@ -30,7 +30,7 @@ class CreateModelRequest(models.Base):
ready = fields.BoolField(default=True)
ui_cache = DictField()
task = fields.StringField()
metadata = ListField(items_types=[MetadataItem])
metadata = DictField(value_types=[MetadataItem])
class CreateModelResponse(models.Base):

View File

@@ -0,0 +1,19 @@
from jsonmodels import models, fields
from apiserver.apimodels import ListField
class Arg(models.Base):
name = fields.StringField(required=True)
value = fields.StringField(required=True)
class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)

View File

@@ -1,6 +1,6 @@
from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility
@@ -51,11 +51,18 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
allow_public = fields.BoolField(default=True)
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
key = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
active_users = fields.ListField(str)
check_own_contents = fields.BoolField(default=False)
shallow_search = fields.BoolField(default=False)
search_hidden = fields.BoolField(default=False)

View File

@@ -2,7 +2,7 @@ from jsonmodels import validators
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
@@ -19,13 +19,18 @@ class CreateRequest(Base):
name = StringField(required=True)
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
metadata = DictField(value_types=[MetadataItem])
class QueueRequest(Base):
queue = StringField(required=True)
class GetNextTaskRequest(QueueRequest):
queue = StringField(required=True)
get_task_info = BoolField(default=False)
class DeleteRequest(QueueRequest):
force = BoolField(default=False)
@@ -34,7 +39,7 @@ class UpdateRequest(QueueRequest):
name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
metadata = DictField(value_types=[MetadataItem])
class TaskRequest(QueueRequest):

View File

@@ -162,7 +162,7 @@ class AuthBLL:
access_key=get_client_id(), secret_key=get_secret_key(), label=label
)
user.credentials.append(
Credentials(key=cred.access_key, secret=cred.secret_key)
Credentials(key=cred.access_key, secret=cred.secret_key, label=label)
)
user.save()

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
from elasticsearch import helpers
import elasticsearch
from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
from nested_dict import nested_dict
@@ -48,6 +48,9 @@ MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
log = config.logger(__file__)
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
@@ -219,7 +222,7 @@ class EventBLL(object):
with TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
elasticsearch.helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
@@ -1005,3 +1008,16 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
def clear_scroll(self, scroll_id: str):
if scroll_id == self.empty_scroll:
return
# noinspection PyBroadException
try:
self.es.clear_scroll(scroll_id=scroll_id)
except elasticsearch.exceptions.NotFoundError:
pass
except elasticsearch.exceptions.RequestError:
pass
except Exception as ex:
log.exception("Failed clearing scroll %s", scroll_id)

View File

@@ -67,6 +67,9 @@ class EventsIterator:
task_id: str,
metric_variants: MetricVariants = None,
) -> int:
if check_empty_data(self.es, company_id, event_type):
return 0
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
es_req = {
"query": query,

View File

@@ -7,6 +7,7 @@ from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
from .metadata import Metadata
class ModelBLL:

View File

@@ -0,0 +1,111 @@
from typing import Sequence, Union, Mapping
from mongoengine import Document
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem
from apiserver.database.model.base import GetMixin
from apiserver.service_repo import APICall
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
)
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
class Metadata:
@staticmethod
def metadata_from_api(
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
) -> dict:
if not api_data:
return {}
if isinstance(api_data, dict):
return {
ParameterKeyEscaper.escape(k): v.to_struct()
for k, v in api_data.items()
}
return {
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
}
@classmethod
def edit_metadata(
cls,
obj: Document,
items: Sequence[MetadataItem],
replace_metadata: bool,
**more_updates,
) -> int:
with TimingContext("mongo", "edit_metadata"):
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
update_cmds["set__metadata"] = metadata
else:
for key, value in metadata.items():
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
return obj.update(**update_cmds, **more_updates)
@classmethod
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
with TimingContext("mongo", "delete_metadata"):
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
for key in set(keys)
},
**more_updates,
)
@staticmethod
def _process_path(path: str):
"""
Frontend does a partial escaping on the path so the all '.' in key names are escaped
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
raise errors.bad_request.ValidationError("invalid field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
)
@classmethod
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
for prefix in (
"metadata.",
"-metadata.",
):
paths = [
cls._process_path(path) if path.startswith(prefix) else path
for path in paths
]
return paths
@classmethod
def escape_query_parameters(cls, call: APICall) -> dict:
if not call.data:
return call.data
keys = list(call.data)
call_data = {
safe_key: call.data[key]
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
}
projection = GetMixin.get_projection(call_data)
if projection:
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
ordering = GetMixin.get_ordering(call_data)
if ordering:
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
return call_data

View File

@@ -6,6 +6,7 @@ from redis import Redis
from apiserver.config_repo import config
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model import EntityVisibility
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
@@ -42,6 +43,8 @@ class _TagsCache:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project__in=project_ids_with_children([project]))
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
return self.db_cls.objects(query).distinct(field)

View File

@@ -14,6 +14,7 @@ from typing import (
TypeVar,
Callable,
Mapping,
Any,
)
from mongoengine import Q, Document
@@ -22,6 +23,7 @@ from apiserver import database
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
@@ -204,6 +206,7 @@ class ProjectBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str:
"""
Create a new project.
@@ -226,7 +229,12 @@ class ProjectBLL:
created=now,
last_update=now,
)
parent = _ensure_project(company=company, user=user, name=location)
parent = _ensure_project(
company=company,
user=user,
name=location,
creation_params=parent_creation_params,
)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)
@@ -244,13 +252,14 @@ class ProjectBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str:
"""
Find a project named `project_name` or create a new one.
Returns project ID
"""
if not project_id and not project_name:
raise ValueError("project id or name required")
raise errors.bad_request.ValidationError("project id or name required")
if project_id:
project = Project.objects(company=company, id=project_id).only("id").first()
@@ -271,6 +280,7 @@ class ProjectBLL:
tags=tags,
system_tags=system_tags,
default_output_destination=default_output_destination,
parent_creation_params=parent_creation_params,
)
@classmethod
@@ -314,6 +324,7 @@ class ProjectBLL:
company_id: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
filter_: Mapping[str, Any] = None,
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
@@ -337,10 +348,9 @@ class ProjectBLL:
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
"$match": cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
)
},
ensure_valid_fields(),
{
@@ -388,6 +398,17 @@ class ProjectBLL:
}
}
def max_started_subquery(condition):
return {
"$max": {
"$cond": {
"if": condition,
"then": "$started",
"else": datetime.min,
}
}
}
def runtime_subquery(additional_cond):
return {
# the sum of
@@ -431,14 +452,23 @@ class ProjectBLL:
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
cond, time_thresh=time_thresh
)
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
def get_state_filter() -> dict:
if not specific_state:
return {}
if specific_state == EntityVisibility.archived:
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"company": {"$in": [None, "", company_id]},
"type": {"$in": ["training", "testing", "annotation"]},
"project": {"$in": project_ids},
**cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
),
**get_state_filter(),
}
},
ensure_valid_fields(),
@@ -481,12 +511,14 @@ class ProjectBLL:
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
include_children: bool = True,
return_hidden_children: bool = False,
filter_: Mapping[str, Any] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = (
_get_sub_projects(project_ids, _only=("id", "name"))
_get_sub_projects(project_ids, _only=("id", "name", "system_tags"))
if include_children
else {}
)
@@ -497,6 +529,7 @@ class ProjectBLL:
company,
project_ids=list(project_ids_with_children),
specific_state=specific_state,
filter_=filter_,
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
@@ -547,6 +580,8 @@ class ProjectBLL:
) -> Dict[str, dict]:
return {
section: a.get(section, 0) + b.get(section, 0)
if not section.endswith("max_task_started")
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
for section in set(a) | set(b)
}
@@ -562,14 +597,20 @@ class ProjectBLL:
project_section_statuses = nested_get(
status_count, (project_id, section), default=default_counts
)
def get_time_or_none(value):
return value if value != datetime.min else None
return {
"status_count": project_section_statuses,
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
"total_tasks": sum(project_section_statuses.values()),
"total_runtime": project_runtime.get(section, 0),
"completed_tasks": project_runtime.get(
"completed_tasks_24h": project_runtime.get(
f"{section}_recently_completed", 0
),
"last_task_run": get_time_or_none(
project_runtime.get(f"{section}_max_task_started", datetime.min)
),
}
report_for_states = [
@@ -586,9 +627,24 @@ class ProjectBLL:
for project in project_ids
}
def filter_child_projects(project: str) -> Sequence[Project]:
non_filtered_children = child_projects.get(project, [])
if not non_filtered_children or return_hidden_children:
return non_filtered_children
return [
c
for c in non_filtered_children
if not c.system_tags
or EntityVisibility.hidden.value not in c.system_tags
]
children = {
project: sorted(
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
[
{"id": c.id, "name": c.name}
for c in filter_child_projects(project)
],
key=itemgetter("name"),
)
for project in project_ids
@@ -624,6 +680,30 @@ class ProjectBLL:
return res
@classmethod
def get_project_tags(
cls,
company_id: str,
include_system: bool,
projects: Sequence[str] = None,
filter_: Dict[str, Sequence[str]] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
with TimingContext("mongo", "get_tags_from_db"):
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if projects:
query &= Q(id__in=_ids_with_children(projects))
tags = Project.objects(query).distinct("tags")
system_tags = (
Project.objects(query).distinct("system_tags") if include_system else []
)
return tags, system_tags
@classmethod
def get_projects_with_active_user(
cls,
@@ -676,10 +756,14 @@ class ProjectBLL:
If projects is None or empty then get parents for all the company tasks
"""
query = Q(company=company_id)
if projects:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
@@ -707,6 +791,8 @@ class ProjectBLL:
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
else:
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@@ -722,8 +808,35 @@ class ProjectBLL:
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")
@staticmethod
def get_match_conditions(
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
):
conditions = {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
if not filter_:
return conditions
for field in ("tags", "system_tags"):
field_filter = filter_.get(field)
if not field_filter:
continue
if not isinstance(field_filter, list) or not all(
isinstance(t, str) for t in field_filter
):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
conditions[field] = {"$in": field_filter}
return conditions
@classmethod
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
def calc_own_contents(
cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
) -> Dict[str, dict]:
"""
Returns the amount of task/models per requested project
Use separate aggregation calls on Task/Model instead of lookup
@@ -734,35 +847,21 @@ class ProjectBLL:
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
"$match": cls.get_match_conditions(
company=company, project_ids=project_ids, filter_=filter_
)
},
{
"$project": {"project": 1}
},
{
"$group": {
"_id": "$project",
"count": {"$sum": 1},
}
}
{"$project": {"project": 1}},
{"$group": {"_id": "$project", "count": {"$sum": 1}}},
]
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
return {
data["_id"]: data["count"]
for data in cls_.aggregate(pipeline)
}
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
with TimingContext("mongo", "get_security_groups"):
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {
"own_tasks": tasks.get(pid, 0),
"own_models": models.get(pid, 0),
}
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
for pid in project_ids
}

View File

@@ -1,6 +1,6 @@
import json
from collections import OrderedDict
from datetime import datetime, timedelta
from datetime import datetime
from typing import (
Sequence,
Optional,
@@ -10,6 +10,7 @@ from typing import (
from redis import StrictRedis
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get
@@ -27,12 +28,21 @@ class ProjectQueries:
def _get_project_constraint(
project_ids: Sequence[str], include_subprojects: bool
) -> dict:
"""
If passed projects is None means top level projects
If passed projects is empty means no project filtering
"""
if include_subprojects:
if project_ids is None:
if not project_ids:
return {}
project_ids = _ids_with_children(project_ids)
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
if project_ids is None:
project_ids = [None]
if not project_ids:
return {}
return {"project": {"$in": project_ids}}
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
@@ -105,16 +115,11 @@ class ProjectQueries:
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
ParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
def _get_cached_param_values(
self, key: str, last_update: datetime, allowed_delta_sec=0
) -> Optional[ParamValues]:
try:
cached = self.redis.get(key)
if not cached:
@@ -122,12 +127,12 @@ class ProjectQueries:
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
log.error(f"Error retrieving params cached values: {str(ex)}")
def get_hyperparam_distinct_values(
def get_task_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
@@ -135,7 +140,7 @@ class ProjectQueries:
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
@@ -157,8 +162,12 @@ class ProjectQueries:
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
cached_res = self._get_cached_param_values(
key=redis_key,
last_update=last_update,
allowed_delta_sec=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
),
)
if cached_res:
return cached_res
@@ -239,3 +248,123 @@ class ProjectQueries:
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
def get_model_metadata_keys(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"metadata": {"$exists": True, "$gt": {}},
}
},
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
{"$unwind": "$metadata"},
{"$group": {"_id": "$metadata.k"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Model.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
ParameterKeyEscaper.unescape(r.get("_id"))
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
def get_model_metadata_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
key: str,
include_subprojects: bool,
allow_public: bool = True,
) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
last_updated_model = (
Model.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_model:
return 0, []
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
last_update = last_updated_model.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.models.metadata_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values

View File

@@ -25,7 +25,9 @@ def _validate_project_name(project_name: str) -> Tuple[str, str]:
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
def _ensure_project(
company: str, user: str, name: str, creation_params: dict = None
) -> Optional[Project]:
"""
Makes sure that the project with the given name exists
If needed auto-create the project and all the missing projects in the path to it
@@ -48,9 +50,9 @@ def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
created=now,
last_update=now,
name=name,
description="",
**(creation_params or dict(description="")),
)
parent = _ensure_project(company, user, location)
parent = _ensure_project(company, user, location, creation_params=creation_params)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)

View File

@@ -32,7 +32,7 @@ class QueueBLL(object):
name: str,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[Sequence[dict]] = None,
metadata: Optional[dict] = None,
) -> Queue:
"""Creates a queue"""
with translate_errors_context():
@@ -187,13 +187,15 @@ class QueueBLL(object):
if any(e.task == task_id for e in queue.entries):
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
entry = Entry(added=datetime.utcnow(), task=task_id)
query = dict(id=queue_id, company=company_id)
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
push__entries=entry, last_update=datetime.utcnow(), upsert=False
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
if not res:
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
task=task_id, **query
@@ -233,7 +235,6 @@ class QueueBLL(object):
queue = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
entries_to_remove = [e for e in queue.entries if e.task == task_id]
query = dict(id=queue_id, company=company_id)
@@ -241,6 +242,9 @@ class QueueBLL(object):
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
return len(entries_to_remove) if res else 0
def reposition_task(

View File

@@ -113,7 +113,7 @@ class WorkerBLL:
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
if not res:
if not res and not config.get("apiserver.workers.auto_unregister", False):
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(

View File

@@ -112,6 +112,8 @@
workers {
# Auto-register unknown workers on status reports and other calls
auto_register: true
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
auto_unregister: true
# Timeout in seconds on task status update. If exceeded
# then task can be stopped without communicating to the worker
task_update_timeout: 600

View File

@@ -0,0 +1,7 @@
metadata_values {
# maximal amount of distinct model values to retrieve
max_count: 100
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@@ -60,3 +60,4 @@ def validate_id(cls, company, **kwargs):
class EntityVisibility(Enum):
active = "active"
archived = "archived"
hidden = "hidden"

View File

@@ -50,6 +50,7 @@ class Credentials(EmbeddedDocument):
secret = StringField(required=True)
label = StringField()
last_used = DateTimeField()
last_used_from = StringField()
class User(DbModelMixin, AuthDocument):

View File

@@ -95,6 +95,7 @@ class GetMixin(PropsMixin):
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {}
class QueryParameterOptions(object):
@@ -599,7 +600,7 @@ class GetMixin(PropsMixin):
return size
@classmethod
def get_data_with_scroll_and_filter_support(
def get_data_with_scroll_support(
cls,
query_dict: dict,
data_getter: Callable[[], Sequence[dict]],
@@ -629,15 +630,12 @@ class GetMixin(PropsMixin):
if cls._start_key in query_dict:
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
def update_state(returned_len: int):
if not state:
return
if state:
state.position = query_dict[cls._start_key]
cls.get_cache_manager().set_state(state)
if ret_params is not None:
ret_params["scroll_id"] = state.id
update_state(len(data))
return data
@classmethod
@@ -770,7 +768,7 @@ class GetMixin(PropsMixin):
override_projection=override_projection,
override_collation=override_collation,
)
return cls.get_data_with_scroll_and_filter_support(
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
)

View File

@@ -1,10 +1,8 @@
from typing import Sequence
from mongoengine import (
StringField,
DateTimeField,
BooleanField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
)
from apiserver.database import Database, strict
@@ -12,6 +10,7 @@ from apiserver.database.fields import (
StrippedStringField,
SafeDictField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import GetMixin
@@ -22,6 +21,10 @@ from apiserver.database.model.task.task import Task
class Model(AttributedDocument):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
}
meta = {
"db_alias": Database.backend,
"strict": strict,
@@ -30,8 +33,6 @@ class Model(AttributedDocument):
"project",
"task",
"last_update",
"metadata.key",
"metadata.type",
("company", "framework"),
("company", "name"),
("company", "user"),
@@ -63,6 +64,7 @@ class Model(AttributedDocument):
"project",
"task",
"parent",
"metadata.*",
),
datetime_fields=("last_update",),
)
@@ -86,6 +88,6 @@ class Model(AttributedDocument):
default=dict, user_set_allowed=True, exclude_by_default=True
)
company_origin = StringField(exclude_by_default=True)
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)

View File

@@ -1,16 +1,19 @@
from typing import Sequence
from mongoengine import (
Document,
EmbeddedDocument,
StringField,
DateTimeField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
)
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeSortedListField
from apiserver.database.model import DbModelMixin
from apiserver.database.fields import (
StrippedStringField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import DbModelMixin, AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
from apiserver.database.model.company import Company
from apiserver.database.model.metadata import MetadataItem
@@ -19,23 +22,25 @@ from apiserver.database.model.task.task import Task
class Entry(EmbeddedDocument, ProperDictMixin):
""" Entry representing a task waiting in the queue """
task = StringField(required=True, reference_field=Task)
''' Task ID '''
""" Task ID """
added = DateTimeField(required=True)
''' Added to the queue '''
""" Added to the queue """
class Queue(DbModelMixin, Document):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
}
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name",),
list_fields=("tags", "system_tags", "id"),
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
)
meta = {
'db_alias': Database.backend,
'strict': strict,
"indexes": ["metadata.key", "metadata.type"],
"db_alias": Database.backend,
"strict": strict,
}
id = StringField(primary_key=True)
@@ -44,10 +49,12 @@ class Queue(DbModelMixin, Document):
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
tags = SafeSortedListField(
StringField(required=True), default=list, user_set_allowed=True
)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)

View File

@@ -159,11 +159,10 @@ external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"execution.parameters.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
"hyperparams.": AttributedDocument._numeric_locale,
}
meta = {
@@ -176,6 +175,8 @@ class Task(AttributedDocument):
"active_duration",
"parent",
"project",
"last_update",
"status_changed",
"models.input.model",
("company", "name"),
("company", "user"),
@@ -184,7 +185,10 @@ class Task(AttributedDocument):
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
{"fields": ["company", "project"], "collation": _numeric_locale},
{
"fields": ["company", "project"],
"collation": AttributedDocument._numeric_locale,
},
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [

View File

@@ -21,6 +21,7 @@ from typing import (
Union,
Mapping,
IO,
Callable,
)
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
from apiserver.database.utils import get_options
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
class PrePopulate:
@@ -744,6 +746,19 @@ class PrePopulate:
module = importlib.import_module(module_name)
return getattr(module, class_name)
@staticmethod
def _upgrade_model_data(model_data: dict) -> dict:
metadata_key = "metadata"
metadata = model_data.get(metadata_key)
if isinstance(metadata, list):
metadata = {
ParameterKeyEscaper.escape(item["key"]): item
for item in metadata
if isinstance(item, dict) and "key" in item
}
model_data[metadata_key] = metadata
return model_data
@staticmethod
def _upgrade_task_data(task_data: dict) -> dict:
"""
@@ -828,9 +843,14 @@ class PrePopulate:
print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
override_project_count = 0
data_upgrade_funcs: Mapping[Type, Callable] = {
cls.task_cls: cls._upgrade_task_data,
cls.model_cls: cls._upgrade_model_data,
}
for item in cls.json_lines(f):
if cls_ == cls.task_cls:
item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item)))
upgrade_func = data_upgrade_funcs.get(cls_)
if upgrade_func:
item = json.dumps(upgrade_func(json.loads(item)))
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):

View File

@@ -0,0 +1,29 @@
from pymongo.collection import Collection
from pymongo.database import Database
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .utils import _drop_all_indices_from_collections
def _convert_metadata(db: Database, name):
collection: Collection = db[name]
metadata_field = "metadata"
query = {metadata_field: {"$exists": True, "$type": 4}}
for doc in collection.find(filter=query, projection=(metadata_field,)):
metadata = {
ParameterKeyEscaper.escape(item["key"]): item
for item in doc.get(metadata_field, [])
if isinstance(item, dict) and "key" in item
}
collection.update_one(
{"_id": doc["_id"]}, {"$set": {"metadata": metadata}},
)
def migrate_backend(db: Database):
collections = ["model", "queue"]
for name in collections:
_convert_metadata(db, name)
_drop_all_indices_from_collections(db, collections)

View File

@@ -24,6 +24,10 @@ _definitions {
description: ""
format: "date-time"
}
last_used_from {
type: string
description: ""
}
}
}
role {
@@ -226,6 +230,12 @@ create_credentials {
}
}
}
"2.17": ${create_credentials."2.1"} {
request.properties.label {
type: string
description: Optional credentials label
}
}
}
get_credentials {

View File

@@ -1304,3 +1304,24 @@ scalar_metrics_iter_raw {
}
}
}
clear_scroll {
"2.18" {
description: "Clear an open Scroll ID"
request {
type: object
required: [
scroll_id
]
properties {
scroll_id {
description: "Scroll ID as returned by previous events service calls"
type: string
}
}
}
response {
type: object
additionalProperties: false
}
}
}

View File

@@ -61,14 +61,14 @@ _definitions {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
description: "System tags. This field is reserved for system use, please don't use it."
items { type: string }
}
framework {
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
@@ -98,9 +98,11 @@ _definitions {
additionalProperties: true
}
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -407,7 +409,7 @@ update_for_task {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
override_model_id {
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
@@ -473,7 +475,7 @@ create {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
@@ -529,9 +531,11 @@ create {
}
"2.13": ${create."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -568,7 +572,7 @@ edit {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
@@ -624,9 +628,11 @@ edit {
}
"2.13": ${edit."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -657,7 +663,7 @@ update {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
ready {
description: "Indication if the model is final and can be used by other tasks Default is false."
@@ -707,9 +713,11 @@ update {
}
"2.13": ${update."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -718,7 +726,7 @@ publish_many {
description: Publish models
request {
properties {
ids.description: "IDs of models to publish"
ids.description: "IDs of the models to publish"
force_publish_task {
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
type: boolean
@@ -779,7 +787,7 @@ archive_many {
description: Archive models
request {
properties {
ids.description: "IDs of models to archive"
ids.description: "IDs of the models to archive"
}
}
response {
@@ -815,10 +823,9 @@ delete_many {
description: Delete models
request {
properties {
ids.description: "IDs of models to delete"
ids.description: "IDs of the models to delete"
force {
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
"""
description: "Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published."
type: boolean
}
}
@@ -975,6 +982,11 @@ add_or_update_metadata {
description: "Metadata items to add or update"
items {"$ref": "#/definitions/metadata_item"}
}
replace_metadata {
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
type: boolean
default: false
}
}
}
response {

View File

@@ -0,0 +1,47 @@
_description: "Provides a management API for pipelines in the system."
_definitions {
}
start_pipeline {
"2.17" {
description: "Start a pipeline"
request {
type: object
required: [ task ]
properties {
task {
description: "ID of the task on which the pipeline will be based"
type: string
}
queue {
description: "Queue ID in which the created pipeline task will be enqueued"
type: string
}
args {
description: "Task arguments, name/value to be placed in the hyperparameters Args section"
type: array
items {
type: object
properties {
name: { type: string }
value: { type: [string, null] }
}
}
}
}
}
response {
type: object
properties {
pipeline {
description: "ID of the new pipeline task"
type: string
}
enqueued {
description: "True if the task was successfuly enqueued"
type: boolean
}
}
}
}
}

View File

@@ -42,15 +42,20 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -70,6 +75,18 @@ _definitions {
description: "Total run time of all tasks in project (in seconds)"
type: integer
}
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks_24h {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
last_task_run {
description: "The most recent started time of a task"
type: integer
}
status_count {
description: "Status counts"
type: object
@@ -78,6 +95,10 @@ _definitions {
description: "Number of 'created' tasks in project"
type: integer
}
completed {
description: "Number of 'completed' tasks in project"
type: integer
}
queued {
description: "Number of 'queued' tasks in project"
type: integer
@@ -158,14 +179,14 @@ _definitions {
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -299,14 +320,14 @@ create {
type: string
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -419,7 +440,6 @@ get_all {
description: "Projects list"
type: array
items { "$ref": "#/definitions/projects_get_all_response_single" }
}
}
}
@@ -435,7 +455,14 @@ get_all {
}
}
}
"2.15": ${get_all."2.13"} {
"2.14": ${get_all."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden projects are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all."2.14"} {
request {
properties {
scroll_id {
@@ -516,7 +543,14 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
"2.14": ${get_all_ex."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden projects are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all_ex."2.14"} {
request {
properties {
scroll_id {
@@ -545,39 +579,16 @@ get_all_ex {
type: boolean
default: true
}
response {
}
"2.17": ${get_all_ex."2.16"} {
request.properties.include_stats_filter {
description: The filter for selecting entities that participate in statistics calculation
type: object
properties {
stats {
properties {
active.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
archived.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
}
system_tags {
description: The list of allowed system tags
type: array
items { type: string }
}
}
}
@@ -603,14 +614,14 @@ update {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -748,7 +759,6 @@ delete {
type: boolean
default: false
}
}
}
response {
@@ -881,6 +891,7 @@ get_hyper_parameters {
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
@@ -929,6 +940,105 @@ get_hyper_parameters {
}
}
}
get_model_metadata_values {
"2.17" {
description: """Get a list of distinct values for the chosen model metadata key"""
request {
type: object
required: [key]
properties {
projects {
description: "Project IDs"
type: array
items {type: string}
}
key {
description: "Metadata key"
type: string
}
allow_public {
description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'"
type: boolean
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models"
type: boolean
default: true
}
}
}
response {
type: object
properties {
total {
description: "Total number of distinct values"
type: integer
}
values {
description: "The list of the unique values"
type: array
items {type: string}
}
}
}
}
}
get_model_metadata_keys {
"2.17" {
description: """Get a list of all metadata keys used in models within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
type: string
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
type: boolean
default: true
}
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
response {
type: object
properties {
keys {
description: "A list of model keys"
type: array
items {type: string}
}
remaining {
description: "Remaining results"
type: integer
}
total {
description: "Total number of results"
type: integer
}
}
}
}
}
get_project_tags {
"2.17" {
description: "Get user and system tags used for the specified projects and their children"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"
@@ -936,7 +1046,6 @@ get_task_tags {
response = ${_definitions.tags_response}
}
}
get_model_tags {
"2.8" {
description: "Get user and system tags used for the models under the specified projects"
@@ -1058,4 +1167,4 @@ get_task_parents {
}
}
}
}
}

View File

@@ -79,9 +79,11 @@ _definitions {
items { "$ref": "#/definitions/entry" }
}
metadata {
type: array
description: "Queue metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -281,6 +283,15 @@ create {
}
}
}
"2.13": ${create."2.4"} {
metadata {
description: "Queue metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
update {
"2.4" {
@@ -322,7 +333,15 @@ update {
type: object
additionalProperties: true
}
}
}
}
"2.13": ${update."2.4"} {
metadata {
description: "Queue metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
@@ -632,6 +651,11 @@ add_or_update_metadata {
description: "Metadata items to add or update"
items {"$ref": "#/definitions/metadata_item"}
}
replace_metadata {
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
type: boolean
default: false
}
}
}
response {

View File

@@ -685,7 +685,14 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
"2.14": ${get_all_ex."2.13"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden tasks are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all_ex."2.14"} {
request {
properties {
scroll_id {
@@ -822,7 +829,14 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
"2.14": ${get_all."2.1"} {
request.properties.search_hidden {
description: "If set to 'true' then hidden tasks are included in the search results"
type: boolean
default: false
}
}
"2.15": ${get_all."2.14"} {
request {
properties {
scroll_id {

View File

@@ -25,6 +25,7 @@ from apiserver.server_init.request_handlers import RequestHandlers
from apiserver.service_repo import ServiceRepo
from apiserver.sync import distributed_lock
from apiserver.updates import check_updates_thread
from apiserver.utilities.env import get_bool
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
@@ -46,10 +47,13 @@ class AppSequence:
def _attach_request_handlers(self, request_handlers: RequestHandlers):
self.app.before_first_request(request_handlers.before_app_first_request)
self.app.before_request(request_handlers.before_request)
self.app.after_request(request_handlers.after_request)
def _configure(self):
CORS(self.app, **config.get("apiserver.cors"))
Compress(self.app)
if get_bool("CLEARML_COMPRESS_RESP", default=True):
Compress(self.app)
self.app.config["SECRET_KEY"] = config.get(
"secure.http.session_secret.apiserver"

View File

@@ -18,6 +18,7 @@ log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml")
def before_app_first_request(self):
pass
@@ -28,6 +29,9 @@ class RequestHandlers:
if "/static/" in request.path:
return
if request.content_encoding:
return f"Content encoding is not supported ({request.content_encoding})", 415
try:
call = self._create_api_call(request)
load_data_callback = partial(self._load_call_data, req=request)
@@ -81,6 +85,10 @@ class RequestHandlers:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def after_request(self, response):
response.headers["server"] = self._server_header
return response
@staticmethod
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
def convert_value(v: str):

View File

@@ -95,8 +95,8 @@ class DataContainer(object):
@raw_data.setter
def raw_data(self, value):
assert isinstance(
value, string_types + (types.GeneratorType,)
), "Raw data must be a string type or generator"
value, string_types + (types.GeneratorType, bytes)
), "Raw data must be a string type or bytes or generator"
self._raw_data = value
@property
@@ -395,6 +395,10 @@ class APICall(DataContainer):
self._auth_cookie = auth_cookie
self._json_flags = {}
@property
def files(self):
return self._files
@property
def id(self):
return self._id

View File

@@ -51,7 +51,7 @@ def authorize_token(jwt_token, *_, **__):
)
def authorize_credentials(auth_data, service, action, call_data_items):
def authorize_credentials(auth_data, service, action, call):
"""Validate credentials against service/action and request data (dicts).
Returns a new basic object (auth payload)
"""
@@ -100,7 +100,12 @@ def authorize_credentials(auth_data, service, action, call_data_items):
if not fixed_user:
# In case these are proper credentials, update last used time
User.objects(id=user.id, credentials__key=access_key).update(
**{"set__credentials__$__last_used": datetime.utcnow()}
**{
"set__credentials__$__last_used": datetime.utcnow(),
"set__credentials__$__last_used_from": call.get_worker(
default=call.real_ip
),
}
)
with TimingContext("mongo", "company_by_id"):

View File

@@ -38,7 +38,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.16")
_max_version = PartialVersion("2.17")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@@ -69,7 +69,7 @@ def validate_auth(endpoint, call):
auth = call.authorization or ""
auth_type, _, auth_data = auth.partition(" ")
authorize_func = get_auth_func(auth_type)
call.auth = authorize_func(auth_data, service, action, call.batched_data)
call.auth = authorize_func(auth_data, service, action, call)
except Exception:
if endpoint.authorize:
# if endpoint requires authorization, re-raise exception

View File

@@ -161,7 +161,10 @@ def get_credentials(call: APICall, _, __):
call.result.data_model = GetCredentialsResponse(
credentials=[
CredentialsResponse(
access_key=c.key, last_used=c.last_used, label=c.label
access_key=c.key,
last_used=c.last_used,
label=c.label,
last_used_from=c.last_used_from,
)
for c in user.credentials
]

View File

@@ -25,6 +25,7 @@ from apiserver.apimodels.events import (
TaskPlotsRequest,
TaskEventsRequest,
ScalarMetricsIterRawRequest,
ClearScrollRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants
@@ -936,3 +937,9 @@ def scalar_metrics_iter_raw(
scroll_id=scroll.get_scroll_id(),
variants=variants,
)
@endpoint("events.clear_scroll", min_version="2.18")
def clear_scroll(_, __, request: ClearScrollRequest):
if request.scroll_id:
event_bll.clear_scroll(request.scroll_id)

View File

@@ -21,16 +21,14 @@ from apiserver.apimodels.models import (
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
)
from apiserver.bll.model import ModelBLL
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import validate_id
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import (
@@ -50,8 +48,8 @@ from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
ModelsBackwardsCompatibility,
validate_metadata,
get_metadata_from_api,
unescape_metadata,
escape_metadata,
)
from apiserver.timing_context import TimingContext
@@ -64,19 +62,20 @@ project_bll = ProjectBLL()
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
models = Model.get_many(
company=company_id,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
Metadata.escape_query_parameters(call)
models = Model.get_many(
company=company_id,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
conform_output_tags(call, models[0])
call.result.data = {"model": models[0]}
conform_output_tags(call, models[0])
unescape_metadata(call, models[0])
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
@@ -86,25 +85,25 @@ def get_by_task_id(call: APICall, company_id, _):
task_id = call.data["task"]
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not task.models or not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output")
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not task.models or not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output")
model_id = task.models.output[-1].model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
call.result.data = {"model": model_dict}
model_id = task.models.output[-1].model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
unescape_metadata(call, model_dict)
call.result.data = {"model": model_dict}
def _process_include_subprojects(call_data: dict):
@@ -121,47 +120,50 @@ def _process_include_subprojects(call_data: dict):
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
_process_include_subprojects(call.data)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
call.result.data = {"models": models, **ret_params}
_process_include_subprojects(call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_by_id_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
call.result.data = {"models": models}
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_by_id_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
call.result.data = {"models": models, **ret_params}
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
@@ -189,15 +191,22 @@ create_fields = {
"metadata": list,
}
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata", "system_tags", "tags")
last_update_fields = (
"uri",
"framework",
"design",
"labels",
"ready",
"metadata",
"system_tags",
"tags",
)
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields, validate=True)
metadata = fields.get("metadata")
if metadata:
validate_metadata(metadata)
escape_metadata(fields)
return fields
@@ -231,82 +240,80 @@ def update_for_task(call: APICall, company_id, _):
"exactly one field is required", fields=("uri", "override_model_id")
)
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
raise errors.bad_request.InvalidTaskStatus(
f"model can only be updated for tasks in the {allowed_states} states",
**query,
)
if override_model_id:
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=override_model_id
)
else:
if "name" not in call.data:
# use task name if name not provided
call.data["name"] = task.name
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.models and task.models.output:
# model exists, update
model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct()
res.update({"id": model_id, "created": False})
call.result.data = res
return
# new model, create
fields = parse_model_fields(call, create_fields)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
created=now,
last_update=now,
user=call.identity.user,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
project=task.project,
framework=task.execution.framework,
parent=task.models.input[0].model
if task.models and task.models.input
else None,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
**fields,
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
raise errors.bad_request.InvalidTaskStatus(
f"model can only be updated for tasks in the {allowed_states} states",
**query,
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
last_iteration_max=iteration,
models__output=[
ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
)
],
)
if override_model_id:
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=override_model_id
)
else:
if "name" not in call.data:
# use task name if name not provided
call.data["name"] = task.name
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.models and task.models.output:
# model exists, update
model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct()
res.update({"id": model_id, "created": False})
call.result.data = res
return
# new model, create
fields = parse_model_fields(call, create_fields)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
created=now,
last_update=now,
user=call.identity.user,
company=company_id,
project=task.project,
framework=task.execution.framework,
parent=task.models.input[0].model
if task.models and task.models.input
else None,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
last_iteration_max=iteration,
models__output=[
ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
)
],
)
call.result.data = {"id": model.id, "created": True}
call.result.data = {"id": model.id, "created": True}
@endpoint(
@@ -319,36 +326,33 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
if req_model.public:
company_id = ""
with translate_errors_context():
project = req_model.project
if project:
validate_id(Project, company=company_id, project=project)
project = req_model.project
if project:
validate_id(Project, company=company_id, project=project)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
user=call.identity.user,
company=company_id,
created=now,
last_update=now,
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
validate_metadata(fields.get("metadata"))
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
user=call.identity.user,
company=company_id,
created=now,
last_update=now,
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
def prepare_update_fields(call, company_id, fields: dict):
@@ -383,6 +387,7 @@ def prepare_update_fields(call, company_id, fields: dict):
)
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
return fields
@@ -394,89 +399,85 @@ def validate_task(company_id, fields: dict):
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
if fields:
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
if fields:
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
updated = model.update(upsert=False, **fields)
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
updated = model.update(upsert=False, **fields)
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
with translate_errors_context():
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
data = prepare_update_fields(call, company_id, call.data)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
metadata = data.get("metadata")
if metadata:
validate_metadata(metadata)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
conform_output_tags(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
conform_output_tags(call, updated_fields)
unescape_metadata(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@endpoint(
@@ -641,26 +642,25 @@ def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
model_id = request.model
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
updated = metadata_add_or_update(
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
)
if updated:
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
return {"updated": updated}
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
return {
"updated": Metadata.edit_metadata(
model,
items=request.metadata,
replace_metadata=request.replace_metadata,
last_update=datetime.utcnow(),
)
}
@endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model
ModelBLL.get_company_model_by_id(
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
if updated:
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
return {"updated": updated}
return {
"updated": Metadata.delete_metadata(
model, keys=request.keys, last_update=datetime.utcnow()
)
}

View File

@@ -5,7 +5,7 @@ from apiserver.apimodels.organization import TagsRequest
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.database.model import User
from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
org_bll = OrgBLL()
@@ -21,17 +21,13 @@ def get_tags(call: APICall, company, request: TagsRequest):
for field, vals in tags.items():
ret[field] |= vals
call.result.data = get_tags_response(ret)
call.result.data = sort_tags_response(ret)
@endpoint("organization.get_user_companies")
def get_user_companies(call: APICall, company_id: str, _):
users = [
{
"id": u.id,
"name": u.name,
"avatar": u.avatar,
}
{"id": u.id, "name": u.name, "avatar": u.avatar}
for u in User.objects(company=company_id).only("avatar", "name", "company")
]

View File

@@ -0,0 +1,68 @@
import re
from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
def _update_task_name(task: Task):
if not task or not task.project:
return
project = Project.objects(id=task.project).only("name").first()
if not project:
return
_, _, name_prefix = project.name.rpartition("/")
name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$")
count = Task.objects(
project=task.project, system_tags__in=["pipeline"], name=name_mask
).count()
new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix
task.update(name=new_name)
@endpoint(
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
)
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
hyperparams = None
if request.args:
hyperparams = {
"Args": {
str(arg.name): {
"section": "Args",
"name": str(arg.name),
"value": str(arg.value),
}
for arg in request.args or []
}
}
task, _ = task_bll.clone_task(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=hyperparams,
)
_update_task_name(task)
queued, res = enqueue_task(
task_id=task.id,
company_id=company_id,
queue_id=request.queue,
status_message="Starting pipeline",
status_reason="",
)
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))

View File

@@ -17,6 +17,7 @@ from apiserver.apimodels.projects import (
MergeRequest,
ProjectOrNoneRequest,
ProjectRequest,
ProjectModelMetadataValuesRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries
@@ -25,6 +26,7 @@ from apiserver.bll.project.project_cleanup import (
validate_project_delete,
)
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.utils import (
parse_from_call,
@@ -35,7 +37,7 @@ from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
get_tags_filter_dictionary,
get_tags_response,
sort_tags_response,
)
from apiserver.timing_context import TimingContext
@@ -72,6 +74,16 @@ def get_by_id(call):
call.result.data = {"project": project_dict}
def _hidden_query(search_hidden: bool, ids: Sequence) -> Q:
"""
1. Add only non-hidden tasks search condition (unless specifically specified differently)
"""
if search_hidden or ids:
return Q()
return Q(system_tags__ne=EntityVisibility.hidden.value)
def _adjust_search_parameters(data: dict, shallow_search: bool):
"""
1. Make sure that there is no external query on path
@@ -90,12 +102,14 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_tag_fields(call, call.data)
allow_public = not request.non_public
data = call.data
conform_tag_fields(call, data)
allow_public = not request.non_public
requested_ids = data.get("id")
_adjust_search_parameters(
data, shallow_search=request.shallow_search,
)
with TimingContext("mongo", "projects_get_all"):
data = call.data
if request.active_users:
ids = project_bll.get_projects_with_active_user(
company=company_id,
@@ -104,16 +118,14 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
allow_public=allow_public,
)
if not ids:
call.result.data = {"projects": []}
return
return {"projects": []}
data["id"] = ids
_adjust_search_parameters(data, shallow_search=request.shallow_search)
ret_params = {}
projects = Project.get_many_with_join(
projects: Sequence[dict] = Project.get_many_with_join(
company=company_id,
query_dict=data,
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
allow_public=allow_public,
ret_params=ret_params,
)
@@ -124,7 +136,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
}
if existing_requested_ids:
contents = project_bll.calc_own_contents(
company=company_id, project_ids=list(existing_requested_ids)
company=company_id,
project_ids=list(existing_requested_ids),
filter_=request.include_stats_filter,
)
for project in projects:
project.update(**contents.get(project["id"], {}))
@@ -140,6 +154,8 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
project_ids=list(project_ids),
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
return_hidden_children=request.search_hidden,
filter_=request.include_stats_filter,
)
for project in projects:
@@ -151,20 +167,24 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
@endpoint("projects.get_all")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
data = call.data
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
conform_tag_fields(call, data)
_adjust_search_parameters(
data, shallow_search=data.get("shallow_search", False),
)
with TimingContext("mongo", "projects_get_all"):
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
query=_hidden_query(
search_hidden=data.get("search_hidden"), ids=data.get("id")
),
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects, **ret_params}
@@ -275,6 +295,40 @@ def get_unique_metric_variants(
call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",)
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"keys": keys,
}
@endpoint("projects.get_model_metadata_values")
def get_model_metadata_values(
call: APICall, company_id: str, request: ProjectModelMetadataValuesRequest
):
total, values = project_queries.get_model_metadata_distinct_values(
company_id,
project_ids=request.projects,
key=request.key,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
)
call.result.data = {
"total": total,
"values": values,
}
@endpoint(
"projects.get_hyper_parameters",
min_version="2.9",
@@ -305,7 +359,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsReque
def get_hyperparam_values(
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
):
total, values = project_queries.get_hyperparam_distinct_values(
total, values = project_queries.get_task_hyperparam_distinct_values(
company_id,
project_ids=request.projects,
section=request.section,
@@ -319,6 +373,17 @@ def get_hyperparam_values(
}
@endpoint("projects.get_project_tags")
def get_tags(call: APICall, company, request: ProjectTagsRequest):
tags, system_tags = project_bll.get_project_tags(
company,
include_system=request.include_system,
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = sort_tags_response({"tags": tags, "system_tags": system_tags})
@endpoint(
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
@@ -330,7 +395,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)
call.result.data = sort_tags_response(ret)
@endpoint(
@@ -344,7 +409,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)
call.result.data = sort_tags_response(ret)
@endpoint(

View File

@@ -13,17 +13,19 @@ from apiserver.apimodels.queues import (
QueueMetrics,
AddOrUpdateMetadataRequest,
DeleteMetadataRequest,
GetNextTaskRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
from apiserver.bll.workers import WorkerBLL
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
conform_tags,
get_metadata_from_api,
escape_metadata,
unescape_metadata,
)
from apiserver.utilities import extract_properties_to_lists
@@ -36,6 +38,7 @@ def get_by_id(call: APICall, company_id, req_model: QueueRequest):
queue = queue_bll.get_by_id(company_id, req_model.queue)
queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict)
unescape_metadata(call, queue_dict)
call.result.data = {"queue": queue_dict}
@@ -49,13 +52,13 @@ def get_by_id(call: APICall):
def get_all_ex(call: APICall):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_queue_infos(
company_id=call.identity.company,
query_dict=call.data,
ret_params=ret_params,
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
call.result.data = {"queues": queues, **ret_params}
@@ -63,13 +66,12 @@ def get_all_ex(call: APICall):
def get_all(call: APICall):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_all(
company_id=call.identity.company,
query_dict=call.data,
ret_params=ret_params,
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
call.result.data = {"queues": queues, **ret_params}
@@ -83,7 +85,7 @@ def create(call: APICall, company_id, request: CreateRequest):
name=request.name,
tags=tags,
system_tags=system_tags,
metadata=get_metadata_from_api(request.metadata),
metadata=Metadata.metadata_from_api(request.metadata),
)
call.result.data = {"id": queue.id}
@@ -97,10 +99,12 @@ def create(call: APICall, company_id, request: CreateRequest):
def update(call: APICall, company_id, req_model: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data, validate=True)
escape_metadata(data)
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@@ -121,11 +125,19 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
}
@endpoint("queues.get_next_task", min_version="2.4", request_data_model=QueueRequest)
def get_next_task(call: APICall, company_id, req_model: QueueRequest):
task = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
if task:
call.result.data = {"entry": task.to_proper_dict()}
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
entry = queue_bll.get_next_task(
company_id=company_id, queue_id=req_model.queue
)
if entry:
data = {"entry": entry.to_proper_dict()}
if req_model.get_task_info:
task = Task.objects(id=entry.task).first()
if task:
data["task_info"] = {"company": task.company, "user": task.user}
call.result.data = data
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
@@ -245,21 +257,19 @@ def get_queue_metrics(
@endpoint("queues.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {
"updated": metadata_add_or_update(
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
"updated": Metadata.edit_metadata(
queue, items=request.metadata, replace_metadata=request.replace_metadata
)
}
@endpoint("queues.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}

View File

@@ -98,6 +98,7 @@ from apiserver.bll.task.task_operations import (
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import (
Task,
@@ -213,6 +214,16 @@ def _process_include_subprojects(call_data: dict):
call_data["project"] = project_ids_with_children(project_ids)
def _hidden_query(data: dict) -> Q:
"""
1. Add only non-hidden tasks search condition (unless specifically specified differently)
"""
if data.get("search_hidden") or data.get("id"):
return Q()
return Q(system_tags__ne=EntityVisibility.hidden.value)
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
@@ -225,6 +236,7 @@ def get_all_ex(call: APICall, company_id, _):
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)
@@ -259,6 +271,7 @@ def get_all(call: APICall, company_id, _):
company=company_id,
parameters=call_data,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=True,
ret_params=ret_params,
)

View File

@@ -2,7 +2,6 @@ from datetime import datetime
from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
@@ -24,7 +23,7 @@ def get_tags_filter_dictionary(input_: Filter) -> dict:
}
def get_tags_response(ret: dict) -> dict:
def sort_tags_response(ret: dict) -> dict:
return {field: sorted(vals) for field, vals in ret.items()}
@@ -222,22 +221,38 @@ class DockerCmdBackwardsCompatibility:
nested_set(task, cls.field, docker_cmd)
def validate_metadata(metadata: Sequence[dict]):
def escape_metadata(document: dict):
"""
Escape special characters in metadata keys
"""
metadata = document.get("metadata")
if not metadata:
return
keys = [m.get("key") for m in metadata]
unique_keys = set(keys)
unique_keys.discard(None)
if len(keys) != len(set(keys)):
raise errors.bad_request.ValidationError("Metadata keys should be unique")
document["metadata"] = {
ParameterKeyEscaper.escape(k): v
for k, v in metadata.items()
}
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
if not api_metadata:
return api_metadata
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
Unescape special characters in metadata keys
"""
if isinstance(documents, dict):
documents = [documents]
metadata = [m.to_struct() for m in api_metadata]
validate_metadata(metadata)
old_client = call.requested_endpoint_version <= PartialVersion("2.16")
for doc in documents:
if old_client and "metadata" in doc:
doc["metadata"] = []
continue
return metadata
metadata = doc.get("metadata")
if not metadata:
continue
doc["metadata"] = {
ParameterKeyEscaper.unescape(k): v
for k, v in metadata.items()
}

View File

@@ -4,10 +4,29 @@ from apiserver.tests.automated import TestService
class TestProjectTags(TestService):
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_project_own_tags(self):
p1_tags = ["Tag 1", "Tag 2"]
p1 = self.create_temp(
"projects", name="Test project tags1", description="test", tags=p1_tags
)
p2_tags = ["Tag 1", "Tag 3"]
p2 = self.create_temp(
"projects",
name="Test project tags2",
description="test",
tags=p2_tags,
system_tags=["hidden"],
)
def test_project_tags(self):
res = self.api.projects.get_project_tags(projects=[p1, p2])
self.assertEqual(set(res.tags), set(p1_tags) | set(p2_tags))
res = self.api.projects.get_project_tags(
projects=[p1, p2], filter={"system_tags": ["__$not", "hidden"]}
)
self.assertEqual(res.tags, p1_tags)
def test_project_entities_tags(self):
tags_1 = ["Test tag 1", "Test tag 2"]
tags_2 = ["Test tag 3", "Test tag 4"]

View File

@@ -1,12 +1,11 @@
from functools import partial
from typing import Sequence
from apiserver.tests.api_client import APIClient
from apiserver.tests.automated import TestService
class TestQueueAndModelMetadata(TestService):
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
meta1 = {"test_key": {"key": "test_key", "type": "str", "value": "test_value"}}
def test_queue_metas(self):
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
@@ -23,20 +22,51 @@ class TestQueueAndModelMetadata(TestService):
)
model_id = self._temp_model("TestMetadata1")
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
self.api.models.edit(model=model_id, metadata=self.meta1)
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
def test_project_meta_query(self):
self._temp_model("TestMetadata", metadata=self.meta1)
project = self.temp_project(name="MetaParent")
test_key = "test_key"
test_key2 = "test_key2"
test_value = "test_value"
test_value2 = "test_value2"
model_id = self._temp_model(
"TestMetadata2",
project=project,
metadata={
test_key: {"key": test_key, "type": "str", "value": test_value},
test_key2: {"key": test_key2, "type": "str", "value": test_value2},
},
)
res = self.api.projects.get_model_metadata_keys()
self.assertTrue({test_key, test_key2}.issubset(set(res["keys"])))
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
self.assertTrue(test_key in res["keys"])
self.assertFalse(test_key2 in res["keys"])
model = self.api.models.get_all_ex(
id=[model_id], only_fields=["metadata.test_key"]
).models[0]
self.assertTrue(test_key in model.metadata)
self.assertFalse(test_key2 in model.metadata)
res = self.api.projects.get_model_metadata_values(key=test_key)
self.assertEqual(res.total, 1)
self.assertEqual(res["values"], [test_value])
def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str,
):
assert_meta = partial(self._assertMeta, service=service, entity=entity)
assert_meta(_id=_id, meta=self.meta1)
meta2 = [
{"key": "test1", "type": "str", "value": "data1"},
{"key": "test2", "type": "str", "value": "data2"},
{"key": "test3", "type": "str", "value": "data3"},
]
meta2 = {
"test1": {"key": "test1", "type": "str", "value": "data1"},
"test2": {"key": "test2", "type": "str", "value": "data2"},
"test3": {"key": "test3", "type": "str", "value": "data3"},
}
service.update(**{entity: _id, "metadata": meta2})
assert_meta(_id=_id, meta=meta2)
@@ -48,16 +78,17 @@ class TestQueueAndModelMetadata(TestService):
]
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=[meta2[0], *updates])
assert_meta(_id=_id, meta={**meta2, **{u["key"]: u for u in updates}})
res = service.delete_metadata(
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
)
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=meta2[:1])
# noinspection PyTypeChecker
assert_meta(_id=_id, meta=dict(list(meta2.items())[:1]))
def _assertMeta(
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
self, service: APIClient.Service, entity: str, _id: str, meta: dict
):
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
self.assertEqual(res.metadata, meta)

View File

@@ -199,10 +199,10 @@ class TestSubProjects(TestService):
res1 = next(p for p in res if p.id == project1)
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
self.assertEqual(res1.stats["active"]["status_count"]["in_progress"], 0)
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
self.assertEqual(res1.stats["active"]["completed_tasks_24h"], 2)
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
self.assertEqual(
{sp.name for sp in res1.sub_projects},
{
@@ -214,10 +214,10 @@ class TestSubProjects(TestService):
res2 = next(p for p in res if p.id == project2)
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["in_progress"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["completed"], 0)
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
self.assertEqual(res2.sub_projects, [])
def _run_tasks(self, *tasks):

View File

@@ -133,6 +133,32 @@ class TestTags(TestService):
).models
self.assertFound(model_id, [], models)
def testQueueTags(self):
q_id = self._temp_queue(system_tags=["default"])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["default"]
).queues
self.assertFound(q_id, ["default"], queues)
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertNotFound(q_id, queues)
self.api.queues.update(queue=q_id, system_tags=[])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertFound(q_id, [], queues)
# test default queue
queues = self.api.queues.get_all(system_tags=["default"]).queues
if queues:
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
else:
self.api.queues.update(queue=q_id, system_tags=["default"])
self.assertEqual(q_id, self.api.queues.get_default().id)
def testTaskTags(self):
task_id = self._temp_task(
name="Test tags", system_tags=["active"]
@@ -169,38 +195,11 @@ class TestTags(TestService):
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.status, "stopped")
def testQueueTags(self):
q_id = self._temp_queue(system_tags=["default"])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["default"]
).queues
self.assertFound(q_id, ["default"], queues)
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertNotFound(q_id, queues)
self.api.queues.update(queue=q_id, system_tags=[])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertFound(q_id, [], queues)
# test default queue
queues = self.api.queues.get_all(system_tags=["default"]).queues
if queues:
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
else:
self.api.queues.update(queue=q_id, system_tags=["default"])
self.assertEqual(q_id, self.api.queues.get_default().id)
def assertProjectStats(self, project: AttrDict):
self.assertEqual(set(project.stats.keys()), {"active"})
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
self.assertEqual(project.stats.active.completed_tasks, 1)
self.assertEqual(project.stats.active.completed_tasks_24h, 1)
self.assertEqual(project.stats.active.total_tasks, 1)
self.assertEqual(project.stats.active.running_tasks, 0)
for status, count in project.stats.active.status_count.items():
self.assertEqual(count, 1 if status == "stopped" else 0)

View File

@@ -0,0 +1,14 @@
from distutils.util import strtobool
from os import getenv
from typing import Optional
def get_bool(*keys: str, default: bool = None) -> Optional[bool]:
try:
value = next(env for env in (getenv(key) for key in keys) if env is not None)
except StopIteration:
return default
try:
return bool(strtobool(value))
except ValueError:
return bool(value)

View File

@@ -1 +1 @@
__version__ = "1.2.0"
__version__ = "1.4.0"

View File

@@ -1,9 +1,11 @@
FROM centos/nodejs-12-centos7 AS webapp
ARG CLEARML_WEB_GIT_URL=https://github.com/allegroai/clearml-web.git
USER root
WORKDIR /opt
RUN git clone https://github.com/allegroai/clearml-web.git
RUN git clone ${CLEARML_WEB_GIT_URL} clearml-web
RUN mv clearml-web /opt/open-webapp
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
@@ -18,6 +20,7 @@ COPY --from=staging_image /opt/clearml/ /opt/clearml/
COPY --chmod=744 docker/build/internal_files/final_image_preparation.sh /tmp/internal_files/
COPY docker/build/internal_files/clearml.conf.template /tmp/internal_files/
COPY docker/build/internal_files/clearml_subpath.conf.template /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/final_image_preparation.sh'
COPY --from=webapp /opt/open-webapp/build /usr/share/nginx/html

View File

@@ -41,6 +41,7 @@ http {
server_name _;
root /usr/share/nginx/html;
proxy_http_version 1.1;
client_max_body_size 0;
# comppression
gzip on;

View File

@@ -0,0 +1,21 @@
location /${CLEARML_SERVER_SUB_PATH} {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass http://localhost:80;
rewrite /${CLEARML_SERVER_SUB_PATH}/(.*) /$1 break;
}
location /${CLEARML_SERVER_SUB_PATH}/api {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass http://localhost:80/api;
rewrite /${CLEARML_SERVER_SUB_PATH}/api/(.*) /api/$1 break;
}
location /${CLEARML_SERVER_SUB_PATH}/files {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass http://localhost:80/files;
rewrite /${CLEARML_SERVER_SUB_PATH}/files/(.*) /files/$1 break;
rewrite /${CLEARML_SERVER_SUB_PATH}/files /files/ break;
}

View File

@@ -48,9 +48,16 @@ EOF
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
envsubst '${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/nginx.conf
if [[ -n "${CLEARML_SERVER_SUB_PATH}" ]]; then
envsubst '${CLEARML_SERVER_SUB_PATH}' < /etc/nginx/clearml_subpath.conf.template > /etc/nginx/default.d/clearml_subpath.conf
cp /usr/share/nginx/html/env.js /usr/share/nginx/html/env.js.origin
envsubst '${CLEARML_SERVER_SUB_PATH}' < /usr/share/nginx/html/env.js.origin > /usr/share/nginx/html/env.js
cp /usr/share/nginx/html/index.html /usr/share/nginx/html/index.html.origin
sed 's/href="\/"/href="\/'${CLEARML_SERVER_SUB_PATH}'\/"/' /usr/share/nginx/html/index.html.origin > /usr/share/nginx/html/index.html
fi
#start the server
/usr/sbin/nginx -g "daemon off;"

View File

@@ -5,7 +5,7 @@ set -o pipefail
yum update -y
yum install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm
yum install -y python36 python36-pip nginx gcc python3-devel gettext
yum install -y python36 python36-pip nginx gcc gcc-c++ python3-devel gettext
yum -y upgrade
python3 -m pip install -r /opt/clearml/fileserver/requirements.txt
python3 -m pip install -r /opt/clearml/apiserver/requirements.txt
@@ -15,4 +15,5 @@ ln -s /dev/stdout /var/log/nginx/access.log
ln -s /dev/stderr /var/log/nginx/error.log
mv /etc/nginx/nginx.conf /etc/nginx/nginx.conf.orig
mv /tmp/internal_files/clearml.conf.template /etc/nginx/clearml.conf.template
mv /tmp/internal_files/clearml_subpath.conf.template /etc/nginx/clearml_subpath.conf.template
yum clean all

View File

@@ -89,12 +89,12 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:3.6.23
image: mongo:4.4.9
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
- c:/opt/clearml/data/mongo/db:/data/db
- c:/opt/clearml/data/mongo/configdb:/data/configdb
- c:/opt/clearml/data/mongo_4/db:/data/db
- c:/opt/clearml/data/mongo_4/configdb:/data/configdb
redis:
networks:

View File

@@ -88,12 +88,12 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:3.6.23
image: mongo:4.4.9
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
- /opt/clearml/data/mongo/db:/data/db
- /opt/clearml/data/mongo/configdb:/data/configdb
- /opt/clearml/data/mongo_4/db:/data/db
- /opt/clearml/data/mongo_4/configdb:/data/configdb
redis:
networks:
@@ -108,6 +108,8 @@ services:
command:
- webserver
container_name: clearml-webserver
# environment:
# CLEARML_SERVER_SUB_PATH : clearml-web # Allow Clearml to be served with a URL path prefix.
image: allegroai/clearml:latest
restart: unless-stopped
depends_on:
@@ -152,6 +154,8 @@ services:
- /opt/clearml/agent:/root/.clearml
depends_on:
- apiserver
entrypoint: >
bash -c "curl --retry 10 --retry-delay 10 --retry-connrefused 'http://apiserver:8008/debug.ping' && /usr/agent/entrypoint.sh"
networks:
backend:

View File

@@ -1,5 +1,8 @@
# trains-server FAQ
## **NOTE**: This page's information is deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
Launching **trains-server**
* How do I launch **trains-server** on:

View File

@@ -1,5 +1,7 @@
# Deploying **trains-server** on AWS
## **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
To easily deploy **trains-server** on AWS, use one of our pre-built Amazon Machine Images (AMIs).
We provide AMIs per region for each released version of **trains-server**, see [Released versions](#released-versions) below.

View File

@@ -1,5 +1,7 @@
# Deploying Trains Server on Google Cloud Platform
# **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.

View File

@@ -1,5 +1,7 @@
# Launching the **trains-server** Docker in Linux or macOS
## **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
For Linux users:

View File

@@ -1,5 +1,7 @@
# Launching the **trains-server** Docker in Windows 10
## **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
For Windows, we recommend launching our pre-built Docker image on a Linux virtual machine.
However, you can launch **trains-server** on Windows 10 using Docker Desktop for Windows (see the Docker [System Requirements](https://docs.docker.com/docker-for-windows/install/#system-requirements)).

View File

@@ -13,12 +13,15 @@ from werkzeug.exceptions import NotFound
from werkzeug.security import safe_join
from config import config
from utils import get_env_bool
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
app = Flask(__name__)
CORS(app, **config.get("fileserver.cors"))
Compress(app)
if get_env_bool("CLEARML_COMPRESS_RESP", default=True):
Compress(app)
app.config["UPLOAD_FOLDER"] = first(
(os.environ.get(f"{prefix}_UPLOAD_FOLDER") for prefix in ("CLEARML", "TRAINS")),
@@ -29,6 +32,20 @@ app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get(
)
@app.before_request
def before_request():
if request.content_encoding:
return f"Content encoding is not supported ({request.content_encoding})", 415
@app.after_request
def after_request(response):
response.headers["server"] = config.get(
"fileserver.response.headers.server", "clearml"
)
return response
@app.route("/", methods=["POST"])
def upload():
results = []
@@ -54,7 +71,10 @@ def download(path):
mimetype = "application/octet-stream" if encoding == "gzip" else None
response = send_from_directory(
app.config["UPLOAD_FOLDER"], path, as_attachment=as_attachment, mimetype=mimetype
app.config["UPLOAD_FOLDER"],
path,
as_attachment=as_attachment,
mimetype=mimetype,
)
if config.get("fileserver.download.disable_browser_caching", False):
headers = response.headers
@@ -68,12 +88,7 @@ def download(path):
@app.route("/<path:path>", methods=["DELETE"])
def delete(path):
real_path = Path(
safe_join(
os.fspath(app.config["UPLOAD_FOLDER"]),
os.fspath(path)
)
)
real_path = Path(safe_join(os.fspath(app.config["UPLOAD_FOLDER"]), os.fspath(path)))
if not real_path.exists() or not real_path.is_file():
abort(Response(f"File {str(path)} not found", 404))

14
fileserver/utils.py Normal file
View File

@@ -0,0 +1,14 @@
from distutils.util import strtobool
from os import getenv
from typing import Optional
def get_env_bool(*keys: str, default: bool = None) -> Optional[bool]:
try:
value = next(env for env in (getenv(key) for key in keys) if env is not None)
except StopIteration:
return default
try:
return bool(strtobool(value))
except ValueError:
return bool(value)