mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add metadata dict support for models, queues
Add more info for projects
This commit is contained in:
parent
04ea9018a3
commit
af09fba755
apiserver
apimodels
bll
database/model
mongo/migrations
schema/services
services
tests/automated
@ -1,7 +1,7 @@
|
|||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from jsonmodels import validators
|
from jsonmodels import validators
|
||||||
from jsonmodels.fields import StringField
|
from jsonmodels.fields import StringField, BoolField
|
||||||
from jsonmodels.models import Base
|
from jsonmodels.models import Base
|
||||||
|
|
||||||
from apiserver.apimodels import ListField
|
from apiserver.apimodels import ListField
|
||||||
@ -21,3 +21,4 @@ class AddOrUpdateMetadata(Base):
|
|||||||
metadata: Sequence[MetadataItem] = ListField(
|
metadata: Sequence[MetadataItem] = ListField(
|
||||||
[MetadataItem], validators=validators.Length(minimum_value=1)
|
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||||
)
|
)
|
||||||
|
replace_metadata = BoolField(default=False)
|
||||||
|
@ -30,7 +30,7 @@ class CreateModelRequest(models.Base):
|
|||||||
ready = fields.BoolField(default=True)
|
ready = fields.BoolField(default=True)
|
||||||
ui_cache = DictField()
|
ui_cache = DictField()
|
||||||
task = fields.StringField()
|
task = fields.StringField()
|
||||||
metadata = ListField(items_types=[MetadataItem])
|
metadata = DictField(value_types=[MetadataItem])
|
||||||
|
|
||||||
|
|
||||||
class CreateModelResponse(models.Base):
|
class CreateModelResponse(models.Base):
|
||||||
|
@ -2,7 +2,7 @@ from jsonmodels import validators
|
|||||||
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||||
from jsonmodels.models import Base
|
from jsonmodels.models import Base
|
||||||
|
|
||||||
from apiserver.apimodels import ListField
|
from apiserver.apimodels import ListField, DictField
|
||||||
from apiserver.apimodels.metadata import (
|
from apiserver.apimodels.metadata import (
|
||||||
MetadataItem,
|
MetadataItem,
|
||||||
DeleteMetadata,
|
DeleteMetadata,
|
||||||
@ -19,13 +19,18 @@ class CreateRequest(Base):
|
|||||||
name = StringField(required=True)
|
name = StringField(required=True)
|
||||||
tags = ListField(items_types=[str])
|
tags = ListField(items_types=[str])
|
||||||
system_tags = ListField(items_types=[str])
|
system_tags = ListField(items_types=[str])
|
||||||
metadata = ListField(items_types=[MetadataItem])
|
metadata = DictField(value_types=[MetadataItem])
|
||||||
|
|
||||||
|
|
||||||
class QueueRequest(Base):
|
class QueueRequest(Base):
|
||||||
queue = StringField(required=True)
|
queue = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GetNextTaskRequest(QueueRequest):
|
||||||
|
queue = StringField(required=True)
|
||||||
|
get_task_info = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class DeleteRequest(QueueRequest):
|
class DeleteRequest(QueueRequest):
|
||||||
force = BoolField(default=False)
|
force = BoolField(default=False)
|
||||||
|
|
||||||
@ -34,7 +39,7 @@ class UpdateRequest(QueueRequest):
|
|||||||
name = StringField()
|
name = StringField()
|
||||||
tags = ListField(items_types=[str])
|
tags = ListField(items_types=[str])
|
||||||
system_tags = ListField(items_types=[str])
|
system_tags = ListField(items_types=[str])
|
||||||
metadata = ListField(items_types=[MetadataItem])
|
metadata = DictField(value_types=[MetadataItem])
|
||||||
|
|
||||||
|
|
||||||
class TaskRequest(QueueRequest):
|
class TaskRequest(QueueRequest):
|
||||||
|
@ -7,6 +7,7 @@ from apiserver.bll.task.utils import deleted_prefix
|
|||||||
from apiserver.database.model import EntityVisibility
|
from apiserver.database.model import EntityVisibility
|
||||||
from apiserver.database.model.model import Model
|
from apiserver.database.model.model import Model
|
||||||
from apiserver.database.model.task.task import Task, TaskStatus
|
from apiserver.database.model.task.task import Task, TaskStatus
|
||||||
|
from .metadata import Metadata
|
||||||
|
|
||||||
|
|
||||||
class ModelBLL:
|
class ModelBLL:
|
||||||
|
111
apiserver/bll/model/metadata.py
Normal file
111
apiserver/bll/model/metadata.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from typing import Sequence, Union, Mapping
|
||||||
|
|
||||||
|
from mongoengine import Document
|
||||||
|
|
||||||
|
from apiserver.apierrors import errors
|
||||||
|
from apiserver.apimodels.metadata import MetadataItem
|
||||||
|
from apiserver.database.model.base import GetMixin
|
||||||
|
from apiserver.service_repo import APICall
|
||||||
|
from apiserver.utilities.parameter_key_escaper import (
|
||||||
|
ParameterKeyEscaper,
|
||||||
|
mongoengine_safe,
|
||||||
|
)
|
||||||
|
from apiserver.config_repo import config
|
||||||
|
from apiserver.timing_context import TimingContext
|
||||||
|
|
||||||
|
log = config.logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class Metadata:
|
||||||
|
@staticmethod
|
||||||
|
def metadata_from_api(
|
||||||
|
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
|
||||||
|
) -> dict:
|
||||||
|
if not api_data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if isinstance(api_data, dict):
|
||||||
|
return {
|
||||||
|
ParameterKeyEscaper.escape(k): v.to_struct()
|
||||||
|
for k, v in api_data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def edit_metadata(
|
||||||
|
cls,
|
||||||
|
obj: Document,
|
||||||
|
items: Sequence[MetadataItem],
|
||||||
|
replace_metadata: bool,
|
||||||
|
**more_updates,
|
||||||
|
) -> int:
|
||||||
|
with TimingContext("mongo", "edit_metadata"):
|
||||||
|
update_cmds = dict()
|
||||||
|
metadata = cls.metadata_from_api(items)
|
||||||
|
if replace_metadata:
|
||||||
|
update_cmds["set__metadata"] = metadata
|
||||||
|
else:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
|
||||||
|
|
||||||
|
return obj.update(**update_cmds, **more_updates)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
|
||||||
|
with TimingContext("mongo", "delete_metadata"):
|
||||||
|
return obj.update(
|
||||||
|
**{
|
||||||
|
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
|
||||||
|
for key in set(keys)
|
||||||
|
},
|
||||||
|
**more_updates,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_path(path: str):
|
||||||
|
"""
|
||||||
|
Frontend does a partial escaping on the path so the all '.' in key names are escaped
|
||||||
|
Need to unescape and apply a full mongo escaping
|
||||||
|
"""
|
||||||
|
parts = path.split(".")
|
||||||
|
if len(parts) < 2 or len(parts) > 3:
|
||||||
|
raise errors.bad_request.ValidationError("invalid field", path=path)
|
||||||
|
return ".".join(
|
||||||
|
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
|
||||||
|
for prefix in (
|
||||||
|
"metadata.",
|
||||||
|
"-metadata.",
|
||||||
|
):
|
||||||
|
paths = [
|
||||||
|
cls._process_path(path) if path.startswith(prefix) else path
|
||||||
|
for path in paths
|
||||||
|
]
|
||||||
|
return paths
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def escape_query_parameters(cls, call: APICall) -> dict:
|
||||||
|
if not call.data:
|
||||||
|
return call.data
|
||||||
|
|
||||||
|
keys = list(call.data)
|
||||||
|
call_data = {
|
||||||
|
safe_key: call.data[key]
|
||||||
|
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
|
||||||
|
}
|
||||||
|
|
||||||
|
projection = GetMixin.get_projection(call_data)
|
||||||
|
if projection:
|
||||||
|
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
|
||||||
|
|
||||||
|
ordering = GetMixin.get_ordering(call_data)
|
||||||
|
if ordering:
|
||||||
|
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
|
||||||
|
|
||||||
|
return call_data
|
@ -388,6 +388,17 @@ class ProjectBLL:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def max_started_subquery(condition):
|
||||||
|
return {
|
||||||
|
"$max": {
|
||||||
|
"$cond": {
|
||||||
|
"if": condition,
|
||||||
|
"then": "$started",
|
||||||
|
"else": datetime.min,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def runtime_subquery(additional_cond):
|
def runtime_subquery(additional_cond):
|
||||||
return {
|
return {
|
||||||
# the sum of
|
# the sum of
|
||||||
@ -431,14 +442,22 @@ class ProjectBLL:
|
|||||||
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
||||||
cond, time_thresh=time_thresh
|
cond, time_thresh=time_thresh
|
||||||
)
|
)
|
||||||
|
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
|
||||||
|
|
||||||
|
def get_state_filter() -> dict:
|
||||||
|
if not specific_state:
|
||||||
|
return {}
|
||||||
|
if specific_state == EntityVisibility.archived:
|
||||||
|
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
|
||||||
|
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
|
||||||
|
|
||||||
runtime_pipeline = [
|
runtime_pipeline = [
|
||||||
# only count run time for these types of tasks
|
# only count run time for these types of tasks
|
||||||
{
|
{
|
||||||
"$match": {
|
"$match": {
|
||||||
"company": {"$in": [None, "", company_id]},
|
"company": {"$in": [None, "", company_id]},
|
||||||
"type": {"$in": ["training", "testing", "annotation"]},
|
|
||||||
"project": {"$in": project_ids},
|
"project": {"$in": project_ids},
|
||||||
|
**get_state_filter(),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ensure_valid_fields(),
|
ensure_valid_fields(),
|
||||||
@ -547,6 +566,8 @@ class ProjectBLL:
|
|||||||
) -> Dict[str, dict]:
|
) -> Dict[str, dict]:
|
||||||
return {
|
return {
|
||||||
section: a.get(section, 0) + b.get(section, 0)
|
section: a.get(section, 0) + b.get(section, 0)
|
||||||
|
if not section.endswith("max_task_started")
|
||||||
|
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
|
||||||
for section in set(a) | set(b)
|
for section in set(a) | set(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -562,6 +583,10 @@ class ProjectBLL:
|
|||||||
project_section_statuses = nested_get(
|
project_section_statuses = nested_get(
|
||||||
status_count, (project_id, section), default=default_counts
|
status_count, (project_id, section), default=default_counts
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_time_or_none(value):
|
||||||
|
return value if value != datetime.min else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status_count": project_section_statuses,
|
"status_count": project_section_statuses,
|
||||||
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
|
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
|
||||||
@ -570,6 +595,9 @@ class ProjectBLL:
|
|||||||
"completed_tasks": project_runtime.get(
|
"completed_tasks": project_runtime.get(
|
||||||
f"{section}_recently_completed", 0
|
f"{section}_recently_completed", 0
|
||||||
),
|
),
|
||||||
|
"last_task_run": get_time_or_none(
|
||||||
|
project_runtime.get(f"{section}_max_task_started", datetime.min)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
report_for_states = [
|
report_for_states = [
|
||||||
@ -723,7 +751,9 @@ class ProjectBLL:
|
|||||||
return Model.objects(query).distinct(field="framework")
|
return Model.objects(query).distinct(field="framework")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
|
def calc_own_contents(
|
||||||
|
cls, company: str, project_ids: Sequence[str]
|
||||||
|
) -> Dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
Returns the amount of task/models per requested project
|
Returns the amount of task/models per requested project
|
||||||
Use separate aggregation calls on Task/Model instead of lookup
|
Use separate aggregation calls on Task/Model instead of lookup
|
||||||
@ -739,30 +769,17 @@ class ProjectBLL:
|
|||||||
"project": {"$in": project_ids},
|
"project": {"$in": project_ids},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{"$project": {"project": 1}},
|
||||||
"$project": {"project": 1}
|
{"$group": {"_id": "$project", "count": {"$sum": 1}}}
|
||||||
},
|
|
||||||
{
|
|
||||||
"$group": {
|
|
||||||
"_id": "$project",
|
|
||||||
"count": {"$sum": 1},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
||||||
return {
|
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
|
||||||
data["_id"]: data["count"]
|
|
||||||
for data in cls_.aggregate(pipeline)
|
|
||||||
}
|
|
||||||
|
|
||||||
with TimingContext("mongo", "get_security_groups"):
|
with TimingContext("mongo", "get_security_groups"):
|
||||||
tasks = get_agrregate_res(Task)
|
tasks = get_agrregate_res(Task)
|
||||||
models = get_agrregate_res(Model)
|
models = get_agrregate_res(Model)
|
||||||
return {
|
return {
|
||||||
pid: {
|
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
|
||||||
"own_tasks": tasks.get(pid, 0),
|
|
||||||
"own_models": models.get(pid, 0),
|
|
||||||
}
|
|
||||||
for pid in project_ids
|
for pid in project_ids
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ from typing import (
|
|||||||
from redis import StrictRedis
|
from redis import StrictRedis
|
||||||
|
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
|
from apiserver.database.model.model import Model
|
||||||
from apiserver.database.model.task.task import Task
|
from apiserver.database.model.task.task import Task
|
||||||
from apiserver.redis_manager import redman
|
from apiserver.redis_manager import redman
|
||||||
from apiserver.utilities.dicts import nested_get
|
from apiserver.utilities.dicts import nested_get
|
||||||
@ -239,3 +240,53 @@ class ProjectQueries:
|
|||||||
|
|
||||||
result = Task.aggregate(pipeline)
|
result = Task.aggregate(pipeline)
|
||||||
return [r["metrics"][0] for r in result]
|
return [r["metrics"][0] for r in result]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_metadata_keys(
|
||||||
|
cls,
|
||||||
|
company_id,
|
||||||
|
project_ids: Sequence[str],
|
||||||
|
include_subprojects: bool,
|
||||||
|
page: int = 0,
|
||||||
|
page_size: int = 500,
|
||||||
|
) -> Tuple[int, int, Sequence[dict]]:
|
||||||
|
page = max(0, page)
|
||||||
|
page_size = max(1, page_size)
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
**cls._get_company_constraint(company_id),
|
||||||
|
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||||
|
"metadata": {"$exists": True, "$gt": {}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
|
||||||
|
{"$unwind": "$metadata"},
|
||||||
|
{"$group": {"_id": "$metadata.k"}},
|
||||||
|
{"$sort": {"_id": 1}},
|
||||||
|
{"$skip": page * page_size},
|
||||||
|
{"$limit": page_size},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": 1,
|
||||||
|
"total": {"$sum": 1},
|
||||||
|
"results": {"$push": "$$ROOT"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = next(Model.aggregate(pipeline), None)
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
remaining = 0
|
||||||
|
results = []
|
||||||
|
|
||||||
|
if result:
|
||||||
|
total = int(result.get("total", -1))
|
||||||
|
results = [
|
||||||
|
ParameterKeyEscaper.unescape(r.get("_id"))
|
||||||
|
for r in result.get("results", [])
|
||||||
|
]
|
||||||
|
remaining = max(0, total - (len(results) + page * page_size))
|
||||||
|
|
||||||
|
return total, remaining, results
|
||||||
|
@ -32,7 +32,7 @@ class QueueBLL(object):
|
|||||||
name: str,
|
name: str,
|
||||||
tags: Optional[Sequence[str]] = None,
|
tags: Optional[Sequence[str]] = None,
|
||||||
system_tags: Optional[Sequence[str]] = None,
|
system_tags: Optional[Sequence[str]] = None,
|
||||||
metadata: Optional[Sequence[dict]] = None,
|
metadata: Optional[dict] = None,
|
||||||
) -> Queue:
|
) -> Queue:
|
||||||
"""Creates a queue"""
|
"""Creates a queue"""
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
|
@ -95,6 +95,7 @@ class GetMixin(PropsMixin):
|
|||||||
}
|
}
|
||||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||||
|
|
||||||
|
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||||
_field_collation_overrides = {}
|
_field_collation_overrides = {}
|
||||||
|
|
||||||
class QueryParameterOptions(object):
|
class QueryParameterOptions(object):
|
||||||
@ -599,7 +600,7 @@ class GetMixin(PropsMixin):
|
|||||||
return size
|
return size
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_data_with_scroll_and_filter_support(
|
def get_data_with_scroll_support(
|
||||||
cls,
|
cls,
|
||||||
query_dict: dict,
|
query_dict: dict,
|
||||||
data_getter: Callable[[], Sequence[dict]],
|
data_getter: Callable[[], Sequence[dict]],
|
||||||
@ -629,15 +630,12 @@ class GetMixin(PropsMixin):
|
|||||||
if cls._start_key in query_dict:
|
if cls._start_key in query_dict:
|
||||||
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
|
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
|
||||||
|
|
||||||
def update_state(returned_len: int):
|
if state:
|
||||||
if not state:
|
|
||||||
return
|
|
||||||
state.position = query_dict[cls._start_key]
|
state.position = query_dict[cls._start_key]
|
||||||
cls.get_cache_manager().set_state(state)
|
cls.get_cache_manager().set_state(state)
|
||||||
if ret_params is not None:
|
if ret_params is not None:
|
||||||
ret_params["scroll_id"] = state.id
|
ret_params["scroll_id"] = state.id
|
||||||
|
|
||||||
update_state(len(data))
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -770,7 +768,7 @@ class GetMixin(PropsMixin):
|
|||||||
override_projection=override_projection,
|
override_projection=override_projection,
|
||||||
override_collation=override_collation,
|
override_collation=override_collation,
|
||||||
)
|
)
|
||||||
return cls.get_data_with_scroll_and_filter_support(
|
return cls.get_data_with_scroll_support(
|
||||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
from typing import Sequence
|
|
||||||
|
|
||||||
from mongoengine import (
|
from mongoengine import (
|
||||||
StringField,
|
StringField,
|
||||||
DateTimeField,
|
DateTimeField,
|
||||||
BooleanField,
|
BooleanField,
|
||||||
EmbeddedDocumentListField,
|
EmbeddedDocumentField,
|
||||||
)
|
)
|
||||||
|
|
||||||
from apiserver.database import Database, strict
|
from apiserver.database import Database, strict
|
||||||
@ -12,6 +10,7 @@ from apiserver.database.fields import (
|
|||||||
StrippedStringField,
|
StrippedStringField,
|
||||||
SafeDictField,
|
SafeDictField,
|
||||||
SafeSortedListField,
|
SafeSortedListField,
|
||||||
|
SafeMapField,
|
||||||
)
|
)
|
||||||
from apiserver.database.model import AttributedDocument
|
from apiserver.database.model import AttributedDocument
|
||||||
from apiserver.database.model.base import GetMixin
|
from apiserver.database.model.base import GetMixin
|
||||||
@ -22,6 +21,10 @@ from apiserver.database.model.task.task import Task
|
|||||||
|
|
||||||
|
|
||||||
class Model(AttributedDocument):
|
class Model(AttributedDocument):
|
||||||
|
_field_collation_overrides = {
|
||||||
|
"metadata.": AttributedDocument._numeric_locale,
|
||||||
|
}
|
||||||
|
|
||||||
meta = {
|
meta = {
|
||||||
"db_alias": Database.backend,
|
"db_alias": Database.backend,
|
||||||
"strict": strict,
|
"strict": strict,
|
||||||
@ -30,8 +33,6 @@ class Model(AttributedDocument):
|
|||||||
"project",
|
"project",
|
||||||
"task",
|
"task",
|
||||||
"last_update",
|
"last_update",
|
||||||
"metadata.key",
|
|
||||||
"metadata.type",
|
|
||||||
("company", "framework"),
|
("company", "framework"),
|
||||||
("company", "name"),
|
("company", "name"),
|
||||||
("company", "user"),
|
("company", "user"),
|
||||||
@ -63,6 +64,7 @@ class Model(AttributedDocument):
|
|||||||
"project",
|
"project",
|
||||||
"task",
|
"task",
|
||||||
"parent",
|
"parent",
|
||||||
|
"metadata.*",
|
||||||
),
|
),
|
||||||
datetime_fields=("last_update",),
|
datetime_fields=("last_update",),
|
||||||
)
|
)
|
||||||
@ -86,6 +88,6 @@ class Model(AttributedDocument):
|
|||||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||||
)
|
)
|
||||||
company_origin = StringField(exclude_by_default=True)
|
company_origin = StringField(exclude_by_default=True)
|
||||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
metadata = SafeMapField(
|
||||||
MetadataItem, default=list, user_set_allowed=True
|
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||||
)
|
)
|
||||||
|
@ -1,16 +1,19 @@
|
|||||||
from typing import Sequence
|
|
||||||
|
|
||||||
from mongoengine import (
|
from mongoengine import (
|
||||||
Document,
|
Document,
|
||||||
EmbeddedDocument,
|
EmbeddedDocument,
|
||||||
StringField,
|
StringField,
|
||||||
DateTimeField,
|
DateTimeField,
|
||||||
EmbeddedDocumentListField,
|
EmbeddedDocumentListField,
|
||||||
|
EmbeddedDocumentField,
|
||||||
)
|
)
|
||||||
|
|
||||||
from apiserver.database import Database, strict
|
from apiserver.database import Database, strict
|
||||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
from apiserver.database.fields import (
|
||||||
from apiserver.database.model import DbModelMixin
|
StrippedStringField,
|
||||||
|
SafeSortedListField,
|
||||||
|
SafeMapField,
|
||||||
|
)
|
||||||
|
from apiserver.database.model import DbModelMixin, AttributedDocument
|
||||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||||
from apiserver.database.model.company import Company
|
from apiserver.database.model.company import Company
|
||||||
from apiserver.database.model.metadata import MetadataItem
|
from apiserver.database.model.metadata import MetadataItem
|
||||||
@ -19,23 +22,25 @@ from apiserver.database.model.task.task import Task
|
|||||||
|
|
||||||
class Entry(EmbeddedDocument, ProperDictMixin):
|
class Entry(EmbeddedDocument, ProperDictMixin):
|
||||||
""" Entry representing a task waiting in the queue """
|
""" Entry representing a task waiting in the queue """
|
||||||
|
|
||||||
task = StringField(required=True, reference_field=Task)
|
task = StringField(required=True, reference_field=Task)
|
||||||
''' Task ID '''
|
""" Task ID """
|
||||||
added = DateTimeField(required=True)
|
added = DateTimeField(required=True)
|
||||||
''' Added to the queue '''
|
""" Added to the queue """
|
||||||
|
|
||||||
|
|
||||||
class Queue(DbModelMixin, Document):
|
class Queue(DbModelMixin, Document):
|
||||||
|
_field_collation_overrides = {
|
||||||
|
"metadata.": AttributedDocument._numeric_locale,
|
||||||
|
}
|
||||||
|
|
||||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||||
pattern_fields=("name",),
|
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
|
||||||
list_fields=("tags", "system_tags", "id"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
meta = {
|
meta = {
|
||||||
'db_alias': Database.backend,
|
"db_alias": Database.backend,
|
||||||
'strict': strict,
|
"strict": strict,
|
||||||
"indexes": ["metadata.key", "metadata.type"],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
id = StringField(primary_key=True)
|
id = StringField(primary_key=True)
|
||||||
@ -44,10 +49,12 @@ class Queue(DbModelMixin, Document):
|
|||||||
)
|
)
|
||||||
company = StringField(required=True, reference_field=Company)
|
company = StringField(required=True, reference_field=Company)
|
||||||
created = DateTimeField(required=True)
|
created = DateTimeField(required=True)
|
||||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
tags = SafeSortedListField(
|
||||||
|
StringField(required=True), default=list, user_set_allowed=True
|
||||||
|
)
|
||||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||||
last_update = DateTimeField()
|
last_update = DateTimeField()
|
||||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
metadata = SafeMapField(
|
||||||
MetadataItem, default=list, user_set_allowed=True
|
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||||
)
|
)
|
||||||
|
@ -159,11 +159,10 @@ external_task_types = set(get_options(TaskType))
|
|||||||
|
|
||||||
|
|
||||||
class Task(AttributedDocument):
|
class Task(AttributedDocument):
|
||||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
|
||||||
_field_collation_overrides = {
|
_field_collation_overrides = {
|
||||||
"execution.parameters.": _numeric_locale,
|
"execution.parameters.": AttributedDocument._numeric_locale,
|
||||||
"last_metrics.": _numeric_locale,
|
"last_metrics.": AttributedDocument._numeric_locale,
|
||||||
"hyperparams.": _numeric_locale,
|
"hyperparams.": AttributedDocument._numeric_locale,
|
||||||
}
|
}
|
||||||
|
|
||||||
meta = {
|
meta = {
|
||||||
@ -184,7 +183,10 @@ class Task(AttributedDocument):
|
|||||||
("company", "type", "system_tags", "status"),
|
("company", "type", "system_tags", "status"),
|
||||||
("company", "project", "type", "system_tags", "status"),
|
("company", "project", "type", "system_tags", "status"),
|
||||||
("status", "last_update"), # for maintenance tasks
|
("status", "last_update"), # for maintenance tasks
|
||||||
{"fields": ["company", "project"], "collation": _numeric_locale},
|
{
|
||||||
|
"fields": ["company", "project"],
|
||||||
|
"collation": AttributedDocument._numeric_locale,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "%s.task.main_text_index" % Database.backend,
|
"name": "%s.task.main_text_index" % Database.backend,
|
||||||
"fields": [
|
"fields": [
|
||||||
|
29
apiserver/mongo/migrations/1_3_0.py
Normal file
29
apiserver/mongo/migrations/1_3_0.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.database import Database
|
||||||
|
|
||||||
|
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||||
|
from .utils import _drop_all_indices_from_collections
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_metadata(db: Database, name):
|
||||||
|
collection: Collection = db[name]
|
||||||
|
|
||||||
|
metadata_field = "metadata"
|
||||||
|
query = {metadata_field: {"$exists": True, "$type": 4}}
|
||||||
|
for doc in collection.find(filter=query, projection=(metadata_field,)):
|
||||||
|
metadata = {
|
||||||
|
ParameterKeyEscaper.escape(item["key"]): item
|
||||||
|
for item in doc.get(metadata_field, [])
|
||||||
|
if isinstance(item, dict) and "key" in item
|
||||||
|
}
|
||||||
|
collection.update_one(
|
||||||
|
{"_id": doc["_id"]}, {"$set": {"metadata": metadata}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_backend(db: Database):
|
||||||
|
collections = ["model", "queue"]
|
||||||
|
for name in collections:
|
||||||
|
_convert_metadata(db, name)
|
||||||
|
|
||||||
|
_drop_all_indices_from_collections(db, collections)
|
@ -226,6 +226,12 @@ create_credentials {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"999.0": ${create_credentials."2.1"} {
|
||||||
|
request.properties.label {
|
||||||
|
type: string
|
||||||
|
description: Optional credentials label
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
get_credentials {
|
get_credentials {
|
||||||
|
@ -61,14 +61,14 @@ _definitions {
|
|||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
tags {
|
tags {
|
||||||
description: "User-defined tags list"
|
|
||||||
type: array
|
type: array
|
||||||
|
description: "User-defined tags"
|
||||||
items { type: string }
|
items { type: string }
|
||||||
}
|
}
|
||||||
system_tags {
|
system_tags {
|
||||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
|
||||||
type: array
|
type: array
|
||||||
items {type: string}
|
description: "System tags. This field is reserved for system use, please don't use it."
|
||||||
|
items { type: string }
|
||||||
}
|
}
|
||||||
framework {
|
framework {
|
||||||
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
|
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
|
||||||
@ -98,9 +98,11 @@ _definitions {
|
|||||||
additionalProperties: true
|
additionalProperties: true
|
||||||
}
|
}
|
||||||
metadata {
|
metadata {
|
||||||
type: array
|
|
||||||
description: "Model metadata"
|
description: "Model metadata"
|
||||||
items {"$ref": "#/definitions/metadata_item"}
|
type: object
|
||||||
|
additionalProperties {
|
||||||
|
"$ref": "#/definitions/metadata_item"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -407,7 +409,7 @@ update_for_task {
|
|||||||
system_tags {
|
system_tags {
|
||||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||||
type: array
|
type: array
|
||||||
items {type: string}
|
items { type: string }
|
||||||
}
|
}
|
||||||
override_model_id {
|
override_model_id {
|
||||||
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
|
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
|
||||||
@ -473,7 +475,7 @@ create {
|
|||||||
system_tags {
|
system_tags {
|
||||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||||
type: array
|
type: array
|
||||||
items {type: string}
|
items { type: string }
|
||||||
}
|
}
|
||||||
framework {
|
framework {
|
||||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||||
@ -529,9 +531,11 @@ create {
|
|||||||
}
|
}
|
||||||
"2.13": ${create."2.1"} {
|
"2.13": ${create."2.1"} {
|
||||||
metadata {
|
metadata {
|
||||||
type: array
|
|
||||||
description: "Model metadata"
|
description: "Model metadata"
|
||||||
items {"$ref": "#/definitions/metadata_item"}
|
type: object
|
||||||
|
additionalProperties {
|
||||||
|
"$ref": "#/definitions/metadata_item"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -568,7 +572,7 @@ edit {
|
|||||||
system_tags {
|
system_tags {
|
||||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||||
type: array
|
type: array
|
||||||
items {type: string}
|
items { type: string }
|
||||||
}
|
}
|
||||||
framework {
|
framework {
|
||||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||||
@ -624,9 +628,11 @@ edit {
|
|||||||
}
|
}
|
||||||
"2.13": ${edit."2.1"} {
|
"2.13": ${edit."2.1"} {
|
||||||
metadata {
|
metadata {
|
||||||
type: array
|
|
||||||
description: "Model metadata"
|
description: "Model metadata"
|
||||||
items {"$ref": "#/definitions/metadata_item"}
|
type: object
|
||||||
|
additionalProperties {
|
||||||
|
"$ref": "#/definitions/metadata_item"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -657,7 +663,7 @@ update {
|
|||||||
system_tags {
|
system_tags {
|
||||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||||
type: array
|
type: array
|
||||||
items {type: string}
|
items { type: string }
|
||||||
}
|
}
|
||||||
ready {
|
ready {
|
||||||
description: "Indication if the model is final and can be used by other tasks Default is false."
|
description: "Indication if the model is final and can be used by other tasks Default is false."
|
||||||
@ -707,9 +713,11 @@ update {
|
|||||||
}
|
}
|
||||||
"2.13": ${update."2.1"} {
|
"2.13": ${update."2.1"} {
|
||||||
metadata {
|
metadata {
|
||||||
type: array
|
|
||||||
description: "Model metadata"
|
description: "Model metadata"
|
||||||
items {"$ref": "#/definitions/metadata_item"}
|
type: object
|
||||||
|
additionalProperties {
|
||||||
|
"$ref": "#/definitions/metadata_item"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -718,7 +726,7 @@ publish_many {
|
|||||||
description: Publish models
|
description: Publish models
|
||||||
request {
|
request {
|
||||||
properties {
|
properties {
|
||||||
ids.description: "IDs of models to publish"
|
ids.description: "IDs of the models to publish"
|
||||||
force_publish_task {
|
force_publish_task {
|
||||||
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
|
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
|
||||||
type: boolean
|
type: boolean
|
||||||
@ -779,7 +787,7 @@ archive_many {
|
|||||||
description: Archive models
|
description: Archive models
|
||||||
request {
|
request {
|
||||||
properties {
|
properties {
|
||||||
ids.description: "IDs of models to archive"
|
ids.description: "IDs of the models to archive"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response {
|
response {
|
||||||
@ -815,10 +823,9 @@ delete_many {
|
|||||||
description: Delete models
|
description: Delete models
|
||||||
request {
|
request {
|
||||||
properties {
|
properties {
|
||||||
ids.description: "IDs of models to delete"
|
ids.description: "IDs of the models to delete"
|
||||||
force {
|
force {
|
||||||
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
|
description: "Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published."
|
||||||
"""
|
|
||||||
type: boolean
|
type: boolean
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -975,6 +982,11 @@ add_or_update_metadata {
|
|||||||
description: "Metadata items to add or update"
|
description: "Metadata items to add or update"
|
||||||
items {"$ref": "#/definitions/metadata_item"}
|
items {"$ref": "#/definitions/metadata_item"}
|
||||||
}
|
}
|
||||||
|
replace_metadata {
|
||||||
|
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response {
|
response {
|
||||||
|
@ -42,15 +42,20 @@ _definitions {
|
|||||||
type: string
|
type: string
|
||||||
format: "date-time"
|
format: "date-time"
|
||||||
}
|
}
|
||||||
|
last_update {
|
||||||
|
description: "Last update time"
|
||||||
|
type: string
|
||||||
|
format: "date-time"
|
||||||
|
}
|
||||||
tags {
|
tags {
|
||||||
type: array
|
|
||||||
description: "User-defined tags"
|
description: "User-defined tags"
|
||||||
|
type: array
|
||||||
items { type: string }
|
items { type: string }
|
||||||
}
|
}
|
||||||
system_tags {
|
system_tags {
|
||||||
type: array
|
|
||||||
description: "System tags. This field is reserved for system use, please don't use it."
|
description: "System tags. This field is reserved for system use, please don't use it."
|
||||||
items {type: string}
|
type: array
|
||||||
|
items { type: string }
|
||||||
}
|
}
|
||||||
default_output_destination {
|
default_output_destination {
|
||||||
description: "The default output destination URL for new tasks under this project"
|
description: "The default output destination URL for new tasks under this project"
|
||||||
@ -70,6 +75,18 @@ _definitions {
|
|||||||
description: "Total run time of all tasks in project (in seconds)"
|
description: "Total run time of all tasks in project (in seconds)"
|
||||||
type: integer
|
type: integer
|
||||||
}
|
}
|
||||||
|
total_tasks {
|
||||||
|
description: "Number of tasks"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
completed_tasks_24h {
|
||||||
|
description: "Number of tasks completed in the last 24 hours"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
last_task_run {
|
||||||
|
description: "The most recent started time of a task"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
status_count {
|
status_count {
|
||||||
description: "Status counts"
|
description: "Status counts"
|
||||||
type: object
|
type: object
|
||||||
@ -78,6 +95,10 @@ _definitions {
|
|||||||
description: "Number of 'created' tasks in project"
|
description: "Number of 'created' tasks in project"
|
||||||
type: integer
|
type: integer
|
||||||
}
|
}
|
||||||
|
completed {
|
||||||
|
description: "Number of 'completed' tasks in project"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
queued {
|
queued {
|
||||||
description: "Number of 'queued' tasks in project"
|
description: "Number of 'queued' tasks in project"
|
||||||
type: integer
|
type: integer
|
||||||
@ -158,14 +179,14 @@ _definitions {
|
|||||||
format: "date-time"
|
format: "date-time"
|
||||||
}
|
}
|
||||||
tags {
|
tags {
|
||||||
type: array
|
|
||||||
description: "User-defined tags"
|
description: "User-defined tags"
|
||||||
|
type: array
|
||||||
items { type: string }
|
items { type: string }
|
||||||
}
|
}
|
||||||
system_tags {
|
system_tags {
|
||||||
type: array
|
|
||||||
description: "System tags. This field is reserved for system use, please don't use it."
|
description: "System tags. This field is reserved for system use, please don't use it."
|
||||||
items {type: string}
|
type: array
|
||||||
|
items { type: string }
|
||||||
}
|
}
|
||||||
default_output_destination {
|
default_output_destination {
|
||||||
description: "The default output destination URL for new tasks under this project"
|
description: "The default output destination URL for new tasks under this project"
|
||||||
@ -299,14 +320,14 @@ create {
|
|||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
tags {
|
tags {
|
||||||
type: array
|
|
||||||
description: "User-defined tags"
|
description: "User-defined tags"
|
||||||
|
type: array
|
||||||
items { type: string }
|
items { type: string }
|
||||||
}
|
}
|
||||||
system_tags {
|
system_tags {
|
||||||
type: array
|
|
||||||
description: "System tags. This field is reserved for system use, please don't use it."
|
description: "System tags. This field is reserved for system use, please don't use it."
|
||||||
items {type: string}
|
type: array
|
||||||
|
items { type: string }
|
||||||
}
|
}
|
||||||
default_output_destination {
|
default_output_destination {
|
||||||
description: "The default output destination URL for new tasks under this project"
|
description: "The default output destination URL for new tasks under this project"
|
||||||
@ -419,7 +440,6 @@ get_all {
|
|||||||
description: "Projects list"
|
description: "Projects list"
|
||||||
type: array
|
type: array
|
||||||
items { "$ref": "#/definitions/projects_get_all_response_single" }
|
items { "$ref": "#/definitions/projects_get_all_response_single" }
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -545,42 +565,6 @@ get_all_ex {
|
|||||||
type: boolean
|
type: boolean
|
||||||
default: true
|
default: true
|
||||||
}
|
}
|
||||||
response {
|
|
||||||
properties {
|
|
||||||
stats {
|
|
||||||
properties {
|
|
||||||
active.properties {
|
|
||||||
total_tasks {
|
|
||||||
description: "Number of tasks"
|
|
||||||
type: integer
|
|
||||||
}
|
|
||||||
completed_tasks {
|
|
||||||
description: "Number of tasks completed in the last 24 hours"
|
|
||||||
type: integer
|
|
||||||
}
|
|
||||||
running_tasks {
|
|
||||||
description: "Number of running tasks"
|
|
||||||
type: integer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
archived.properties {
|
|
||||||
total_tasks {
|
|
||||||
description: "Number of tasks"
|
|
||||||
type: integer
|
|
||||||
}
|
|
||||||
completed_tasks {
|
|
||||||
description: "Number of tasks completed in the last 24 hours"
|
|
||||||
type: integer
|
|
||||||
}
|
|
||||||
running_tasks {
|
|
||||||
description: "Number of running tasks"
|
|
||||||
type: integer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
update {
|
update {
|
||||||
@ -603,14 +587,14 @@ update {
|
|||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
tags {
|
tags {
|
||||||
|
description: "User-defined tags list"
|
||||||
type: array
|
type: array
|
||||||
description: "User-defined tags"
|
|
||||||
items { type: string }
|
items { type: string }
|
||||||
}
|
}
|
||||||
system_tags {
|
system_tags {
|
||||||
|
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||||
type: array
|
type: array
|
||||||
description: "System tags. This field is reserved for system use, please don't use it."
|
items { type: string }
|
||||||
items {type: string}
|
|
||||||
}
|
}
|
||||||
default_output_destination {
|
default_output_destination {
|
||||||
description: "The default output destination URL for new tasks under this project"
|
description: "The default output destination URL for new tasks under this project"
|
||||||
@ -748,7 +732,6 @@ delete {
|
|||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response {
|
response {
|
||||||
@ -881,6 +864,7 @@ get_hyper_parameters {
|
|||||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||||
request {
|
request {
|
||||||
type: object
|
type: object
|
||||||
|
required: [project]
|
||||||
properties {
|
properties {
|
||||||
project {
|
project {
|
||||||
description: "Project ID"
|
description: "Project ID"
|
||||||
@ -929,6 +913,55 @@ get_hyper_parameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
get_model_metadata_keys {
|
||||||
|
"999.0" {
|
||||||
|
description: """Get a list of all metadata keys used in models within the given project."""
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [project]
|
||||||
|
properties {
|
||||||
|
project {
|
||||||
|
description: "Project ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
include_subprojects {
|
||||||
|
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
}
|
||||||
|
|
||||||
|
page {
|
||||||
|
description: "Page number"
|
||||||
|
default: 0
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
page_size {
|
||||||
|
description: "Page size"
|
||||||
|
default: 500
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
keys {
|
||||||
|
description: "A list of model keys"
|
||||||
|
type: array
|
||||||
|
items {type: string}
|
||||||
|
}
|
||||||
|
remaining {
|
||||||
|
description: "Remaining results"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
total {
|
||||||
|
description: "Total number of results"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
get_task_tags {
|
get_task_tags {
|
||||||
"2.8" {
|
"2.8" {
|
||||||
description: "Get user and system tags used for the tasks under the specified projects"
|
description: "Get user and system tags used for the tasks under the specified projects"
|
||||||
@ -936,7 +969,6 @@ get_task_tags {
|
|||||||
response = ${_definitions.tags_response}
|
response = ${_definitions.tags_response}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
get_model_tags {
|
get_model_tags {
|
||||||
"2.8" {
|
"2.8" {
|
||||||
description: "Get user and system tags used for the models under the specified projects"
|
description: "Get user and system tags used for the models under the specified projects"
|
||||||
@ -1058,4 +1090,4 @@ get_task_parents {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -79,9 +79,11 @@ _definitions {
|
|||||||
items { "$ref": "#/definitions/entry" }
|
items { "$ref": "#/definitions/entry" }
|
||||||
}
|
}
|
||||||
metadata {
|
metadata {
|
||||||
type: array
|
|
||||||
description: "Queue metadata"
|
description: "Queue metadata"
|
||||||
items {"$ref": "#/definitions/metadata_item"}
|
type: object
|
||||||
|
additionalProperties {
|
||||||
|
"$ref": "#/definitions/metadata_item"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -281,6 +283,15 @@ create {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"2.13": ${create."2.4"} {
|
||||||
|
metadata {
|
||||||
|
description: "Queue metadata"
|
||||||
|
type: object
|
||||||
|
additionalProperties {
|
||||||
|
"$ref": "#/definitions/metadata_item"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
update {
|
update {
|
||||||
"2.4" {
|
"2.4" {
|
||||||
@ -322,7 +333,15 @@ update {
|
|||||||
type: object
|
type: object
|
||||||
additionalProperties: true
|
additionalProperties: true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"2.13": ${update."2.4"} {
|
||||||
|
metadata {
|
||||||
|
description: "Queue metadata"
|
||||||
|
type: object
|
||||||
|
additionalProperties {
|
||||||
|
"$ref": "#/definitions/metadata_item"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -632,6 +651,11 @@ add_or_update_metadata {
|
|||||||
description: "Metadata items to add or update"
|
description: "Metadata items to add or update"
|
||||||
items {"$ref": "#/definitions/metadata_item"}
|
items {"$ref": "#/definitions/metadata_item"}
|
||||||
}
|
}
|
||||||
|
replace_metadata {
|
||||||
|
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response {
|
response {
|
||||||
|
@ -21,16 +21,14 @@ from apiserver.apimodels.models import (
|
|||||||
ModelsPublishManyRequest,
|
ModelsPublishManyRequest,
|
||||||
ModelsDeleteManyRequest,
|
ModelsDeleteManyRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.model import ModelBLL
|
from apiserver.bll.model import ModelBLL, Metadata
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||||
from apiserver.bll.task import TaskBLL
|
from apiserver.bll.task import TaskBLL
|
||||||
from apiserver.bll.task.task_operations import publish_task
|
from apiserver.bll.task.task_operations import publish_task
|
||||||
from apiserver.bll.util import run_batch_operation
|
from apiserver.bll.util import run_batch_operation
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.database.errors import translate_errors_context
|
|
||||||
from apiserver.database.model import validate_id
|
from apiserver.database.model import validate_id
|
||||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
|
||||||
from apiserver.database.model.model import Model
|
from apiserver.database.model.model import Model
|
||||||
from apiserver.database.model.project import Project
|
from apiserver.database.model.project import Project
|
||||||
from apiserver.database.model.task.task import (
|
from apiserver.database.model.task.task import (
|
||||||
@ -50,8 +48,8 @@ from apiserver.services.utils import (
|
|||||||
conform_tag_fields,
|
conform_tag_fields,
|
||||||
conform_output_tags,
|
conform_output_tags,
|
||||||
ModelsBackwardsCompatibility,
|
ModelsBackwardsCompatibility,
|
||||||
validate_metadata,
|
unescape_metadata,
|
||||||
get_metadata_from_api,
|
escape_metadata,
|
||||||
)
|
)
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
|
|
||||||
@ -64,19 +62,20 @@ project_bll = ProjectBLL()
|
|||||||
def get_by_id(call: APICall, company_id, _):
|
def get_by_id(call: APICall, company_id, _):
|
||||||
model_id = call.data["model"]
|
model_id = call.data["model"]
|
||||||
|
|
||||||
with translate_errors_context():
|
Metadata.escape_query_parameters(call)
|
||||||
models = Model.get_many(
|
models = Model.get_many(
|
||||||
company=company_id,
|
company=company_id,
|
||||||
query_dict=call.data,
|
query_dict=call.data,
|
||||||
query=Q(id=model_id),
|
query=Q(id=model_id),
|
||||||
allow_public=True,
|
allow_public=True,
|
||||||
|
)
|
||||||
|
if not models:
|
||||||
|
raise errors.bad_request.InvalidModelId(
|
||||||
|
"no such public or company model", id=model_id, company=company_id,
|
||||||
)
|
)
|
||||||
if not models:
|
conform_output_tags(call, models[0])
|
||||||
raise errors.bad_request.InvalidModelId(
|
unescape_metadata(call, models[0])
|
||||||
"no such public or company model", id=model_id, company=company_id,
|
call.result.data = {"model": models[0]}
|
||||||
)
|
|
||||||
conform_output_tags(call, models[0])
|
|
||||||
call.result.data = {"model": models[0]}
|
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||||
@ -86,25 +85,25 @@ def get_by_task_id(call: APICall, company_id, _):
|
|||||||
|
|
||||||
task_id = call.data["task"]
|
task_id = call.data["task"]
|
||||||
|
|
||||||
with translate_errors_context():
|
query = dict(id=task_id, company=company_id)
|
||||||
query = dict(id=task_id, company=company_id)
|
task = Task.get(_only=["models"], **query)
|
||||||
task = Task.get(_only=["models"], **query)
|
if not task:
|
||||||
if not task:
|
raise errors.bad_request.InvalidTaskId(**query)
|
||||||
raise errors.bad_request.InvalidTaskId(**query)
|
if not task.models or not task.models.output:
|
||||||
if not task.models or not task.models.output:
|
raise errors.bad_request.MissingTaskFields(field="models.output")
|
||||||
raise errors.bad_request.MissingTaskFields(field="models.output")
|
|
||||||
|
|
||||||
model_id = task.models.output[-1].model
|
model_id = task.models.output[-1].model
|
||||||
model = Model.objects(
|
model = Model.objects(
|
||||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||||
).first()
|
).first()
|
||||||
if not model:
|
if not model:
|
||||||
raise errors.bad_request.InvalidModelId(
|
raise errors.bad_request.InvalidModelId(
|
||||||
"no such public or company model", id=model_id, company=company_id,
|
"no such public or company model", id=model_id, company=company_id,
|
||||||
)
|
)
|
||||||
model_dict = model.to_proper_dict()
|
model_dict = model.to_proper_dict()
|
||||||
conform_output_tags(call, model_dict)
|
conform_output_tags(call, model_dict)
|
||||||
call.result.data = {"model": model_dict}
|
unescape_metadata(call, model_dict)
|
||||||
|
call.result.data = {"model": model_dict}
|
||||||
|
|
||||||
|
|
||||||
def _process_include_subprojects(call_data: dict):
|
def _process_include_subprojects(call_data: dict):
|
||||||
@ -121,47 +120,50 @@ def _process_include_subprojects(call_data: dict):
|
|||||||
@endpoint("models.get_all_ex", required_fields=[])
|
@endpoint("models.get_all_ex", required_fields=[])
|
||||||
def get_all_ex(call: APICall, company_id, _):
|
def get_all_ex(call: APICall, company_id, _):
|
||||||
conform_tag_fields(call, call.data)
|
conform_tag_fields(call, call.data)
|
||||||
with translate_errors_context():
|
_process_include_subprojects(call.data)
|
||||||
_process_include_subprojects(call.data)
|
Metadata.escape_query_parameters(call)
|
||||||
with TimingContext("mongo", "models_get_all_ex"):
|
with TimingContext("mongo", "models_get_all_ex"):
|
||||||
ret_params = {}
|
ret_params = {}
|
||||||
models = Model.get_many_with_join(
|
models = Model.get_many_with_join(
|
||||||
company=company_id,
|
company=company_id,
|
||||||
query_dict=call.data,
|
query_dict=call.data,
|
||||||
allow_public=True,
|
allow_public=True,
|
||||||
ret_params=ret_params,
|
ret_params=ret_params,
|
||||||
)
|
)
|
||||||
conform_output_tags(call, models)
|
conform_output_tags(call, models)
|
||||||
call.result.data = {"models": models, **ret_params}
|
unescape_metadata(call, models)
|
||||||
|
call.result.data = {"models": models, **ret_params}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.get_by_id_ex", required_fields=["id"])
|
@endpoint("models.get_by_id_ex", required_fields=["id"])
|
||||||
def get_by_id_ex(call: APICall, company_id, _):
|
def get_by_id_ex(call: APICall, company_id, _):
|
||||||
conform_tag_fields(call, call.data)
|
conform_tag_fields(call, call.data)
|
||||||
with translate_errors_context():
|
Metadata.escape_query_parameters(call)
|
||||||
with TimingContext("mongo", "models_get_by_id_ex"):
|
with TimingContext("mongo", "models_get_by_id_ex"):
|
||||||
models = Model.get_many_with_join(
|
models = Model.get_many_with_join(
|
||||||
company=company_id, query_dict=call.data, allow_public=True
|
company=company_id, query_dict=call.data, allow_public=True
|
||||||
)
|
)
|
||||||
conform_output_tags(call, models)
|
conform_output_tags(call, models)
|
||||||
call.result.data = {"models": models}
|
unescape_metadata(call, models)
|
||||||
|
call.result.data = {"models": models}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.get_all", required_fields=[])
|
@endpoint("models.get_all", required_fields=[])
|
||||||
def get_all(call: APICall, company_id, _):
|
def get_all(call: APICall, company_id, _):
|
||||||
conform_tag_fields(call, call.data)
|
conform_tag_fields(call, call.data)
|
||||||
with translate_errors_context():
|
Metadata.escape_query_parameters(call)
|
||||||
with TimingContext("mongo", "models_get_all"):
|
with TimingContext("mongo", "models_get_all"):
|
||||||
ret_params = {}
|
ret_params = {}
|
||||||
models = Model.get_many(
|
models = Model.get_many(
|
||||||
company=company_id,
|
company=company_id,
|
||||||
parameters=call.data,
|
parameters=call.data,
|
||||||
query_dict=call.data,
|
query_dict=call.data,
|
||||||
allow_public=True,
|
allow_public=True,
|
||||||
ret_params=ret_params,
|
ret_params=ret_params,
|
||||||
)
|
)
|
||||||
conform_output_tags(call, models)
|
conform_output_tags(call, models)
|
||||||
call.result.data = {"models": models, **ret_params}
|
unescape_metadata(call, models)
|
||||||
|
call.result.data = {"models": models, **ret_params}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
|
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
|
||||||
@ -189,15 +191,22 @@ create_fields = {
|
|||||||
"metadata": list,
|
"metadata": list,
|
||||||
}
|
}
|
||||||
|
|
||||||
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata", "system_tags", "tags")
|
last_update_fields = (
|
||||||
|
"uri",
|
||||||
|
"framework",
|
||||||
|
"design",
|
||||||
|
"labels",
|
||||||
|
"ready",
|
||||||
|
"metadata",
|
||||||
|
"system_tags",
|
||||||
|
"tags",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_model_fields(call, valid_fields):
|
def parse_model_fields(call, valid_fields):
|
||||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||||
conform_tag_fields(call, fields, validate=True)
|
conform_tag_fields(call, fields, validate=True)
|
||||||
metadata = fields.get("metadata")
|
escape_metadata(fields)
|
||||||
if metadata:
|
|
||||||
validate_metadata(metadata)
|
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
|
||||||
@ -231,82 +240,80 @@ def update_for_task(call: APICall, company_id, _):
|
|||||||
"exactly one field is required", fields=("uri", "override_model_id")
|
"exactly one field is required", fields=("uri", "override_model_id")
|
||||||
)
|
)
|
||||||
|
|
||||||
with translate_errors_context():
|
query = dict(id=task_id, company=company_id)
|
||||||
|
task = Task.get_for_writing(
|
||||||
|
id=task_id,
|
||||||
|
company=company_id,
|
||||||
|
_only=["models", "execution", "name", "status", "project"],
|
||||||
|
)
|
||||||
|
if not task:
|
||||||
|
raise errors.bad_request.InvalidTaskId(**query)
|
||||||
|
|
||||||
query = dict(id=task_id, company=company_id)
|
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||||
task = Task.get_for_writing(
|
if task.status not in allowed_states:
|
||||||
id=task_id,
|
raise errors.bad_request.InvalidTaskStatus(
|
||||||
|
f"model can only be updated for tasks in the {allowed_states} states",
|
||||||
|
**query,
|
||||||
|
)
|
||||||
|
|
||||||
|
if override_model_id:
|
||||||
|
model = ModelBLL.get_company_model_by_id(
|
||||||
|
company_id=company_id, model_id=override_model_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "name" not in call.data:
|
||||||
|
# use task name if name not provided
|
||||||
|
call.data["name"] = task.name
|
||||||
|
|
||||||
|
if "comment" not in call.data:
|
||||||
|
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||||
|
|
||||||
|
if task.models and task.models.output:
|
||||||
|
# model exists, update
|
||||||
|
model_id = task.models.output[-1].model
|
||||||
|
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
||||||
|
res.update({"id": model_id, "created": False})
|
||||||
|
call.result.data = res
|
||||||
|
return
|
||||||
|
|
||||||
|
# new model, create
|
||||||
|
fields = parse_model_fields(call, create_fields)
|
||||||
|
|
||||||
|
# create and save model
|
||||||
|
now = datetime.utcnow()
|
||||||
|
model = Model(
|
||||||
|
id=database.utils.id(),
|
||||||
|
created=now,
|
||||||
|
last_update=now,
|
||||||
|
user=call.identity.user,
|
||||||
company=company_id,
|
company=company_id,
|
||||||
_only=["models", "execution", "name", "status", "project"],
|
project=task.project,
|
||||||
|
framework=task.execution.framework,
|
||||||
|
parent=task.models.input[0].model
|
||||||
|
if task.models and task.models.input
|
||||||
|
else None,
|
||||||
|
design=task.execution.model_desc,
|
||||||
|
labels=task.execution.model_labels,
|
||||||
|
ready=(task.status == TaskStatus.published),
|
||||||
|
**fields,
|
||||||
)
|
)
|
||||||
if not task:
|
model.save()
|
||||||
raise errors.bad_request.InvalidTaskId(**query)
|
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||||
|
|
||||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
TaskBLL.update_statistics(
|
||||||
if task.status not in allowed_states:
|
task_id=task_id,
|
||||||
raise errors.bad_request.InvalidTaskStatus(
|
company_id=company_id,
|
||||||
f"model can only be updated for tasks in the {allowed_states} states",
|
last_iteration_max=iteration,
|
||||||
**query,
|
models__output=[
|
||||||
|
ModelItem(
|
||||||
|
model=model.id,
|
||||||
|
name=TaskModelNames[TaskModelTypes.output],
|
||||||
|
updated=datetime.utcnow(),
|
||||||
)
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if override_model_id:
|
call.result.data = {"id": model.id, "created": True}
|
||||||
model = ModelBLL.get_company_model_by_id(
|
|
||||||
company_id=company_id, model_id=override_model_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if "name" not in call.data:
|
|
||||||
# use task name if name not provided
|
|
||||||
call.data["name"] = task.name
|
|
||||||
|
|
||||||
if "comment" not in call.data:
|
|
||||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
|
||||||
|
|
||||||
if task.models and task.models.output:
|
|
||||||
# model exists, update
|
|
||||||
model_id = task.models.output[-1].model
|
|
||||||
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
|
||||||
res.update({"id": model_id, "created": False})
|
|
||||||
call.result.data = res
|
|
||||||
return
|
|
||||||
|
|
||||||
# new model, create
|
|
||||||
fields = parse_model_fields(call, create_fields)
|
|
||||||
|
|
||||||
# create and save model
|
|
||||||
now = datetime.utcnow()
|
|
||||||
model = Model(
|
|
||||||
id=database.utils.id(),
|
|
||||||
created=now,
|
|
||||||
last_update=now,
|
|
||||||
user=call.identity.user,
|
|
||||||
company=company_id,
|
|
||||||
project=task.project,
|
|
||||||
framework=task.execution.framework,
|
|
||||||
parent=task.models.input[0].model
|
|
||||||
if task.models and task.models.input
|
|
||||||
else None,
|
|
||||||
design=task.execution.model_desc,
|
|
||||||
labels=task.execution.model_labels,
|
|
||||||
ready=(task.status == TaskStatus.published),
|
|
||||||
**fields,
|
|
||||||
)
|
|
||||||
model.save()
|
|
||||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
|
||||||
|
|
||||||
TaskBLL.update_statistics(
|
|
||||||
task_id=task_id,
|
|
||||||
company_id=company_id,
|
|
||||||
last_iteration_max=iteration,
|
|
||||||
models__output=[
|
|
||||||
ModelItem(
|
|
||||||
model=model.id,
|
|
||||||
name=TaskModelNames[TaskModelTypes.output],
|
|
||||||
updated=datetime.utcnow(),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
call.result.data = {"id": model.id, "created": True}
|
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
@ -319,36 +326,33 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
|||||||
if req_model.public:
|
if req_model.public:
|
||||||
company_id = ""
|
company_id = ""
|
||||||
|
|
||||||
with translate_errors_context():
|
project = req_model.project
|
||||||
|
if project:
|
||||||
|
validate_id(Project, company=company_id, project=project)
|
||||||
|
|
||||||
project = req_model.project
|
task = req_model.task
|
||||||
if project:
|
req_data = req_model.to_struct()
|
||||||
validate_id(Project, company=company_id, project=project)
|
if task:
|
||||||
|
validate_task(company_id, req_data)
|
||||||
|
|
||||||
task = req_model.task
|
fields = filter_fields(Model, req_data)
|
||||||
req_data = req_model.to_struct()
|
conform_tag_fields(call, fields, validate=True)
|
||||||
if task:
|
escape_metadata(fields)
|
||||||
validate_task(company_id, req_data)
|
|
||||||
|
|
||||||
fields = filter_fields(Model, req_data)
|
# create and save model
|
||||||
conform_tag_fields(call, fields, validate=True)
|
now = datetime.utcnow()
|
||||||
|
model = Model(
|
||||||
|
id=database.utils.id(),
|
||||||
|
user=call.identity.user,
|
||||||
|
company=company_id,
|
||||||
|
created=now,
|
||||||
|
last_update=now,
|
||||||
|
**fields,
|
||||||
|
)
|
||||||
|
model.save()
|
||||||
|
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||||
|
|
||||||
validate_metadata(fields.get("metadata"))
|
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||||
|
|
||||||
# create and save model
|
|
||||||
now = datetime.utcnow()
|
|
||||||
model = Model(
|
|
||||||
id=database.utils.id(),
|
|
||||||
user=call.identity.user,
|
|
||||||
company=company_id,
|
|
||||||
created=now,
|
|
||||||
last_update=now,
|
|
||||||
**fields,
|
|
||||||
)
|
|
||||||
model.save()
|
|
||||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
|
||||||
|
|
||||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_update_fields(call, company_id, fields: dict):
|
def prepare_update_fields(call, company_id, fields: dict):
|
||||||
@ -383,6 +387,7 @@ def prepare_update_fields(call, company_id, fields: dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
conform_tag_fields(call, fields, validate=True)
|
conform_tag_fields(call, fields, validate=True)
|
||||||
|
escape_metadata(fields)
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
|
||||||
@ -394,89 +399,85 @@ def validate_task(company_id, fields: dict):
|
|||||||
def edit(call: APICall, company_id, _):
|
def edit(call: APICall, company_id, _):
|
||||||
model_id = call.data["model"]
|
model_id = call.data["model"]
|
||||||
|
|
||||||
with translate_errors_context():
|
model = ModelBLL.get_company_model_by_id(
|
||||||
model = ModelBLL.get_company_model_by_id(
|
company_id=company_id, model_id=model_id
|
||||||
company_id=company_id, model_id=model_id
|
)
|
||||||
|
|
||||||
|
fields = parse_model_fields(call, create_fields)
|
||||||
|
fields = prepare_update_fields(call, company_id, fields)
|
||||||
|
|
||||||
|
for key in fields:
|
||||||
|
field = getattr(model, key, None)
|
||||||
|
value = fields[key]
|
||||||
|
if (
|
||||||
|
field
|
||||||
|
and isinstance(value, dict)
|
||||||
|
and isinstance(field, EmbeddedDocument)
|
||||||
|
):
|
||||||
|
d = field.to_mongo(use_db_field=False).to_dict()
|
||||||
|
d.update(value)
|
||||||
|
fields[key] = d
|
||||||
|
|
||||||
|
iteration = call.data.get("iteration")
|
||||||
|
task_id = model.task or fields.get("task")
|
||||||
|
if task_id and iteration is not None:
|
||||||
|
TaskBLL.update_statistics(
|
||||||
|
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||||
)
|
)
|
||||||
|
|
||||||
fields = parse_model_fields(call, create_fields)
|
if fields:
|
||||||
fields = prepare_update_fields(call, company_id, fields)
|
if any(uf in fields for uf in last_update_fields):
|
||||||
|
fields.update(last_update=datetime.utcnow())
|
||||||
|
|
||||||
for key in fields:
|
updated = model.update(upsert=False, **fields)
|
||||||
field = getattr(model, key, None)
|
if updated:
|
||||||
value = fields[key]
|
new_project = fields.get("project", model.project)
|
||||||
if (
|
if new_project != model.project:
|
||||||
field
|
_reset_cached_tags(
|
||||||
and isinstance(value, dict)
|
company_id, projects=[new_project, model.project]
|
||||||
and isinstance(field, EmbeddedDocument)
|
)
|
||||||
):
|
else:
|
||||||
d = field.to_mongo(use_db_field=False).to_dict()
|
_update_cached_tags(
|
||||||
d.update(value)
|
company_id, project=model.project, fields=fields
|
||||||
fields[key] = d
|
)
|
||||||
|
conform_output_tags(call, fields)
|
||||||
iteration = call.data.get("iteration")
|
unescape_metadata(call, fields)
|
||||||
task_id = model.task or fields.get("task")
|
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||||
if task_id and iteration is not None:
|
else:
|
||||||
TaskBLL.update_statistics(
|
call.result.data_model = UpdateResponse(updated=0)
|
||||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
|
||||||
)
|
|
||||||
|
|
||||||
if fields:
|
|
||||||
if any(uf in fields for uf in last_update_fields):
|
|
||||||
fields.update(last_update=datetime.utcnow())
|
|
||||||
|
|
||||||
updated = model.update(upsert=False, **fields)
|
|
||||||
if updated:
|
|
||||||
new_project = fields.get("project", model.project)
|
|
||||||
if new_project != model.project:
|
|
||||||
_reset_cached_tags(
|
|
||||||
company_id, projects=[new_project, model.project]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_update_cached_tags(
|
|
||||||
company_id, project=model.project, fields=fields
|
|
||||||
)
|
|
||||||
conform_output_tags(call, fields)
|
|
||||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
|
||||||
else:
|
|
||||||
call.result.data_model = UpdateResponse(updated=0)
|
|
||||||
|
|
||||||
|
|
||||||
def _update_model(call: APICall, company_id, model_id=None):
|
def _update_model(call: APICall, company_id, model_id=None):
|
||||||
model_id = model_id or call.data["model"]
|
model_id = model_id or call.data["model"]
|
||||||
|
|
||||||
with translate_errors_context():
|
model = ModelBLL.get_company_model_by_id(
|
||||||
model = ModelBLL.get_company_model_by_id(
|
company_id=company_id, model_id=model_id
|
||||||
company_id=company_id, model_id=model_id
|
)
|
||||||
|
|
||||||
|
data = prepare_update_fields(call, company_id, call.data)
|
||||||
|
|
||||||
|
task_id = data.get("task")
|
||||||
|
iteration = data.get("iteration")
|
||||||
|
if task_id and iteration is not None:
|
||||||
|
TaskBLL.update_statistics(
|
||||||
|
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = prepare_update_fields(call, company_id, call.data)
|
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||||
|
if updated_count:
|
||||||
|
if any(uf in updated_fields for uf in last_update_fields):
|
||||||
|
model.update(upsert=False, last_update=datetime.utcnow())
|
||||||
|
|
||||||
task_id = data.get("task")
|
new_project = updated_fields.get("project", model.project)
|
||||||
iteration = data.get("iteration")
|
if new_project != model.project:
|
||||||
if task_id and iteration is not None:
|
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||||
TaskBLL.update_statistics(
|
else:
|
||||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
_update_cached_tags(
|
||||||
|
company_id, project=model.project, fields=updated_fields
|
||||||
)
|
)
|
||||||
|
conform_output_tags(call, updated_fields)
|
||||||
metadata = data.get("metadata")
|
unescape_metadata(call, updated_fields)
|
||||||
if metadata:
|
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||||
validate_metadata(metadata)
|
|
||||||
|
|
||||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
|
||||||
if updated_count:
|
|
||||||
if any(uf in updated_fields for uf in last_update_fields):
|
|
||||||
model.update(upsert=False, last_update=datetime.utcnow())
|
|
||||||
|
|
||||||
new_project = updated_fields.get("project", model.project)
|
|
||||||
if new_project != model.project:
|
|
||||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
|
||||||
else:
|
|
||||||
_update_cached_tags(
|
|
||||||
company_id, project=model.project, fields=updated_fields
|
|
||||||
)
|
|
||||||
conform_output_tags(call, updated_fields)
|
|
||||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
@ -641,26 +642,25 @@ def add_or_update_metadata(
|
|||||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||||
):
|
):
|
||||||
model_id = request.model
|
model_id = request.model
|
||||||
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||||
|
return {
|
||||||
updated = metadata_add_or_update(
|
"updated": Metadata.edit_metadata(
|
||||||
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
|
model,
|
||||||
)
|
items=request.metadata,
|
||||||
if updated:
|
replace_metadata=request.replace_metadata,
|
||||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
last_update=datetime.utcnow(),
|
||||||
|
)
|
||||||
return {"updated": updated}
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.delete_metadata", min_version="2.13")
|
@endpoint("models.delete_metadata", min_version="2.13")
|
||||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||||
model_id = request.model
|
model_id = request.model
|
||||||
ModelBLL.get_company_model_by_id(
|
model = ModelBLL.get_company_model_by_id(
|
||||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
|
"updated": Metadata.delete_metadata(
|
||||||
if updated:
|
model, keys=request.keys, last_update=datetime.utcnow()
|
||||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
)
|
||||||
|
}
|
||||||
return {"updated": updated}
|
|
||||||
|
@ -275,6 +275,23 @@ def get_unique_metric_variants(
|
|||||||
call.result.data = {"metrics": metrics}
|
call.result.data = {"metrics": metrics}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("projects.get_model_metadata_keys",)
|
||||||
|
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
|
||||||
|
total, remaining, keys = project_queries.get_model_metadata_keys(
|
||||||
|
company_id,
|
||||||
|
project_ids=[request.project] if request.project else None,
|
||||||
|
include_subprojects=request.include_subprojects,
|
||||||
|
page=request.page,
|
||||||
|
page_size=request.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
call.result.data = {
|
||||||
|
"total": total,
|
||||||
|
"remaining": remaining,
|
||||||
|
"keys": keys,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"projects.get_hyper_parameters",
|
"projects.get_hyper_parameters",
|
||||||
min_version="2.9",
|
min_version="2.9",
|
||||||
|
@ -13,17 +13,19 @@ from apiserver.apimodels.queues import (
|
|||||||
QueueMetrics,
|
QueueMetrics,
|
||||||
AddOrUpdateMetadataRequest,
|
AddOrUpdateMetadataRequest,
|
||||||
DeleteMetadataRequest,
|
DeleteMetadataRequest,
|
||||||
|
GetNextTaskRequest,
|
||||||
)
|
)
|
||||||
|
from apiserver.bll.model import Metadata
|
||||||
from apiserver.bll.queue import QueueBLL
|
from apiserver.bll.queue import QueueBLL
|
||||||
from apiserver.bll.workers import WorkerBLL
|
from apiserver.bll.workers import WorkerBLL
|
||||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
from apiserver.database.model.task.task import Task
|
||||||
from apiserver.database.model.queue import Queue
|
|
||||||
from apiserver.service_repo import APICall, endpoint
|
from apiserver.service_repo import APICall, endpoint
|
||||||
from apiserver.services.utils import (
|
from apiserver.services.utils import (
|
||||||
conform_tag_fields,
|
conform_tag_fields,
|
||||||
conform_output_tags,
|
conform_output_tags,
|
||||||
conform_tags,
|
conform_tags,
|
||||||
get_metadata_from_api,
|
escape_metadata,
|
||||||
|
unescape_metadata,
|
||||||
)
|
)
|
||||||
from apiserver.utilities import extract_properties_to_lists
|
from apiserver.utilities import extract_properties_to_lists
|
||||||
|
|
||||||
@ -36,6 +38,7 @@ def get_by_id(call: APICall, company_id, req_model: QueueRequest):
|
|||||||
queue = queue_bll.get_by_id(company_id, req_model.queue)
|
queue = queue_bll.get_by_id(company_id, req_model.queue)
|
||||||
queue_dict = queue.to_proper_dict()
|
queue_dict = queue.to_proper_dict()
|
||||||
conform_output_tags(call, queue_dict)
|
conform_output_tags(call, queue_dict)
|
||||||
|
unescape_metadata(call, queue_dict)
|
||||||
call.result.data = {"queue": queue_dict}
|
call.result.data = {"queue": queue_dict}
|
||||||
|
|
||||||
|
|
||||||
@ -49,13 +52,13 @@ def get_by_id(call: APICall):
|
|||||||
def get_all_ex(call: APICall):
|
def get_all_ex(call: APICall):
|
||||||
conform_tag_fields(call, call.data)
|
conform_tag_fields(call, call.data)
|
||||||
ret_params = {}
|
ret_params = {}
|
||||||
|
|
||||||
|
Metadata.escape_query_parameters(call)
|
||||||
queues = queue_bll.get_queue_infos(
|
queues = queue_bll.get_queue_infos(
|
||||||
company_id=call.identity.company,
|
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||||
query_dict=call.data,
|
|
||||||
ret_params=ret_params,
|
|
||||||
)
|
)
|
||||||
conform_output_tags(call, queues)
|
conform_output_tags(call, queues)
|
||||||
|
unescape_metadata(call, queues)
|
||||||
call.result.data = {"queues": queues, **ret_params}
|
call.result.data = {"queues": queues, **ret_params}
|
||||||
|
|
||||||
|
|
||||||
@ -63,13 +66,12 @@ def get_all_ex(call: APICall):
|
|||||||
def get_all(call: APICall):
|
def get_all(call: APICall):
|
||||||
conform_tag_fields(call, call.data)
|
conform_tag_fields(call, call.data)
|
||||||
ret_params = {}
|
ret_params = {}
|
||||||
|
Metadata.escape_query_parameters(call)
|
||||||
queues = queue_bll.get_all(
|
queues = queue_bll.get_all(
|
||||||
company_id=call.identity.company,
|
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||||
query_dict=call.data,
|
|
||||||
ret_params=ret_params,
|
|
||||||
)
|
)
|
||||||
conform_output_tags(call, queues)
|
conform_output_tags(call, queues)
|
||||||
|
unescape_metadata(call, queues)
|
||||||
call.result.data = {"queues": queues, **ret_params}
|
call.result.data = {"queues": queues, **ret_params}
|
||||||
|
|
||||||
|
|
||||||
@ -83,7 +85,7 @@ def create(call: APICall, company_id, request: CreateRequest):
|
|||||||
name=request.name,
|
name=request.name,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
system_tags=system_tags,
|
system_tags=system_tags,
|
||||||
metadata=get_metadata_from_api(request.metadata),
|
metadata=Metadata.metadata_from_api(request.metadata),
|
||||||
)
|
)
|
||||||
call.result.data = {"id": queue.id}
|
call.result.data = {"id": queue.id}
|
||||||
|
|
||||||
@ -97,10 +99,12 @@ def create(call: APICall, company_id, request: CreateRequest):
|
|||||||
def update(call: APICall, company_id, req_model: UpdateRequest):
|
def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||||
data = call.data_model_for_partial_update
|
data = call.data_model_for_partial_update
|
||||||
conform_tag_fields(call, data, validate=True)
|
conform_tag_fields(call, data, validate=True)
|
||||||
|
escape_metadata(data)
|
||||||
updated, fields = queue_bll.update(
|
updated, fields = queue_bll.update(
|
||||||
company_id=company_id, queue_id=req_model.queue, **data
|
company_id=company_id, queue_id=req_model.queue, **data
|
||||||
)
|
)
|
||||||
conform_output_tags(call, fields)
|
conform_output_tags(call, fields)
|
||||||
|
unescape_metadata(call, fields)
|
||||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||||
|
|
||||||
|
|
||||||
@ -121,11 +125,19 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("queues.get_next_task", min_version="2.4", request_data_model=QueueRequest)
|
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
|
||||||
def get_next_task(call: APICall, company_id, req_model: QueueRequest):
|
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
|
||||||
task = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
|
entry = queue_bll.get_next_task(
|
||||||
if task:
|
company_id=company_id, queue_id=req_model.queue
|
||||||
call.result.data = {"entry": task.to_proper_dict()}
|
)
|
||||||
|
if entry:
|
||||||
|
data = {"entry": entry.to_proper_dict()}
|
||||||
|
if req_model.get_task_info:
|
||||||
|
task = Task.objects(id=entry.task).first()
|
||||||
|
if task:
|
||||||
|
data["task_info"] = {"company": task.company, "user": task.user}
|
||||||
|
|
||||||
|
call.result.data = data
|
||||||
|
|
||||||
|
|
||||||
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
|
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
|
||||||
@ -245,21 +257,19 @@ def get_queue_metrics(
|
|||||||
|
|
||||||
@endpoint("queues.add_or_update_metadata", min_version="2.13")
|
@endpoint("queues.add_or_update_metadata", min_version="2.13")
|
||||||
def add_or_update_metadata(
|
def add_or_update_metadata(
|
||||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||||
):
|
):
|
||||||
queue_id = request.queue
|
queue_id = request.queue
|
||||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"updated": metadata_add_or_update(
|
"updated": Metadata.edit_metadata(
|
||||||
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
|
queue, items=request.metadata, replace_metadata=request.replace_metadata
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("queues.delete_metadata", min_version="2.13")
|
@endpoint("queues.delete_metadata", min_version="2.13")
|
||||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||||
queue_id = request.queue
|
queue_id = request.queue
|
||||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||||
|
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
|
||||||
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
|
|
||||||
|
@ -2,7 +2,6 @@ from datetime import datetime
|
|||||||
from typing import Union, Sequence, Tuple
|
from typing import Union, Sequence, Tuple
|
||||||
|
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
|
||||||
from apiserver.apimodels.organization import Filter
|
from apiserver.apimodels.organization import Filter
|
||||||
from apiserver.database.model.base import GetMixin
|
from apiserver.database.model.base import GetMixin
|
||||||
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||||
@ -222,22 +221,38 @@ class DockerCmdBackwardsCompatibility:
|
|||||||
nested_set(task, cls.field, docker_cmd)
|
nested_set(task, cls.field, docker_cmd)
|
||||||
|
|
||||||
|
|
||||||
def validate_metadata(metadata: Sequence[dict]):
|
def escape_metadata(document: dict):
|
||||||
|
"""
|
||||||
|
Escape special characters in metadata keys
|
||||||
|
"""
|
||||||
|
metadata = document.get("metadata")
|
||||||
if not metadata:
|
if not metadata:
|
||||||
return
|
return
|
||||||
|
|
||||||
keys = [m.get("key") for m in metadata]
|
document["metadata"] = {
|
||||||
unique_keys = set(keys)
|
ParameterKeyEscaper.escape(k): v
|
||||||
unique_keys.discard(None)
|
for k, v in metadata.items()
|
||||||
if len(keys) != len(set(keys)):
|
}
|
||||||
raise errors.bad_request.ValidationError("Metadata keys should be unique")
|
|
||||||
|
|
||||||
|
|
||||||
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
|
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||||
if not api_metadata:
|
"""
|
||||||
return api_metadata
|
Unescape special characters in metadata keys
|
||||||
|
"""
|
||||||
|
if isinstance(documents, dict):
|
||||||
|
documents = [documents]
|
||||||
|
|
||||||
metadata = [m.to_struct() for m in api_metadata]
|
old_client = call.requested_endpoint_version <= PartialVersion("2.16")
|
||||||
validate_metadata(metadata)
|
for doc in documents:
|
||||||
|
if old_client and "metadata" in doc:
|
||||||
|
doc["metadata"] = []
|
||||||
|
continue
|
||||||
|
|
||||||
return metadata
|
metadata = doc.get("metadata")
|
||||||
|
if not metadata:
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc["metadata"] = {
|
||||||
|
ParameterKeyEscaper.unescape(k): v
|
||||||
|
for k, v in metadata.items()
|
||||||
|
}
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Sequence
|
|
||||||
|
|
||||||
from apiserver.tests.api_client import APIClient
|
from apiserver.tests.api_client import APIClient
|
||||||
from apiserver.tests.automated import TestService
|
from apiserver.tests.automated import TestService
|
||||||
|
|
||||||
|
|
||||||
class TestQueueAndModelMetadata(TestService):
|
class TestQueueAndModelMetadata(TestService):
|
||||||
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
meta1 = {"test_key": {"key": "test_key", "type": "str", "value": "test_value"}}
|
||||||
|
|
||||||
def test_queue_metas(self):
|
def test_queue_metas(self):
|
||||||
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
|
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
|
||||||
@ -23,20 +22,43 @@ class TestQueueAndModelMetadata(TestService):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_id = self._temp_model("TestMetadata1")
|
model_id = self._temp_model("TestMetadata1")
|
||||||
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
self.api.models.edit(model=model_id, metadata=self.meta1)
|
||||||
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
||||||
|
|
||||||
|
def test_project_meta_query(self):
|
||||||
|
self._temp_model("TestMetadata", metadata=self.meta1)
|
||||||
|
project = self.temp_project(name="MetaParent")
|
||||||
|
model_id = self._temp_model(
|
||||||
|
"TestMetadata2",
|
||||||
|
project=project,
|
||||||
|
metadata={
|
||||||
|
"test_key": {"key": "test_key", "type": "str", "value": "test_value"},
|
||||||
|
"test_key2": {"key": "test_key2", "type": "str", "value": "test_value"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
res = self.api.projects.get_model_metadata_keys()
|
||||||
|
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
|
||||||
|
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
|
||||||
|
self.assertTrue("test_key" in res["keys"])
|
||||||
|
self.assertFalse("test_key2" in res["keys"])
|
||||||
|
|
||||||
|
model = self.api.models.get_all_ex(
|
||||||
|
id=[model_id], only_fields=["metadata.test_key"]
|
||||||
|
).models[0]
|
||||||
|
self.assertTrue("test_key" in model.metadata)
|
||||||
|
self.assertFalse("test_key2" in model.metadata)
|
||||||
|
|
||||||
def _test_meta_operations(
|
def _test_meta_operations(
|
||||||
self, service: APIClient.Service, entity: str, _id: str,
|
self, service: APIClient.Service, entity: str, _id: str,
|
||||||
):
|
):
|
||||||
assert_meta = partial(self._assertMeta, service=service, entity=entity)
|
assert_meta = partial(self._assertMeta, service=service, entity=entity)
|
||||||
assert_meta(_id=_id, meta=self.meta1)
|
assert_meta(_id=_id, meta=self.meta1)
|
||||||
|
|
||||||
meta2 = [
|
meta2 = {
|
||||||
{"key": "test1", "type": "str", "value": "data1"},
|
"test1": {"key": "test1", "type": "str", "value": "data1"},
|
||||||
{"key": "test2", "type": "str", "value": "data2"},
|
"test2": {"key": "test2", "type": "str", "value": "data2"},
|
||||||
{"key": "test3", "type": "str", "value": "data3"},
|
"test3": {"key": "test3", "type": "str", "value": "data3"},
|
||||||
]
|
}
|
||||||
service.update(**{entity: _id, "metadata": meta2})
|
service.update(**{entity: _id, "metadata": meta2})
|
||||||
assert_meta(_id=_id, meta=meta2)
|
assert_meta(_id=_id, meta=meta2)
|
||||||
|
|
||||||
@ -48,16 +70,17 @@ class TestQueueAndModelMetadata(TestService):
|
|||||||
]
|
]
|
||||||
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
|
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
|
||||||
self.assertEqual(res.updated, 1)
|
self.assertEqual(res.updated, 1)
|
||||||
assert_meta(_id=_id, meta=[meta2[0], *updates])
|
assert_meta(_id=_id, meta={**meta2, **{u["key"]: u for u in updates}})
|
||||||
|
|
||||||
res = service.delete_metadata(
|
res = service.delete_metadata(
|
||||||
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
|
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
|
||||||
)
|
)
|
||||||
self.assertEqual(res.updated, 1)
|
self.assertEqual(res.updated, 1)
|
||||||
assert_meta(_id=_id, meta=meta2[:1])
|
# noinspection PyTypeChecker
|
||||||
|
assert_meta(_id=_id, meta=dict(list(meta2.items())[:1]))
|
||||||
|
|
||||||
def _assertMeta(
|
def _assertMeta(
|
||||||
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
|
self, service: APIClient.Service, entity: str, _id: str, meta: dict
|
||||||
):
|
):
|
||||||
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
|
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
|
||||||
self.assertEqual(res.metadata, meta)
|
self.assertEqual(res.metadata, meta)
|
||||||
|
Loading…
Reference in New Issue
Block a user