mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
69737308fe | ||
|
|
a6dbea808a | ||
|
|
5131b17901 | ||
|
|
5f21c3a56d | ||
|
|
2350ac64ed | ||
|
|
d146127c18 | ||
|
|
abd65e103e | ||
|
|
bf65ea7bd0 | ||
|
|
73e278a8ed | ||
|
|
d92dfbbdb7 | ||
|
|
5c1e419eb5 | ||
|
|
124684f53f | ||
|
|
455b5d6758 | ||
|
|
c04e2e498b | ||
|
|
da8a45072f | ||
|
|
e1992e2054 | ||
|
|
c17cedd93a | ||
|
|
b6ad8f8790 | ||
|
|
5acc7eebc3 | ||
|
|
941927dfcd | ||
|
|
02933a9c93 | ||
|
|
e537651f29 | ||
|
|
af09fba755 | ||
|
|
04ea9018a3 | ||
|
|
ff7e1be24f |
@@ -81,6 +81,7 @@ class Credentials(Base):
|
||||
class CredentialsResponse(Credentials):
|
||||
secret_key = StringField()
|
||||
last_used = DateTimeField(default=None)
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsRequest(Base):
|
||||
|
||||
@@ -137,3 +137,7 @@ class TaskPlotsRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class ClearScrollRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
@@ -21,3 +21,4 @@ class AddOrUpdateMetadata(Base):
|
||||
metadata: Sequence[MetadataItem] = ListField(
|
||||
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||
)
|
||||
replace_metadata = BoolField(default=False)
|
||||
|
||||
@@ -30,7 +30,7 @@ class CreateModelRequest(models.Base):
|
||||
ready = fields.BoolField(default=True)
|
||||
ui_cache = DictField()
|
||||
task = fields.StringField()
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class CreateModelResponse(models.Base):
|
||||
|
||||
19
apiserver/apimodels/pipelines.py
Normal file
19
apiserver/apimodels/pipelines.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class Arg(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
value = fields.StringField(required=True)
|
||||
|
||||
|
||||
class StartPipelineRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
queue = fields.StringField(required=True)
|
||||
args = ListField(Arg)
|
||||
|
||||
|
||||
class StartPipelineResponse(models.Base):
|
||||
pipeline = fields.StringField(required=True)
|
||||
enqueued = fields.BoolField(required=True)
|
||||
@@ -1,6 +1,6 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField, ActualEnumField
|
||||
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
||||
from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
|
||||
@@ -51,11 +51,18 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
|
||||
key = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
include_stats_filter = DictField()
|
||||
stats_with_children = fields.BoolField(default=True)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
active_users = fields.ListField(str)
|
||||
check_own_contents = fields.BoolField(default=False)
|
||||
shallow_search = fields.BoolField(default=False)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
|
||||
@@ -2,7 +2,7 @@ from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
@@ -19,13 +19,18 @@ class CreateRequest(Base):
|
||||
name = StringField(required=True)
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class QueueRequest(Base):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class GetNextTaskRequest(QueueRequest):
|
||||
queue = StringField(required=True)
|
||||
get_task_info = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteRequest(QueueRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
@@ -34,7 +39,7 @@ class UpdateRequest(QueueRequest):
|
||||
name = StringField()
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class TaskRequest(QueueRequest):
|
||||
|
||||
@@ -162,7 +162,7 @@ class AuthBLL:
|
||||
access_key=get_client_id(), secret_key=get_secret_key(), label=label
|
||||
)
|
||||
user.credentials.append(
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key)
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key, label=label)
|
||||
)
|
||||
user.save()
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
|
||||
|
||||
from elasticsearch import helpers
|
||||
import elasticsearch
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
@@ -48,6 +48,9 @@ MAX_LONG = 2 ** 63 - 1
|
||||
MIN_LONG = -(2 ** 63)
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class PlotFields:
|
||||
valid_plot = "valid_plot"
|
||||
plot_len = "plot_len"
|
||||
@@ -219,7 +222,7 @@ class EventBLL(object):
|
||||
with TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
elasticsearch.helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
@@ -1005,3 +1008,16 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
def clear_scroll(self, scroll_id: str):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.es.clear_scroll(scroll_id=scroll_id)
|
||||
except elasticsearch.exceptions.NotFoundError:
|
||||
pass
|
||||
except elasticsearch.exceptions.RequestError:
|
||||
pass
|
||||
except Exception as ex:
|
||||
log.exception("Failed clearing scroll %s", scroll_id)
|
||||
|
||||
@@ -67,6 +67,9 @@ class EventsIterator:
|
||||
task_id: str,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> int:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return 0
|
||||
|
||||
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
es_req = {
|
||||
"query": query,
|
||||
|
||||
@@ -7,6 +7,7 @@ from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from .metadata import Metadata
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
|
||||
111
apiserver/bll/model/metadata.py
Normal file
111
apiserver/bll/model/metadata.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Sequence, Union, Mapping
|
||||
|
||||
from mongoengine import Document
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Metadata:
|
||||
@staticmethod
|
||||
def metadata_from_api(
|
||||
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
|
||||
) -> dict:
|
||||
if not api_data:
|
||||
return {}
|
||||
|
||||
if isinstance(api_data, dict):
|
||||
return {
|
||||
ParameterKeyEscaper.escape(k): v.to_struct()
|
||||
for k, v in api_data.items()
|
||||
}
|
||||
|
||||
return {
|
||||
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_metadata(
|
||||
cls,
|
||||
obj: Document,
|
||||
items: Sequence[MetadataItem],
|
||||
replace_metadata: bool,
|
||||
**more_updates,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_metadata"):
|
||||
update_cmds = dict()
|
||||
metadata = cls.metadata_from_api(items)
|
||||
if replace_metadata:
|
||||
update_cmds["set__metadata"] = metadata
|
||||
else:
|
||||
for key, value in metadata.items():
|
||||
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
|
||||
|
||||
return obj.update(**update_cmds, **more_updates)
|
||||
|
||||
@classmethod
|
||||
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
|
||||
with TimingContext("mongo", "delete_metadata"):
|
||||
return obj.update(
|
||||
**{
|
||||
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
|
||||
for key in set(keys)
|
||||
},
|
||||
**more_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
|
||||
for prefix in (
|
||||
"metadata.",
|
||||
"-metadata.",
|
||||
):
|
||||
paths = [
|
||||
cls._process_path(path) if path.startswith(prefix) else path
|
||||
for path in paths
|
||||
]
|
||||
return paths
|
||||
|
||||
@classmethod
|
||||
def escape_query_parameters(cls, call: APICall) -> dict:
|
||||
if not call.data:
|
||||
return call.data
|
||||
|
||||
keys = list(call.data)
|
||||
call_data = {
|
||||
safe_key: call.data[key]
|
||||
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
|
||||
}
|
||||
|
||||
projection = GetMixin.get_projection(call_data)
|
||||
if projection:
|
||||
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
|
||||
|
||||
ordering = GetMixin.get_ordering(call_data)
|
||||
if ordering:
|
||||
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
|
||||
|
||||
return call_data
|
||||
@@ -6,6 +6,7 @@ from redis import Redis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.bll.project import project_ids_with_children
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
@@ -42,6 +43,8 @@ class _TagsCache:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project__in=project_ids_with_children([project]))
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Callable,
|
||||
Mapping,
|
||||
Any,
|
||||
)
|
||||
|
||||
from mongoengine import Q, Document
|
||||
@@ -22,6 +23,7 @@ from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility, AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
|
||||
@@ -204,6 +206,7 @@ class ProjectBLL:
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new project.
|
||||
@@ -226,7 +229,12 @@ class ProjectBLL:
|
||||
created=now,
|
||||
last_update=now,
|
||||
)
|
||||
parent = _ensure_project(company=company, user=user, name=location)
|
||||
parent = _ensure_project(
|
||||
company=company,
|
||||
user=user,
|
||||
name=location,
|
||||
creation_params=parent_creation_params,
|
||||
)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
@@ -244,13 +252,14 @@ class ProjectBLL:
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find a project named `project_name` or create a new one.
|
||||
Returns project ID
|
||||
"""
|
||||
if not project_id and not project_name:
|
||||
raise ValueError("project id or name required")
|
||||
raise errors.bad_request.ValidationError("project id or name required")
|
||||
|
||||
if project_id:
|
||||
project = Project.objects(company=company, id=project_id).only("id").first()
|
||||
@@ -271,6 +280,7 @@ class ProjectBLL:
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
default_output_destination=default_output_destination,
|
||||
parent_creation_params=parent_creation_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -314,6 +324,7 @@ class ProjectBLL:
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
@@ -337,10 +348,9 @@ class ProjectBLL:
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company_id, project_ids=project_ids, filter_=filter_
|
||||
)
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
@@ -388,6 +398,17 @@ class ProjectBLL:
|
||||
}
|
||||
}
|
||||
|
||||
def max_started_subquery(condition):
|
||||
return {
|
||||
"$max": {
|
||||
"$cond": {
|
||||
"if": condition,
|
||||
"then": "$started",
|
||||
"else": datetime.min,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
@@ -431,14 +452,23 @@ class ProjectBLL:
|
||||
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
||||
cond, time_thresh=time_thresh
|
||||
)
|
||||
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
|
||||
|
||||
def get_state_filter() -> dict:
|
||||
if not specific_state:
|
||||
return {}
|
||||
if specific_state == EntityVisibility.archived:
|
||||
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
|
||||
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"type": {"$in": ["training", "testing", "annotation"]},
|
||||
"project": {"$in": project_ids},
|
||||
**cls.get_match_conditions(
|
||||
company=company_id, project_ids=project_ids, filter_=filter_
|
||||
),
|
||||
**get_state_filter(),
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
@@ -481,12 +511,14 @@ class ProjectBLL:
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
include_children: bool = True,
|
||||
return_hidden_children: bool = False,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
|
||||
child_projects = (
|
||||
_get_sub_projects(project_ids, _only=("id", "name"))
|
||||
_get_sub_projects(project_ids, _only=("id", "name", "system_tags"))
|
||||
if include_children
|
||||
else {}
|
||||
)
|
||||
@@ -497,6 +529,7 @@ class ProjectBLL:
|
||||
company,
|
||||
project_ids=list(project_ids_with_children),
|
||||
specific_state=specific_state,
|
||||
filter_=filter_,
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
@@ -547,6 +580,8 @@ class ProjectBLL:
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: a.get(section, 0) + b.get(section, 0)
|
||||
if not section.endswith("max_task_started")
|
||||
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
@@ -562,14 +597,20 @@ class ProjectBLL:
|
||||
project_section_statuses = nested_get(
|
||||
status_count, (project_id, section), default=default_counts
|
||||
)
|
||||
|
||||
def get_time_or_none(value):
|
||||
return value if value != datetime.min else None
|
||||
|
||||
return {
|
||||
"status_count": project_section_statuses,
|
||||
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
|
||||
"total_tasks": sum(project_section_statuses.values()),
|
||||
"total_runtime": project_runtime.get(section, 0),
|
||||
"completed_tasks": project_runtime.get(
|
||||
"completed_tasks_24h": project_runtime.get(
|
||||
f"{section}_recently_completed", 0
|
||||
),
|
||||
"last_task_run": get_time_or_none(
|
||||
project_runtime.get(f"{section}_max_task_started", datetime.min)
|
||||
),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
@@ -586,9 +627,24 @@ class ProjectBLL:
|
||||
for project in project_ids
|
||||
}
|
||||
|
||||
def filter_child_projects(project: str) -> Sequence[Project]:
|
||||
non_filtered_children = child_projects.get(project, [])
|
||||
if not non_filtered_children or return_hidden_children:
|
||||
return non_filtered_children
|
||||
|
||||
return [
|
||||
c
|
||||
for c in non_filtered_children
|
||||
if not c.system_tags
|
||||
or EntityVisibility.hidden.value not in c.system_tags
|
||||
]
|
||||
|
||||
children = {
|
||||
project: sorted(
|
||||
[{"id": c.id, "name": c.name} for c in child_projects.get(project, [])],
|
||||
[
|
||||
{"id": c.id, "name": c.name}
|
||||
for c in filter_child_projects(project)
|
||||
],
|
||||
key=itemgetter("name"),
|
||||
)
|
||||
for project in project_ids
|
||||
@@ -624,6 +680,30 @@ class ProjectBLL:
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_project_tags(
|
||||
cls,
|
||||
company_id: str,
|
||||
include_system: bool,
|
||||
projects: Sequence[str] = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
with TimingContext("mongo", "get_tags_from_db"):
|
||||
query = Q(company=company_id)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
|
||||
if projects:
|
||||
query &= Q(id__in=_ids_with_children(projects))
|
||||
|
||||
tags = Project.objects(query).distinct("tags")
|
||||
system_tags = (
|
||||
Project.objects(query).distinct("system_tags") if include_system else []
|
||||
)
|
||||
return tags, system_tags
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
cls,
|
||||
@@ -676,10 +756,14 @@ class ProjectBLL:
|
||||
If projects is None or empty then get parents for all the company tasks
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
|
||||
if projects:
|
||||
if include_subprojects:
|
||||
projects = _ids_with_children(projects)
|
||||
query &= Q(project__in=projects)
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
@@ -707,6 +791,8 @@ class ProjectBLL:
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
else:
|
||||
query &= Q(system_tags__nin=[EntityVisibility.hidden.value])
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@@ -722,8 +808,35 @@ class ProjectBLL:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
|
||||
@staticmethod
|
||||
def get_match_conditions(
|
||||
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
|
||||
):
|
||||
conditions = {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
if not filter_:
|
||||
return conditions
|
||||
|
||||
for field in ("tags", "system_tags"):
|
||||
field_filter = filter_.get(field)
|
||||
if not field_filter:
|
||||
continue
|
||||
if not isinstance(field_filter, list) or not all(
|
||||
isinstance(t, str) for t in field_filter
|
||||
):
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"List of strings expected for the field: {field}"
|
||||
)
|
||||
conditions[field] = {"$in": field_filter}
|
||||
|
||||
return conditions
|
||||
|
||||
@classmethod
|
||||
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
def calc_own_contents(
|
||||
cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of task/models per requested project
|
||||
Use separate aggregation calls on Task/Model instead of lookup
|
||||
@@ -734,35 +847,21 @@ class ProjectBLL:
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company, project_ids=project_ids, filter_=filter_
|
||||
)
|
||||
},
|
||||
{
|
||||
"$project": {"project": 1}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$project",
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
}
|
||||
{"$project": {"project": 1}},
|
||||
{"$group": {"_id": "$project", "count": {"$sum": 1}}},
|
||||
]
|
||||
|
||||
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
||||
return {
|
||||
data["_id"]: data["count"]
|
||||
for data in cls_.aggregate(pipeline)
|
||||
}
|
||||
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
|
||||
|
||||
with TimingContext("mongo", "get_security_groups"):
|
||||
tasks = get_agrregate_res(Task)
|
||||
models = get_agrregate_res(Model)
|
||||
return {
|
||||
pid: {
|
||||
"own_tasks": tasks.get(pid, 0),
|
||||
"own_models": models.get(pid, 0),
|
||||
}
|
||||
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
|
||||
for pid in project_ids
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
@@ -10,6 +10,7 @@ from typing import (
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
@@ -27,12 +28,21 @@ class ProjectQueries:
|
||||
def _get_project_constraint(
|
||||
project_ids: Sequence[str], include_subprojects: bool
|
||||
) -> dict:
|
||||
"""
|
||||
If passed projects is None means top level projects
|
||||
If passed projects is empty means no project filtering
|
||||
"""
|
||||
if include_subprojects:
|
||||
if project_ids is None:
|
||||
if not project_ids:
|
||||
return {}
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
|
||||
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
|
||||
if project_ids is None:
|
||||
project_ids = [None]
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
return {"project": {"$in": project_ids}}
|
||||
|
||||
@staticmethod
|
||||
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
|
||||
@@ -105,16 +115,11 @@ class ProjectQueries:
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
ParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
def _get_cached_param_values(
|
||||
self, key: str, last_update: datetime, allowed_delta_sec=0
|
||||
) -> Optional[ParamValues]:
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
@@ -122,12 +127,12 @@ class ProjectQueries:
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
log.error(f"Error retrieving params cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
def get_task_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
@@ -135,7 +140,7 @@ class ProjectQueries:
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
@@ -157,8 +162,12 @@ class ProjectQueries:
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key,
|
||||
last_update=last_update,
|
||||
allowed_delta_sec=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
),
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
@@ -239,3 +248,123 @@ class ProjectQueries:
|
||||
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@classmethod
|
||||
def get_model_metadata_keys(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"metadata": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
|
||||
{"$unwind": "$metadata"},
|
||||
{"$group": {"_id": "$metadata.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r.get("_id"))
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
def get_model_metadata_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
key: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
|
||||
last_updated_model = (
|
||||
Model.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_model:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
|
||||
last_update = last_updated_model.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.models.metadata_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@@ -25,7 +25,9 @@ def _validate_project_name(project_name: str) -> Tuple[str, str]:
|
||||
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
|
||||
|
||||
|
||||
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
def _ensure_project(
|
||||
company: str, user: str, name: str, creation_params: dict = None
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Makes sure that the project with the given name exists
|
||||
If needed auto-create the project and all the missing projects in the path to it
|
||||
@@ -48,9 +50,9 @@ def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name,
|
||||
description="",
|
||||
**(creation_params or dict(description="")),
|
||||
)
|
||||
parent = _ensure_project(company, user, location)
|
||||
parent = _ensure_project(company, user, location, creation_params=creation_params)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
||||
@@ -32,7 +32,7 @@ class QueueBLL(object):
|
||||
name: str,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[Sequence[dict]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Queue:
|
||||
"""Creates a queue"""
|
||||
with translate_errors_context():
|
||||
@@ -187,13 +187,15 @@ class QueueBLL(object):
|
||||
if any(e.task == task_id for e in queue.entries):
|
||||
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
entry = Entry(added=datetime.utcnow(), task=task_id)
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
|
||||
push__entries=entry, last_update=datetime.utcnow(), upsert=False
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
|
||||
task=task_id, **query
|
||||
@@ -233,7 +235,6 @@ class QueueBLL(object):
|
||||
queue = self.get_queue_with_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task_id
|
||||
)
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
entries_to_remove = [e for e in queue.entries if e.task == task_id]
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
@@ -241,6 +242,9 @@ class QueueBLL(object):
|
||||
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
return len(entries_to_remove) if res else 0
|
||||
|
||||
def reposition_task(
|
||||
|
||||
@@ -113,7 +113,7 @@ class WorkerBLL:
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
if not res:
|
||||
if not res and not config.get("apiserver.workers.auto_unregister", False):
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
|
||||
@@ -112,6 +112,8 @@
|
||||
workers {
|
||||
# Auto-register unknown workers on status reports and other calls
|
||||
auto_register: true
|
||||
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
|
||||
auto_unregister: true
|
||||
# Timeout in seconds on task status update. If exceeded
|
||||
# then task can be stopped without communicating to the worker
|
||||
task_update_timeout: 600
|
||||
|
||||
7
apiserver/config/default/services/models.conf
Normal file
7
apiserver/config/default/services/models.conf
Normal file
@@ -0,0 +1,7 @@
|
||||
metadata_values {
|
||||
# maximal amount of distinct model values to retrieve
|
||||
max_count: 100
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
@@ -60,3 +60,4 @@ def validate_id(cls, company, **kwargs):
|
||||
class EntityVisibility(Enum):
|
||||
active = "active"
|
||||
archived = "archived"
|
||||
hidden = "hidden"
|
||||
|
||||
@@ -50,6 +50,7 @@ class Credentials(EmbeddedDocument):
|
||||
secret = StringField(required=True)
|
||||
label = StringField()
|
||||
last_used = DateTimeField()
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
|
||||
@@ -95,6 +95,7 @@ class GetMixin(PropsMixin):
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {}
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
@@ -599,7 +600,7 @@ class GetMixin(PropsMixin):
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_data_with_scroll_and_filter_support(
|
||||
def get_data_with_scroll_support(
|
||||
cls,
|
||||
query_dict: dict,
|
||||
data_getter: Callable[[], Sequence[dict]],
|
||||
@@ -629,15 +630,12 @@ class GetMixin(PropsMixin):
|
||||
if cls._start_key in query_dict:
|
||||
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
|
||||
|
||||
def update_state(returned_len: int):
|
||||
if not state:
|
||||
return
|
||||
if state:
|
||||
state.position = query_dict[cls._start_key]
|
||||
cls.get_cache_manager().set_state(state)
|
||||
if ret_params is not None:
|
||||
ret_params["scroll_id"] = state.id
|
||||
|
||||
update_state(len(data))
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
@@ -770,7 +768,7 @@ class GetMixin(PropsMixin):
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
)
|
||||
return cls.get_data_with_scroll_and_filter_support(
|
||||
return cls.get_data_with_scroll_support(
|
||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
@@ -12,6 +10,7 @@ from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeDictField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
@@ -22,6 +21,10 @@ from apiserver.database.model.task.task import Task
|
||||
|
||||
|
||||
class Model(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
@@ -30,8 +33,6 @@ class Model(AttributedDocument):
|
||||
"project",
|
||||
"task",
|
||||
"last_update",
|
||||
"metadata.key",
|
||||
"metadata.type",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
@@ -63,6 +64,7 @@ class Model(AttributedDocument):
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
"metadata.*",
|
||||
),
|
||||
datetime_fields=("last_update",),
|
||||
)
|
||||
@@ -86,6 +88,6 @@ class Model(AttributedDocument):
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin, AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
@@ -19,23 +22,25 @@ from apiserver.database.model.task.task import Task
|
||||
|
||||
class Entry(EmbeddedDocument, ProperDictMixin):
|
||||
""" Entry representing a task waiting in the queue """
|
||||
|
||||
task = StringField(required=True, reference_field=Task)
|
||||
''' Task ID '''
|
||||
""" Task ID """
|
||||
added = DateTimeField(required=True)
|
||||
''' Added to the queue '''
|
||||
""" Added to the queue """
|
||||
|
||||
|
||||
class Queue(DbModelMixin, Document):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"indexes": ["metadata.key", "metadata.type"],
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@@ -44,10 +49,12 @@ class Queue(DbModelMixin, Document):
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
tags = SafeSortedListField(
|
||||
StringField(required=True), default=list, user_set_allowed=True
|
||||
)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
||||
@@ -159,11 +159,10 @@ external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"execution.parameters.": AttributedDocument._numeric_locale,
|
||||
"last_metrics.": AttributedDocument._numeric_locale,
|
||||
"hyperparams.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
@@ -176,6 +175,8 @@ class Task(AttributedDocument):
|
||||
"active_duration",
|
||||
"parent",
|
||||
"project",
|
||||
"last_update",
|
||||
"status_changed",
|
||||
"models.input.model",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
@@ -184,7 +185,10 @@ class Task(AttributedDocument):
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{"fields": ["company", "project"], "collation": _numeric_locale},
|
||||
{
|
||||
"fields": ["company", "project"],
|
||||
"collation": AttributedDocument._numeric_locale,
|
||||
},
|
||||
{
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import (
|
||||
Union,
|
||||
Mapping,
|
||||
IO,
|
||||
Callable,
|
||||
)
|
||||
from urllib.parse import unquote, urlparse
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@@ -744,6 +746,19 @@ class PrePopulate:
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_model_data(model_data: dict) -> dict:
|
||||
metadata_key = "metadata"
|
||||
metadata = model_data.get(metadata_key)
|
||||
if isinstance(metadata, list):
|
||||
metadata = {
|
||||
ParameterKeyEscaper.escape(item["key"]): item
|
||||
for item in metadata
|
||||
if isinstance(item, dict) and "key" in item
|
||||
}
|
||||
model_data[metadata_key] = metadata
|
||||
return model_data
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_task_data(task_data: dict) -> dict:
|
||||
"""
|
||||
@@ -828,9 +843,14 @@ class PrePopulate:
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
tasks = []
|
||||
override_project_count = 0
|
||||
data_upgrade_funcs: Mapping[Type, Callable] = {
|
||||
cls.task_cls: cls._upgrade_task_data,
|
||||
cls.model_cls: cls._upgrade_model_data,
|
||||
}
|
||||
for item in cls.json_lines(f):
|
||||
if cls_ == cls.task_cls:
|
||||
item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item)))
|
||||
upgrade_func = data_upgrade_funcs.get(cls_)
|
||||
if upgrade_func:
|
||||
item = json.dumps(upgrade_func(json.loads(item)))
|
||||
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
|
||||
29
apiserver/mongo/migrations/1_3_0.py
Normal file
29
apiserver/mongo/migrations/1_3_0.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
|
||||
|
||||
def _convert_metadata(db: Database, name):
|
||||
collection: Collection = db[name]
|
||||
|
||||
metadata_field = "metadata"
|
||||
query = {metadata_field: {"$exists": True, "$type": 4}}
|
||||
for doc in collection.find(filter=query, projection=(metadata_field,)):
|
||||
metadata = {
|
||||
ParameterKeyEscaper.escape(item["key"]): item
|
||||
for item in doc.get(metadata_field, [])
|
||||
if isinstance(item, dict) and "key" in item
|
||||
}
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {"metadata": metadata}},
|
||||
)
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
collections = ["model", "queue"]
|
||||
for name in collections:
|
||||
_convert_metadata(db, name)
|
||||
|
||||
_drop_all_indices_from_collections(db, collections)
|
||||
@@ -24,6 +24,10 @@ _definitions {
|
||||
description: ""
|
||||
format: "date-time"
|
||||
}
|
||||
last_used_from {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
role {
|
||||
@@ -226,6 +230,12 @@ create_credentials {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.17": ${create_credentials."2.1"} {
|
||||
request.properties.label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_credentials {
|
||||
|
||||
@@ -1304,3 +1304,24 @@ scalar_metrics_iter_raw {
|
||||
}
|
||||
}
|
||||
}
|
||||
clear_scroll {
|
||||
"2.18" {
|
||||
description: "Clear an open Scroll ID"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
scroll_id
|
||||
]
|
||||
properties {
|
||||
scroll_id {
|
||||
description: "Scroll ID as returned by previous events service calls"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -61,14 +61,14 @@ _definitions {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
|
||||
@@ -98,9 +98,11 @@ _definitions {
|
||||
additionalProperties: true
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -407,7 +409,7 @@ update_for_task {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
override_model_id {
|
||||
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
|
||||
@@ -473,7 +475,7 @@ create {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
@@ -529,9 +531,11 @@ create {
|
||||
}
|
||||
"2.13": ${create."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -568,7 +572,7 @@ edit {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
@@ -624,9 +628,11 @@ edit {
|
||||
}
|
||||
"2.13": ${edit."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -657,7 +663,7 @@ update {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks Default is false."
|
||||
@@ -707,9 +713,11 @@ update {
|
||||
}
|
||||
"2.13": ${update."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -718,7 +726,7 @@ publish_many {
|
||||
description: Publish models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to publish"
|
||||
ids.description: "IDs of the models to publish"
|
||||
force_publish_task {
|
||||
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
|
||||
type: boolean
|
||||
@@ -779,7 +787,7 @@ archive_many {
|
||||
description: Archive models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to archive"
|
||||
ids.description: "IDs of the models to archive"
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -815,10 +823,9 @@ delete_many {
|
||||
description: Delete models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to delete"
|
||||
ids.description: "IDs of the models to delete"
|
||||
force {
|
||||
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
|
||||
"""
|
||||
description: "Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
@@ -975,6 +982,11 @@ add_or_update_metadata {
|
||||
description: "Metadata items to add or update"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
replace_metadata {
|
||||
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
47
apiserver/schema/services/pipelines.conf
Normal file
47
apiserver/schema/services/pipelines.conf
Normal file
@@ -0,0 +1,47 @@
|
||||
_description: "Provides a management API for pipelines in the system."
|
||||
_definitions {
|
||||
}
|
||||
|
||||
start_pipeline {
|
||||
"2.17" {
|
||||
description: "Start a pipeline"
|
||||
request {
|
||||
type: object
|
||||
required: [ task ]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the task on which the pipeline will be based"
|
||||
type: string
|
||||
}
|
||||
queue {
|
||||
description: "Queue ID in which the created pipeline task will be enqueued"
|
||||
type: string
|
||||
}
|
||||
args {
|
||||
description: "Task arguments, name/value to be placed in the hyperparameters Args section"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
name: { type: string }
|
||||
value: { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
pipeline {
|
||||
description: "ID of the new pipeline task"
|
||||
type: string
|
||||
}
|
||||
enqueued {
|
||||
description: "True if the task was successfuly enqueued"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -42,15 +42,20 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last update time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -70,6 +75,18 @@ _definitions {
|
||||
description: "Total run time of all tasks in project (in seconds)"
|
||||
type: integer
|
||||
}
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks_24h {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
last_task_run {
|
||||
description: "The most recent started time of a task"
|
||||
type: integer
|
||||
}
|
||||
status_count {
|
||||
description: "Status counts"
|
||||
type: object
|
||||
@@ -78,6 +95,10 @@ _definitions {
|
||||
description: "Number of 'created' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
completed {
|
||||
description: "Number of 'completed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
queued {
|
||||
description: "Number of 'queued' tasks in project"
|
||||
type: integer
|
||||
@@ -158,14 +179,14 @@ _definitions {
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -299,14 +320,14 @@ create {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -419,7 +440,6 @@ get_all {
|
||||
description: "Projects list"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/projects_get_all_response_single" }
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -435,7 +455,14 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.13"} {
|
||||
"2.14": ${get_all."2.13"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden projects are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.14"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
@@ -516,7 +543,14 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
"2.14": ${get_all_ex."2.13"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden projects are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.14"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
@@ -545,39 +579,16 @@ get_all_ex {
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
response {
|
||||
}
|
||||
"2.17": ${get_all_ex."2.16"} {
|
||||
request.properties.include_stats_filter {
|
||||
description: The filter for selecting entities that participate in statistics calculation
|
||||
type: object
|
||||
properties {
|
||||
stats {
|
||||
properties {
|
||||
active.properties {
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
running_tasks {
|
||||
description: "Number of running tasks"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
archived.properties {
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
running_tasks {
|
||||
description: "Number of running tasks"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
system_tags {
|
||||
description: The list of allowed system tags
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -603,14 +614,14 @@ update {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -748,7 +759,6 @@ delete {
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -881,6 +891,7 @@ get_hyper_parameters {
|
||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
@@ -929,6 +940,105 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_values {
|
||||
"2.17" {
|
||||
description: """Get a list of distinct values for the chosen model metadata key"""
|
||||
request {
|
||||
type: object
|
||||
required: [key]
|
||||
properties {
|
||||
projects {
|
||||
description: "Project IDs"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
key {
|
||||
description: "Metadata key"
|
||||
type: string
|
||||
}
|
||||
allow_public {
|
||||
description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
total {
|
||||
description: "Total number of distinct values"
|
||||
type: integer
|
||||
}
|
||||
values {
|
||||
description: "The list of the unique values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_keys {
|
||||
"2.17" {
|
||||
description: """Get a list of all metadata keys used in models within the given project."""
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
|
||||
page {
|
||||
description: "Page number"
|
||||
default: 0
|
||||
type: integer
|
||||
}
|
||||
page_size {
|
||||
description: "Page size"
|
||||
default: 500
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
keys {
|
||||
description: "A list of model keys"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
type: integer
|
||||
}
|
||||
total {
|
||||
description: "Total number of results"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_project_tags {
|
||||
"2.17" {
|
||||
description: "Get user and system tags used for the specified projects and their children"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
@@ -936,7 +1046,6 @@ get_task_tags {
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
@@ -1058,4 +1167,4 @@ get_task_parents {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,9 +79,11 @@ _definitions {
|
||||
items { "$ref": "#/definitions/entry" }
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Queue metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -281,6 +283,15 @@ create {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${create."2.4"} {
|
||||
metadata {
|
||||
description: "Queue metadata"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.4" {
|
||||
@@ -322,7 +333,15 @@ update {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${update."2.4"} {
|
||||
metadata {
|
||||
description: "Queue metadata"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -632,6 +651,11 @@ add_or_update_metadata {
|
||||
description: "Metadata items to add or update"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
replace_metadata {
|
||||
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
@@ -685,7 +685,14 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
"2.14": ${get_all_ex."2.13"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden tasks are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.14"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
@@ -822,7 +829,14 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.1"} {
|
||||
"2.14": ${get_all."2.1"} {
|
||||
request.properties.search_hidden {
|
||||
description: "If set to 'true' then hidden tasks are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.14"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
|
||||
@@ -25,6 +25,7 @@ from apiserver.server_init.request_handlers import RequestHandlers
|
||||
from apiserver.service_repo import ServiceRepo
|
||||
from apiserver.sync import distributed_lock
|
||||
from apiserver.updates import check_updates_thread
|
||||
from apiserver.utilities.env import get_bool
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -46,10 +47,13 @@ class AppSequence:
|
||||
def _attach_request_handlers(self, request_handlers: RequestHandlers):
|
||||
self.app.before_first_request(request_handlers.before_app_first_request)
|
||||
self.app.before_request(request_handlers.before_request)
|
||||
self.app.after_request(request_handlers.after_request)
|
||||
|
||||
def _configure(self):
|
||||
CORS(self.app, **config.get("apiserver.cors"))
|
||||
Compress(self.app)
|
||||
|
||||
if get_bool("CLEARML_COMPRESS_RESP", default=True):
|
||||
Compress(self.app)
|
||||
|
||||
self.app.config["SECRET_KEY"] = config.get(
|
||||
"secure.http.session_secret.apiserver"
|
||||
|
||||
@@ -18,6 +18,7 @@ log = config.logger(__file__)
|
||||
|
||||
class RequestHandlers:
|
||||
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
|
||||
_server_header = config.get("apiserver.response.headers.server", "clearml")
|
||||
|
||||
def before_app_first_request(self):
|
||||
pass
|
||||
@@ -28,6 +29,9 @@ class RequestHandlers:
|
||||
if "/static/" in request.path:
|
||||
return
|
||||
|
||||
if request.content_encoding:
|
||||
return f"Content encoding is not supported ({request.content_encoding})", 415
|
||||
|
||||
try:
|
||||
call = self._create_api_call(request)
|
||||
load_data_callback = partial(self._load_call_data, req=request)
|
||||
@@ -81,6 +85,10 @@ class RequestHandlers:
|
||||
log.exception(f"Failed processing request {request.url}: {ex}")
|
||||
return f"Failed processing request {request.url}", 500
|
||||
|
||||
def after_request(self, response):
|
||||
response.headers["server"] = self._server_header
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
|
||||
def convert_value(v: str):
|
||||
|
||||
@@ -95,8 +95,8 @@ class DataContainer(object):
|
||||
@raw_data.setter
|
||||
def raw_data(self, value):
|
||||
assert isinstance(
|
||||
value, string_types + (types.GeneratorType,)
|
||||
), "Raw data must be a string type or generator"
|
||||
value, string_types + (types.GeneratorType, bytes)
|
||||
), "Raw data must be a string type or bytes or generator"
|
||||
self._raw_data = value
|
||||
|
||||
@property
|
||||
@@ -395,6 +395,10 @@ class APICall(DataContainer):
|
||||
self._auth_cookie = auth_cookie
|
||||
self._json_flags = {}
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
return self._files
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
|
||||
@@ -51,7 +51,7 @@ def authorize_token(jwt_token, *_, **__):
|
||||
)
|
||||
|
||||
|
||||
def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
def authorize_credentials(auth_data, service, action, call):
|
||||
"""Validate credentials against service/action and request data (dicts).
|
||||
Returns a new basic object (auth payload)
|
||||
"""
|
||||
@@ -100,7 +100,12 @@ def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
if not fixed_user:
|
||||
# In case these are proper credentials, update last used time
|
||||
User.objects(id=user.id, credentials__key=access_key).update(
|
||||
**{"set__credentials__$__last_used": datetime.utcnow()}
|
||||
**{
|
||||
"set__credentials__$__last_used": datetime.utcnow(),
|
||||
"set__credentials__$__last_used_from": call.get_worker(
|
||||
default=call.real_ip
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
with TimingContext("mongo", "company_by_id"):
|
||||
|
||||
@@ -38,7 +38,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.16")
|
||||
_max_version = PartialVersion("2.17")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
|
||||
@@ -69,7 +69,7 @@ def validate_auth(endpoint, call):
|
||||
auth = call.authorization or ""
|
||||
auth_type, _, auth_data = auth.partition(" ")
|
||||
authorize_func = get_auth_func(auth_type)
|
||||
call.auth = authorize_func(auth_data, service, action, call.batched_data)
|
||||
call.auth = authorize_func(auth_data, service, action, call)
|
||||
except Exception:
|
||||
if endpoint.authorize:
|
||||
# if endpoint requires authorization, re-raise exception
|
||||
|
||||
@@ -161,7 +161,10 @@ def get_credentials(call: APICall, _, __):
|
||||
call.result.data_model = GetCredentialsResponse(
|
||||
credentials=[
|
||||
CredentialsResponse(
|
||||
access_key=c.key, last_used=c.last_used, label=c.label
|
||||
access_key=c.key,
|
||||
last_used=c.last_used,
|
||||
label=c.label,
|
||||
last_used_from=c.last_used_from,
|
||||
)
|
||||
for c in user.credentials
|
||||
]
|
||||
|
||||
@@ -25,6 +25,7 @@ from apiserver.apimodels.events import (
|
||||
TaskPlotsRequest,
|
||||
TaskEventsRequest,
|
||||
ScalarMetricsIterRawRequest,
|
||||
ClearScrollRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
@@ -936,3 +937,9 @@ def scalar_metrics_iter_raw(
|
||||
scroll_id=scroll.get_scroll_id(),
|
||||
variants=variants,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.clear_scroll", min_version="2.18")
|
||||
def clear_scroll(_, __, request: ClearScrollRequest):
|
||||
if request.scroll_id:
|
||||
event_bll.clear_scroll(request.scroll_id)
|
||||
|
||||
@@ -21,16 +21,14 @@ from apiserver.apimodels.models import (
|
||||
ModelsPublishManyRequest,
|
||||
ModelsDeleteManyRequest,
|
||||
)
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.model import ModelBLL, Metadata
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import publish_task
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import validate_id
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import (
|
||||
@@ -50,8 +48,8 @@ from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
ModelsBackwardsCompatibility,
|
||||
validate_metadata,
|
||||
get_metadata_from_api,
|
||||
unescape_metadata,
|
||||
escape_metadata,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
@@ -64,19 +62,20 @@ project_bll = ProjectBLL()
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
Metadata.escape_query_parameters(call)
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
)
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
conform_output_tags(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
conform_output_tags(call, models[0])
|
||||
unescape_metadata(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
@@ -86,25 +85,25 @@ def get_by_task_id(call: APICall, company_id, _):
|
||||
|
||||
task_id = call.data["task"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not task.models or not task.models.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="models.output")
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not task.models or not task.models.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="models.output")
|
||||
|
||||
model_id = task.models.output[-1].model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
call.result.data = {"model": model_dict}
|
||||
model_id = task.models.output[-1].model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
unescape_metadata(call, model_dict)
|
||||
call.result.data = {"model": model_dict}
|
||||
|
||||
|
||||
def _process_include_subprojects(call_data: dict):
|
||||
@@ -121,47 +120,50 @@ def _process_include_subprojects(call_data: dict):
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
_process_include_subprojects(call.data)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
_process_include_subprojects(call.data)
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_by_id_ex", required_fields=["id"])
|
||||
def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_by_id_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_by_id_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
|
||||
@@ -189,15 +191,22 @@ create_fields = {
|
||||
"metadata": list,
|
||||
}
|
||||
|
||||
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata", "system_tags", "tags")
|
||||
last_update_fields = (
|
||||
"uri",
|
||||
"framework",
|
||||
"design",
|
||||
"labels",
|
||||
"ready",
|
||||
"metadata",
|
||||
"system_tags",
|
||||
"tags",
|
||||
)
|
||||
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
metadata = fields.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
escape_metadata(fields)
|
||||
return fields
|
||||
|
||||
|
||||
@@ -231,82 +240,80 @@ def update_for_task(call: APICall, company_id, _):
|
||||
"exactly one field is required", fields=("uri", "override_model_id")
|
||||
)
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
f"model can only be updated for tasks in the {allowed_states} states",
|
||||
**query,
|
||||
)
|
||||
|
||||
if override_model_id:
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=override_model_id
|
||||
)
|
||||
else:
|
||||
if "name" not in call.data:
|
||||
# use task name if name not provided
|
||||
call.data["name"] = task.name
|
||||
|
||||
if "comment" not in call.data:
|
||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||
|
||||
if task.models and task.models.output:
|
||||
# model exists, update
|
||||
model_id = task.models.output[-1].model
|
||||
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
||||
res.update({"id": model_id, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
|
||||
# new model, create
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
created=now,
|
||||
last_update=now,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
project=task.project,
|
||||
framework=task.execution.framework,
|
||||
parent=task.models.input[0].model
|
||||
if task.models and task.models.input
|
||||
else None,
|
||||
design=task.execution.model_desc,
|
||||
labels=task.execution.model_labels,
|
||||
ready=(task.status == TaskStatus.published),
|
||||
**fields,
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
f"model can only be updated for tasks in the {allowed_states} states",
|
||||
**query,
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
models__output=[
|
||||
ModelItem(
|
||||
model=model.id,
|
||||
name=TaskModelNames[TaskModelTypes.output],
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if override_model_id:
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=override_model_id
|
||||
)
|
||||
else:
|
||||
if "name" not in call.data:
|
||||
# use task name if name not provided
|
||||
call.data["name"] = task.name
|
||||
|
||||
if "comment" not in call.data:
|
||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||
|
||||
if task.models and task.models.output:
|
||||
# model exists, update
|
||||
model_id = task.models.output[-1].model
|
||||
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
||||
res.update({"id": model_id, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
|
||||
# new model, create
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
created=now,
|
||||
last_update=now,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
project=task.project,
|
||||
framework=task.execution.framework,
|
||||
parent=task.models.input[0].model
|
||||
if task.models and task.models.input
|
||||
else None,
|
||||
design=task.execution.model_desc,
|
||||
labels=task.execution.model_labels,
|
||||
ready=(task.status == TaskStatus.published),
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
models__output=[
|
||||
ModelItem(
|
||||
model=model.id,
|
||||
name=TaskModelNames[TaskModelTypes.output],
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
call.result.data = {"id": model.id, "created": True}
|
||||
call.result.data = {"id": model.id, "created": True}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -319,36 +326,33 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
if req_model.public:
|
||||
company_id = ""
|
||||
|
||||
with translate_errors_context():
|
||||
project = req_model.project
|
||||
if project:
|
||||
validate_id(Project, company=company_id, project=project)
|
||||
|
||||
project = req_model.project
|
||||
if project:
|
||||
validate_id(Project, company=company_id, project=project)
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(company_id, req_data)
|
||||
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(company_id, req_data)
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
escape_metadata(fields)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
validate_metadata(fields.get("metadata"))
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
|
||||
|
||||
def prepare_update_fields(call, company_id, fields: dict):
|
||||
@@ -383,6 +387,7 @@ def prepare_update_fields(call, company_id, fields: dict):
|
||||
)
|
||||
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
escape_metadata(fields)
|
||||
return fields
|
||||
|
||||
|
||||
@@ -394,89 +399,85 @@ def validate_task(company_id, fields: dict):
|
||||
def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
value = fields[key]
|
||||
if (
|
||||
field
|
||||
and isinstance(value, dict)
|
||||
and isinstance(field, EmbeddedDocument)
|
||||
):
|
||||
d = field.to_mongo(use_db_field=False).to_dict()
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
if fields:
|
||||
if any(uf in fields for uf in last_update_fields):
|
||||
fields.update(last_update=datetime.utcnow())
|
||||
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
value = fields[key]
|
||||
if (
|
||||
field
|
||||
and isinstance(value, dict)
|
||||
and isinstance(field, EmbeddedDocument)
|
||||
):
|
||||
d = field.to_mongo(use_db_field=False).to_dict()
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
if fields:
|
||||
if any(uf in fields for uf in last_update_fields):
|
||||
fields.update(last_update=datetime.utcnow())
|
||||
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
new_project = fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, model.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
new_project = fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, model.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
unescape_metadata(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
if any(uf in updated_fields for uf in last_update_fields):
|
||||
model.update(upsert=False, last_update=datetime.utcnow())
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=updated_fields
|
||||
)
|
||||
|
||||
metadata = data.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
if any(uf in updated_fields for uf in last_update_fields):
|
||||
model.update(upsert=False, last_update=datetime.utcnow())
|
||||
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=updated_fields
|
||||
)
|
||||
conform_output_tags(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
conform_output_tags(call, updated_fields)
|
||||
unescape_metadata(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -641,26 +642,25 @@ def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
model_id = request.model
|
||||
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
updated = metadata_add_or_update(
|
||||
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
|
||||
)
|
||||
if updated:
|
||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
||||
|
||||
return {"updated": updated}
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
return {
|
||||
"updated": Metadata.edit_metadata(
|
||||
model,
|
||||
items=request.metadata,
|
||||
replace_metadata=request.replace_metadata,
|
||||
last_update=datetime.utcnow(),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
model_id = request.model
|
||||
ModelBLL.get_company_model_by_id(
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
|
||||
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
|
||||
if updated:
|
||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
||||
|
||||
return {"updated": updated}
|
||||
return {
|
||||
"updated": Metadata.delete_metadata(
|
||||
model, keys=request.keys, last_update=datetime.utcnow()
|
||||
)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.database.model import User
|
||||
from apiserver.service_repo import endpoint, APICall
|
||||
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
|
||||
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
|
||||
|
||||
org_bll = OrgBLL()
|
||||
|
||||
@@ -21,17 +21,13 @@ def get_tags(call: APICall, company, request: TagsRequest):
|
||||
for field, vals in tags.items():
|
||||
ret[field] |= vals
|
||||
|
||||
call.result.data = get_tags_response(ret)
|
||||
call.result.data = sort_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint("organization.get_user_companies")
|
||||
def get_user_companies(call: APICall, company_id: str, _):
|
||||
users = [
|
||||
{
|
||||
"id": u.id,
|
||||
"name": u.name,
|
||||
"avatar": u.avatar,
|
||||
}
|
||||
{"id": u.id, "name": u.name, "avatar": u.avatar}
|
||||
for u in User.objects(company=company_id).only("avatar", "name", "company")
|
||||
]
|
||||
|
||||
|
||||
68
apiserver/services/pipelines.py
Normal file
68
apiserver/services/pipelines.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import re
|
||||
|
||||
from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import enqueue_task
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
def _update_task_name(task: Task):
|
||||
if not task or not task.project:
|
||||
return
|
||||
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if not project:
|
||||
return
|
||||
|
||||
_, _, name_prefix = project.name.rpartition("/")
|
||||
name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$")
|
||||
count = Task.objects(
|
||||
project=task.project, system_tags__in=["pipeline"], name=name_mask
|
||||
).count()
|
||||
new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix
|
||||
task.update(name=new_name)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
|
||||
)
|
||||
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
|
||||
hyperparams = None
|
||||
if request.args:
|
||||
hyperparams = {
|
||||
"Args": {
|
||||
str(arg.name): {
|
||||
"section": "Args",
|
||||
"name": str(arg.name),
|
||||
"value": str(arg.value),
|
||||
}
|
||||
for arg in request.args or []
|
||||
}
|
||||
}
|
||||
|
||||
task, _ = task_bll.clone_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
hyperparams=hyperparams,
|
||||
)
|
||||
|
||||
_update_task_name(task)
|
||||
|
||||
queued, res = enqueue_task(
|
||||
task_id=task.id,
|
||||
company_id=company_id,
|
||||
queue_id=request.queue,
|
||||
status_message="Starting pipeline",
|
||||
status_reason="",
|
||||
)
|
||||
|
||||
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
|
||||
@@ -17,6 +17,7 @@ from apiserver.apimodels.projects import (
|
||||
MergeRequest,
|
||||
ProjectOrNoneRequest,
|
||||
ProjectRequest,
|
||||
ProjectModelMetadataValuesRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, ProjectQueries
|
||||
@@ -25,6 +26,7 @@ from apiserver.bll.project.project_cleanup import (
|
||||
validate_project_delete,
|
||||
)
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.utils import (
|
||||
parse_from_call,
|
||||
@@ -35,7 +37,7 @@ from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
get_tags_filter_dictionary,
|
||||
get_tags_response,
|
||||
sort_tags_response,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
@@ -72,6 +74,16 @@ def get_by_id(call):
|
||||
call.result.data = {"project": project_dict}
|
||||
|
||||
|
||||
def _hidden_query(search_hidden: bool, ids: Sequence) -> Q:
|
||||
"""
|
||||
1. Add only non-hidden tasks search condition (unless specifically specified differently)
|
||||
"""
|
||||
if search_hidden or ids:
|
||||
return Q()
|
||||
|
||||
return Q(system_tags__ne=EntityVisibility.hidden.value)
|
||||
|
||||
|
||||
def _adjust_search_parameters(data: dict, shallow_search: bool):
|
||||
"""
|
||||
1. Make sure that there is no external query on path
|
||||
@@ -90,12 +102,14 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
|
||||
|
||||
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
|
||||
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
allow_public = not request.non_public
|
||||
data = call.data
|
||||
conform_tag_fields(call, data)
|
||||
allow_public = not request.non_public
|
||||
requested_ids = data.get("id")
|
||||
_adjust_search_parameters(
|
||||
data, shallow_search=request.shallow_search,
|
||||
)
|
||||
with TimingContext("mongo", "projects_get_all"):
|
||||
data = call.data
|
||||
if request.active_users:
|
||||
ids = project_bll.get_projects_with_active_user(
|
||||
company=company_id,
|
||||
@@ -104,16 +118,14 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
allow_public=allow_public,
|
||||
)
|
||||
if not ids:
|
||||
call.result.data = {"projects": []}
|
||||
return
|
||||
return {"projects": []}
|
||||
data["id"] = ids
|
||||
|
||||
_adjust_search_parameters(data, shallow_search=request.shallow_search)
|
||||
|
||||
ret_params = {}
|
||||
projects = Project.get_many_with_join(
|
||||
projects: Sequence[dict] = Project.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=data,
|
||||
query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
@@ -124,7 +136,9 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
}
|
||||
if existing_requested_ids:
|
||||
contents = project_bll.calc_own_contents(
|
||||
company=company_id, project_ids=list(existing_requested_ids)
|
||||
company=company_id,
|
||||
project_ids=list(existing_requested_ids),
|
||||
filter_=request.include_stats_filter,
|
||||
)
|
||||
for project in projects:
|
||||
project.update(**contents.get(project["id"], {}))
|
||||
@@ -140,6 +154,8 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
project_ids=list(project_ids),
|
||||
specific_state=request.stats_for_state,
|
||||
include_children=request.stats_with_children,
|
||||
return_hidden_children=request.search_hidden,
|
||||
filter_=request.include_stats_filter,
|
||||
)
|
||||
|
||||
for project in projects:
|
||||
@@ -151,20 +167,24 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
data = call.data
|
||||
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
conform_tag_fields(call, data)
|
||||
_adjust_search_parameters(
|
||||
data, shallow_search=data.get("shallow_search", False),
|
||||
)
|
||||
with TimingContext("mongo", "projects_get_all"):
|
||||
ret_params = {}
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=data,
|
||||
query=_hidden_query(
|
||||
search_hidden=data.get("search_hidden"), ids=data.get("id")
|
||||
),
|
||||
parameters=data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
|
||||
|
||||
@@ -275,6 +295,40 @@ def get_unique_metric_variants(
|
||||
call.result.data = {"metrics": metrics}
|
||||
|
||||
|
||||
@endpoint("projects.get_model_metadata_keys",)
|
||||
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
|
||||
total, remaining, keys = project_queries.get_model_metadata_keys(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"remaining": remaining,
|
||||
"keys": keys,
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.get_model_metadata_values")
|
||||
def get_model_metadata_values(
|
||||
call: APICall, company_id: str, request: ProjectModelMetadataValuesRequest
|
||||
):
|
||||
total, values = project_queries.get_model_metadata_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
key=request.key,
|
||||
include_subprojects=request.include_subprojects,
|
||||
allow_public=request.allow_public,
|
||||
)
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"values": values,
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.9",
|
||||
@@ -305,7 +359,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsReque
|
||||
def get_hyperparam_values(
|
||||
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
||||
):
|
||||
total, values = project_queries.get_hyperparam_distinct_values(
|
||||
total, values = project_queries.get_task_hyperparam_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
section=request.section,
|
||||
@@ -319,6 +373,17 @@ def get_hyperparam_values(
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.get_project_tags")
|
||||
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
tags, system_tags = project_bll.get_project_tags(
|
||||
company,
|
||||
include_system=request.include_system,
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = sort_tags_response({"tags": tags, "system_tags": system_tags})
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||
)
|
||||
@@ -330,7 +395,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
call.result.data = sort_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -344,7 +409,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
call.result.data = sort_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
|
||||
@@ -13,17 +13,19 @@ from apiserver.apimodels.queues import (
|
||||
QueueMetrics,
|
||||
AddOrUpdateMetadataRequest,
|
||||
DeleteMetadataRequest,
|
||||
GetNextTaskRequest,
|
||||
)
|
||||
from apiserver.bll.model import Metadata
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
conform_tags,
|
||||
get_metadata_from_api,
|
||||
escape_metadata,
|
||||
unescape_metadata,
|
||||
)
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
|
||||
@@ -36,6 +38,7 @@ def get_by_id(call: APICall, company_id, req_model: QueueRequest):
|
||||
queue = queue_bll.get_by_id(company_id, req_model.queue)
|
||||
queue_dict = queue.to_proper_dict()
|
||||
conform_output_tags(call, queue_dict)
|
||||
unescape_metadata(call, queue_dict)
|
||||
call.result.data = {"queue": queue_dict}
|
||||
|
||||
|
||||
@@ -49,13 +52,13 @@ def get_by_id(call: APICall):
|
||||
def get_all_ex(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_queue_infos(
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
ret_params=ret_params,
|
||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
unescape_metadata(call, queues)
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@@ -63,13 +66,12 @@ def get_all_ex(call: APICall):
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_all(
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
ret_params=ret_params,
|
||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
unescape_metadata(call, queues)
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@@ -83,7 +85,7 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
name=request.name,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
metadata=get_metadata_from_api(request.metadata),
|
||||
metadata=Metadata.metadata_from_api(request.metadata),
|
||||
)
|
||||
call.result.data = {"id": queue.id}
|
||||
|
||||
@@ -97,10 +99,12 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
data = call.data_model_for_partial_update
|
||||
conform_tag_fields(call, data, validate=True)
|
||||
escape_metadata(data)
|
||||
updated, fields = queue_bll.update(
|
||||
company_id=company_id, queue_id=req_model.queue, **data
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
unescape_metadata(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
|
||||
@@ -121,11 +125,19 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("queues.get_next_task", min_version="2.4", request_data_model=QueueRequest)
|
||||
def get_next_task(call: APICall, company_id, req_model: QueueRequest):
|
||||
task = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
|
||||
if task:
|
||||
call.result.data = {"entry": task.to_proper_dict()}
|
||||
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
|
||||
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
|
||||
entry = queue_bll.get_next_task(
|
||||
company_id=company_id, queue_id=req_model.queue
|
||||
)
|
||||
if entry:
|
||||
data = {"entry": entry.to_proper_dict()}
|
||||
if req_model.get_task_info:
|
||||
task = Task.objects(id=entry.task).first()
|
||||
if task:
|
||||
data["task_info"] = {"company": task.company, "user": task.user}
|
||||
|
||||
call.result.data = data
|
||||
|
||||
|
||||
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
|
||||
@@ -245,21 +257,19 @@ def get_queue_metrics(
|
||||
|
||||
@endpoint("queues.add_or_update_metadata", min_version="2.13")
|
||||
def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return {
|
||||
"updated": metadata_add_or_update(
|
||||
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
|
||||
"updated": Metadata.edit_metadata(
|
||||
queue, items=request.metadata, replace_metadata=request.replace_metadata
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("queues.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
|
||||
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
|
||||
|
||||
@@ -98,6 +98,7 @@ from apiserver.bll.task.task_operations import (
|
||||
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
|
||||
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
@@ -213,6 +214,16 @@ def _process_include_subprojects(call_data: dict):
|
||||
call_data["project"] = project_ids_with_children(project_ids)
|
||||
|
||||
|
||||
def _hidden_query(data: dict) -> Q:
|
||||
"""
|
||||
1. Add only non-hidden tasks search condition (unless specifically specified differently)
|
||||
"""
|
||||
if data.get("search_hidden") or data.get("id"):
|
||||
return Q()
|
||||
|
||||
return Q(system_tags__ne=EntityVisibility.hidden.value)
|
||||
|
||||
|
||||
@endpoint("tasks.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
@@ -225,6 +236,7 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
query=_hidden_query(call_data),
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
@@ -259,6 +271,7 @@ def get_all(call: APICall, company_id, _):
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
query=_hidden_query(call_data),
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
||||
from apiserver.apimodels.organization import Filter
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||
@@ -24,7 +23,7 @@ def get_tags_filter_dictionary(input_: Filter) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def get_tags_response(ret: dict) -> dict:
|
||||
def sort_tags_response(ret: dict) -> dict:
|
||||
return {field: sorted(vals) for field, vals in ret.items()}
|
||||
|
||||
|
||||
@@ -222,22 +221,38 @@ class DockerCmdBackwardsCompatibility:
|
||||
nested_set(task, cls.field, docker_cmd)
|
||||
|
||||
|
||||
def validate_metadata(metadata: Sequence[dict]):
|
||||
def escape_metadata(document: dict):
|
||||
"""
|
||||
Escape special characters in metadata keys
|
||||
"""
|
||||
metadata = document.get("metadata")
|
||||
if not metadata:
|
||||
return
|
||||
|
||||
keys = [m.get("key") for m in metadata]
|
||||
unique_keys = set(keys)
|
||||
unique_keys.discard(None)
|
||||
if len(keys) != len(set(keys)):
|
||||
raise errors.bad_request.ValidationError("Metadata keys should be unique")
|
||||
document["metadata"] = {
|
||||
ParameterKeyEscaper.escape(k): v
|
||||
for k, v in metadata.items()
|
||||
}
|
||||
|
||||
|
||||
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
|
||||
if not api_metadata:
|
||||
return api_metadata
|
||||
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||
"""
|
||||
Unescape special characters in metadata keys
|
||||
"""
|
||||
if isinstance(documents, dict):
|
||||
documents = [documents]
|
||||
|
||||
metadata = [m.to_struct() for m in api_metadata]
|
||||
validate_metadata(metadata)
|
||||
old_client = call.requested_endpoint_version <= PartialVersion("2.16")
|
||||
for doc in documents:
|
||||
if old_client and "metadata" in doc:
|
||||
doc["metadata"] = []
|
||||
continue
|
||||
|
||||
return metadata
|
||||
metadata = doc.get("metadata")
|
||||
if not metadata:
|
||||
continue
|
||||
|
||||
doc["metadata"] = {
|
||||
ParameterKeyEscaper.unescape(k): v
|
||||
for k, v in metadata.items()
|
||||
}
|
||||
|
||||
@@ -4,10 +4,29 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestProjectTags(TestService):
|
||||
def setUp(self, version="2.12"):
|
||||
super().setUp(version=version)
|
||||
def test_project_own_tags(self):
|
||||
p1_tags = ["Tag 1", "Tag 2"]
|
||||
p1 = self.create_temp(
|
||||
"projects", name="Test project tags1", description="test", tags=p1_tags
|
||||
)
|
||||
p2_tags = ["Tag 1", "Tag 3"]
|
||||
p2 = self.create_temp(
|
||||
"projects",
|
||||
name="Test project tags2",
|
||||
description="test",
|
||||
tags=p2_tags,
|
||||
system_tags=["hidden"],
|
||||
)
|
||||
|
||||
def test_project_tags(self):
|
||||
res = self.api.projects.get_project_tags(projects=[p1, p2])
|
||||
self.assertEqual(set(res.tags), set(p1_tags) | set(p2_tags))
|
||||
|
||||
res = self.api.projects.get_project_tags(
|
||||
projects=[p1, p2], filter={"system_tags": ["__$not", "hidden"]}
|
||||
)
|
||||
self.assertEqual(res.tags, p1_tags)
|
||||
|
||||
def test_project_entities_tags(self):
|
||||
tags_1 = ["Test tag 1", "Test tag 2"]
|
||||
tags_2 = ["Test tag 3", "Test tag 4"]
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.tests.api_client import APIClient
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestQueueAndModelMetadata(TestService):
|
||||
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
||||
meta1 = {"test_key": {"key": "test_key", "type": "str", "value": "test_value"}}
|
||||
|
||||
def test_queue_metas(self):
|
||||
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
|
||||
@@ -23,20 +22,51 @@ class TestQueueAndModelMetadata(TestService):
|
||||
)
|
||||
|
||||
model_id = self._temp_model("TestMetadata1")
|
||||
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
||||
self.api.models.edit(model=model_id, metadata=self.meta1)
|
||||
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
||||
|
||||
def test_project_meta_query(self):
|
||||
self._temp_model("TestMetadata", metadata=self.meta1)
|
||||
project = self.temp_project(name="MetaParent")
|
||||
test_key = "test_key"
|
||||
test_key2 = "test_key2"
|
||||
test_value = "test_value"
|
||||
test_value2 = "test_value2"
|
||||
model_id = self._temp_model(
|
||||
"TestMetadata2",
|
||||
project=project,
|
||||
metadata={
|
||||
test_key: {"key": test_key, "type": "str", "value": test_value},
|
||||
test_key2: {"key": test_key2, "type": "str", "value": test_value2},
|
||||
},
|
||||
)
|
||||
res = self.api.projects.get_model_metadata_keys()
|
||||
self.assertTrue({test_key, test_key2}.issubset(set(res["keys"])))
|
||||
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
|
||||
self.assertTrue(test_key in res["keys"])
|
||||
self.assertFalse(test_key2 in res["keys"])
|
||||
|
||||
model = self.api.models.get_all_ex(
|
||||
id=[model_id], only_fields=["metadata.test_key"]
|
||||
).models[0]
|
||||
self.assertTrue(test_key in model.metadata)
|
||||
self.assertFalse(test_key2 in model.metadata)
|
||||
|
||||
res = self.api.projects.get_model_metadata_values(key=test_key)
|
||||
self.assertEqual(res.total, 1)
|
||||
self.assertEqual(res["values"], [test_value])
|
||||
|
||||
def _test_meta_operations(
|
||||
self, service: APIClient.Service, entity: str, _id: str,
|
||||
):
|
||||
assert_meta = partial(self._assertMeta, service=service, entity=entity)
|
||||
assert_meta(_id=_id, meta=self.meta1)
|
||||
|
||||
meta2 = [
|
||||
{"key": "test1", "type": "str", "value": "data1"},
|
||||
{"key": "test2", "type": "str", "value": "data2"},
|
||||
{"key": "test3", "type": "str", "value": "data3"},
|
||||
]
|
||||
meta2 = {
|
||||
"test1": {"key": "test1", "type": "str", "value": "data1"},
|
||||
"test2": {"key": "test2", "type": "str", "value": "data2"},
|
||||
"test3": {"key": "test3", "type": "str", "value": "data3"},
|
||||
}
|
||||
service.update(**{entity: _id, "metadata": meta2})
|
||||
assert_meta(_id=_id, meta=meta2)
|
||||
|
||||
@@ -48,16 +78,17 @@ class TestQueueAndModelMetadata(TestService):
|
||||
]
|
||||
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=[meta2[0], *updates])
|
||||
assert_meta(_id=_id, meta={**meta2, **{u["key"]: u for u in updates}})
|
||||
|
||||
res = service.delete_metadata(
|
||||
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=meta2[:1])
|
||||
# noinspection PyTypeChecker
|
||||
assert_meta(_id=_id, meta=dict(list(meta2.items())[:1]))
|
||||
|
||||
def _assertMeta(
|
||||
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
|
||||
self, service: APIClient.Service, entity: str, _id: str, meta: dict
|
||||
):
|
||||
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
|
||||
self.assertEqual(res.metadata, meta)
|
||||
|
||||
@@ -199,10 +199,10 @@ class TestSubProjects(TestService):
|
||||
res1 = next(p for p in res if p.id == project1)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["in_progress"], 0)
|
||||
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
|
||||
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
|
||||
self.assertEqual(res1.stats["active"]["completed_tasks_24h"], 2)
|
||||
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
|
||||
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
|
||||
self.assertEqual(
|
||||
{sp.name for sp in res1.sub_projects},
|
||||
{
|
||||
@@ -214,10 +214,10 @@ class TestSubProjects(TestService):
|
||||
res2 = next(p for p in res if p.id == project2)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["in_progress"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["completed"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
|
||||
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
|
||||
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
|
||||
self.assertEqual(res2.sub_projects, [])
|
||||
|
||||
def _run_tasks(self, *tasks):
|
||||
|
||||
@@ -133,6 +133,32 @@ class TestTags(TestService):
|
||||
).models
|
||||
self.assertFound(model_id, [], models)
|
||||
|
||||
def testQueueTags(self):
|
||||
q_id = self._temp_queue(system_tags=["default"])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["default"]
|
||||
).queues
|
||||
self.assertFound(q_id, ["default"], queues)
|
||||
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertNotFound(q_id, queues)
|
||||
|
||||
self.api.queues.update(queue=q_id, system_tags=[])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertFound(q_id, [], queues)
|
||||
|
||||
# test default queue
|
||||
queues = self.api.queues.get_all(system_tags=["default"]).queues
|
||||
if queues:
|
||||
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
|
||||
else:
|
||||
self.api.queues.update(queue=q_id, system_tags=["default"])
|
||||
self.assertEqual(q_id, self.api.queues.get_default().id)
|
||||
|
||||
def testTaskTags(self):
|
||||
task_id = self._temp_task(
|
||||
name="Test tags", system_tags=["active"]
|
||||
@@ -169,38 +195,11 @@ class TestTags(TestService):
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
self.assertEqual(task.status, "stopped")
|
||||
|
||||
def testQueueTags(self):
|
||||
q_id = self._temp_queue(system_tags=["default"])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["default"]
|
||||
).queues
|
||||
self.assertFound(q_id, ["default"], queues)
|
||||
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertNotFound(q_id, queues)
|
||||
|
||||
self.api.queues.update(queue=q_id, system_tags=[])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertFound(q_id, [], queues)
|
||||
|
||||
# test default queue
|
||||
queues = self.api.queues.get_all(system_tags=["default"]).queues
|
||||
if queues:
|
||||
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
|
||||
else:
|
||||
self.api.queues.update(queue=q_id, system_tags=["default"])
|
||||
self.assertEqual(q_id, self.api.queues.get_default().id)
|
||||
|
||||
def assertProjectStats(self, project: AttrDict):
|
||||
self.assertEqual(set(project.stats.keys()), {"active"})
|
||||
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
|
||||
self.assertEqual(project.stats.active.completed_tasks, 1)
|
||||
self.assertEqual(project.stats.active.completed_tasks_24h, 1)
|
||||
self.assertEqual(project.stats.active.total_tasks, 1)
|
||||
self.assertEqual(project.stats.active.running_tasks, 0)
|
||||
for status, count in project.stats.active.status_count.items():
|
||||
self.assertEqual(count, 1 if status == "stopped" else 0)
|
||||
|
||||
|
||||
14
apiserver/utilities/env.py
Normal file
14
apiserver/utilities/env.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from distutils.util import strtobool
|
||||
from os import getenv
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_bool(*keys: str, default: bool = None) -> Optional[bool]:
|
||||
try:
|
||||
value = next(env for env in (getenv(key) for key in keys) if env is not None)
|
||||
except StopIteration:
|
||||
return default
|
||||
try:
|
||||
return bool(strtobool(value))
|
||||
except ValueError:
|
||||
return bool(value)
|
||||
@@ -1 +1 @@
|
||||
__version__ = "1.2.0"
|
||||
__version__ = "1.4.0"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
FROM centos/nodejs-12-centos7 AS webapp
|
||||
|
||||
ARG CLEARML_WEB_GIT_URL=https://github.com/allegroai/clearml-web.git
|
||||
|
||||
USER root
|
||||
WORKDIR /opt
|
||||
|
||||
RUN git clone https://github.com/allegroai/clearml-web.git
|
||||
RUN git clone ${CLEARML_WEB_GIT_URL} clearml-web
|
||||
RUN mv clearml-web /opt/open-webapp
|
||||
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
|
||||
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
|
||||
@@ -18,6 +20,7 @@ COPY --from=staging_image /opt/clearml/ /opt/clearml/
|
||||
|
||||
COPY --chmod=744 docker/build/internal_files/final_image_preparation.sh /tmp/internal_files/
|
||||
COPY docker/build/internal_files/clearml.conf.template /tmp/internal_files/
|
||||
COPY docker/build/internal_files/clearml_subpath.conf.template /tmp/internal_files/
|
||||
RUN /bin/bash -c '/tmp/internal_files/final_image_preparation.sh'
|
||||
|
||||
COPY --from=webapp /opt/open-webapp/build /usr/share/nginx/html
|
||||
|
||||
@@ -41,6 +41,7 @@ http {
|
||||
server_name _;
|
||||
root /usr/share/nginx/html;
|
||||
proxy_http_version 1.1;
|
||||
client_max_body_size 0;
|
||||
|
||||
# comppression
|
||||
gzip on;
|
||||
|
||||
21
docker/build/internal_files/clearml_subpath.conf.template
Executable file
21
docker/build/internal_files/clearml_subpath.conf.template
Executable file
@@ -0,0 +1,21 @@
|
||||
location /${CLEARML_SERVER_SUB_PATH} {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header Host $host;
|
||||
proxy_pass http://localhost:80;
|
||||
rewrite /${CLEARML_SERVER_SUB_PATH}/(.*) /$1 break;
|
||||
}
|
||||
|
||||
location /${CLEARML_SERVER_SUB_PATH}/api {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header Host $host;
|
||||
proxy_pass http://localhost:80/api;
|
||||
rewrite /${CLEARML_SERVER_SUB_PATH}/api/(.*) /api/$1 break;
|
||||
}
|
||||
|
||||
location /${CLEARML_SERVER_SUB_PATH}/files {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header Host $host;
|
||||
proxy_pass http://localhost:80/files;
|
||||
rewrite /${CLEARML_SERVER_SUB_PATH}/files/(.*) /files/$1 break;
|
||||
rewrite /${CLEARML_SERVER_SUB_PATH}/files /files/ break;
|
||||
}
|
||||
@@ -48,9 +48,16 @@ EOF
|
||||
|
||||
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
|
||||
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
|
||||
|
||||
envsubst '${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/nginx.conf
|
||||
|
||||
if [[ -n "${CLEARML_SERVER_SUB_PATH}" ]]; then
|
||||
envsubst '${CLEARML_SERVER_SUB_PATH}' < /etc/nginx/clearml_subpath.conf.template > /etc/nginx/default.d/clearml_subpath.conf
|
||||
cp /usr/share/nginx/html/env.js /usr/share/nginx/html/env.js.origin
|
||||
envsubst '${CLEARML_SERVER_SUB_PATH}' < /usr/share/nginx/html/env.js.origin > /usr/share/nginx/html/env.js
|
||||
cp /usr/share/nginx/html/index.html /usr/share/nginx/html/index.html.origin
|
||||
sed 's/href="\/"/href="\/'${CLEARML_SERVER_SUB_PATH}'\/"/' /usr/share/nginx/html/index.html.origin > /usr/share/nginx/html/index.html
|
||||
fi
|
||||
|
||||
#start the server
|
||||
/usr/sbin/nginx -g "daemon off;"
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ set -o pipefail
|
||||
|
||||
yum update -y
|
||||
yum install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm
|
||||
yum install -y python36 python36-pip nginx gcc python3-devel gettext
|
||||
yum install -y python36 python36-pip nginx gcc gcc-c++ python3-devel gettext
|
||||
yum -y upgrade
|
||||
python3 -m pip install -r /opt/clearml/fileserver/requirements.txt
|
||||
python3 -m pip install -r /opt/clearml/apiserver/requirements.txt
|
||||
@@ -15,4 +15,5 @@ ln -s /dev/stdout /var/log/nginx/access.log
|
||||
ln -s /dev/stderr /var/log/nginx/error.log
|
||||
mv /etc/nginx/nginx.conf /etc/nginx/nginx.conf.orig
|
||||
mv /tmp/internal_files/clearml.conf.template /etc/nginx/clearml.conf.template
|
||||
mv /tmp/internal_files/clearml_subpath.conf.template /etc/nginx/clearml_subpath.conf.template
|
||||
yum clean all
|
||||
@@ -89,12 +89,12 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-mongo
|
||||
image: mongo:3.6.23
|
||||
image: mongo:4.4.9
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
|
||||
volumes:
|
||||
- c:/opt/clearml/data/mongo/db:/data/db
|
||||
- c:/opt/clearml/data/mongo/configdb:/data/configdb
|
||||
- c:/opt/clearml/data/mongo_4/db:/data/db
|
||||
- c:/opt/clearml/data/mongo_4/configdb:/data/configdb
|
||||
|
||||
redis:
|
||||
networks:
|
||||
|
||||
@@ -88,12 +88,12 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-mongo
|
||||
image: mongo:3.6.23
|
||||
image: mongo:4.4.9
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
|
||||
volumes:
|
||||
- /opt/clearml/data/mongo/db:/data/db
|
||||
- /opt/clearml/data/mongo/configdb:/data/configdb
|
||||
- /opt/clearml/data/mongo_4/db:/data/db
|
||||
- /opt/clearml/data/mongo_4/configdb:/data/configdb
|
||||
|
||||
redis:
|
||||
networks:
|
||||
@@ -108,6 +108,8 @@ services:
|
||||
command:
|
||||
- webserver
|
||||
container_name: clearml-webserver
|
||||
# environment:
|
||||
# CLEARML_SERVER_SUB_PATH : clearml-web # Allow Clearml to be served with a URL path prefix.
|
||||
image: allegroai/clearml:latest
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
@@ -152,6 +154,8 @@ services:
|
||||
- /opt/clearml/agent:/root/.clearml
|
||||
depends_on:
|
||||
- apiserver
|
||||
entrypoint: >
|
||||
bash -c "curl --retry 10 --retry-delay 10 --retry-connrefused 'http://apiserver:8008/debug.ping' && /usr/agent/entrypoint.sh"
|
||||
|
||||
networks:
|
||||
backend:
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# trains-server FAQ
|
||||
|
||||
|
||||
## **NOTE**: This page's information is deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
|
||||
|
||||
Launching **trains-server**
|
||||
|
||||
* How do I launch **trains-server** on:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Deploying **trains-server** on AWS
|
||||
|
||||
## **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
|
||||
|
||||
To easily deploy **trains-server** on AWS, use one of our pre-built Amazon Machine Images (AMIs).
|
||||
We provide AMIs per region for each released version of **trains-server**, see [Released versions](#released-versions) below.
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Deploying Trains Server on Google Cloud Platform
|
||||
|
||||
# **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
|
||||
|
||||
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
|
||||
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Launching the **trains-server** Docker in Linux or macOS
|
||||
|
||||
## **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
|
||||
|
||||
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
|
||||
|
||||
For Linux users:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Launching the **trains-server** Docker in Windows 10
|
||||
|
||||
## **NOTE**: These instructions are deprecated. See the [ClearML documentation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server) for up-to-date deployment instructions
|
||||
|
||||
For Windows, we recommend launching our pre-built Docker image on a Linux virtual machine.
|
||||
However, you can launch **trains-server** on Windows 10 using Docker Desktop for Windows (see the Docker [System Requirements](https://docs.docker.com/docker-for-windows/install/#system-requirements)).
|
||||
|
||||
|
||||
@@ -13,12 +13,15 @@ from werkzeug.exceptions import NotFound
|
||||
from werkzeug.security import safe_join
|
||||
|
||||
from config import config
|
||||
from utils import get_env_bool
|
||||
|
||||
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, **config.get("fileserver.cors"))
|
||||
Compress(app)
|
||||
|
||||
if get_env_bool("CLEARML_COMPRESS_RESP", default=True):
|
||||
Compress(app)
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = first(
|
||||
(os.environ.get(f"{prefix}_UPLOAD_FOLDER") for prefix in ("CLEARML", "TRAINS")),
|
||||
@@ -29,6 +32,20 @@ app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get(
|
||||
)
|
||||
|
||||
|
||||
@app.before_request
|
||||
def before_request():
|
||||
if request.content_encoding:
|
||||
return f"Content encoding is not supported ({request.content_encoding})", 415
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
response.headers["server"] = config.get(
|
||||
"fileserver.response.headers.server", "clearml"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def upload():
|
||||
results = []
|
||||
@@ -54,7 +71,10 @@ def download(path):
|
||||
mimetype = "application/octet-stream" if encoding == "gzip" else None
|
||||
|
||||
response = send_from_directory(
|
||||
app.config["UPLOAD_FOLDER"], path, as_attachment=as_attachment, mimetype=mimetype
|
||||
app.config["UPLOAD_FOLDER"],
|
||||
path,
|
||||
as_attachment=as_attachment,
|
||||
mimetype=mimetype,
|
||||
)
|
||||
if config.get("fileserver.download.disable_browser_caching", False):
|
||||
headers = response.headers
|
||||
@@ -68,12 +88,7 @@ def download(path):
|
||||
|
||||
@app.route("/<path:path>", methods=["DELETE"])
|
||||
def delete(path):
|
||||
real_path = Path(
|
||||
safe_join(
|
||||
os.fspath(app.config["UPLOAD_FOLDER"]),
|
||||
os.fspath(path)
|
||||
)
|
||||
)
|
||||
real_path = Path(safe_join(os.fspath(app.config["UPLOAD_FOLDER"]), os.fspath(path)))
|
||||
if not real_path.exists() or not real_path.is_file():
|
||||
abort(Response(f"File {str(path)} not found", 404))
|
||||
|
||||
|
||||
14
fileserver/utils.py
Normal file
14
fileserver/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from distutils.util import strtobool
|
||||
from os import getenv
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_env_bool(*keys: str, default: bool = None) -> Optional[bool]:
|
||||
try:
|
||||
value = next(env for env in (getenv(key) for key in keys) if env is not None)
|
||||
except StopIteration:
|
||||
return default
|
||||
try:
|
||||
return bool(strtobool(value))
|
||||
except ValueError:
|
||||
return bool(value)
|
||||
Reference in New Issue
Block a user