mirror of
https://github.com/clearml/clearml-server
synced 2025-04-22 15:16:11 +00:00
Add metadata dict support for models, queues
Add more info for projects
This commit is contained in:
parent
04ea9018a3
commit
af09fba755
apiserver
apimodels
bll
database/model
mongo/migrations
schema/services
services
tests/automated
@ -1,7 +1,7 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
@ -21,3 +21,4 @@ class AddOrUpdateMetadata(Base):
|
||||
metadata: Sequence[MetadataItem] = ListField(
|
||||
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||
)
|
||||
replace_metadata = BoolField(default=False)
|
||||
|
@ -30,7 +30,7 @@ class CreateModelRequest(models.Base):
|
||||
ready = fields.BoolField(default=True)
|
||||
ui_cache = DictField()
|
||||
task = fields.StringField()
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class CreateModelResponse(models.Base):
|
||||
|
@ -2,7 +2,7 @@ from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
@ -19,13 +19,18 @@ class CreateRequest(Base):
|
||||
name = StringField(required=True)
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class QueueRequest(Base):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class GetNextTaskRequest(QueueRequest):
|
||||
queue = StringField(required=True)
|
||||
get_task_info = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteRequest(QueueRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
@ -34,7 +39,7 @@ class UpdateRequest(QueueRequest):
|
||||
name = StringField()
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class TaskRequest(QueueRequest):
|
||||
|
@ -7,6 +7,7 @@ from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from .metadata import Metadata
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
|
111
apiserver/bll/model/metadata.py
Normal file
111
apiserver/bll/model/metadata.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import Sequence, Union, Mapping
|
||||
|
||||
from mongoengine import Document
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Metadata:
|
||||
@staticmethod
|
||||
def metadata_from_api(
|
||||
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
|
||||
) -> dict:
|
||||
if not api_data:
|
||||
return {}
|
||||
|
||||
if isinstance(api_data, dict):
|
||||
return {
|
||||
ParameterKeyEscaper.escape(k): v.to_struct()
|
||||
for k, v in api_data.items()
|
||||
}
|
||||
|
||||
return {
|
||||
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_metadata(
|
||||
cls,
|
||||
obj: Document,
|
||||
items: Sequence[MetadataItem],
|
||||
replace_metadata: bool,
|
||||
**more_updates,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_metadata"):
|
||||
update_cmds = dict()
|
||||
metadata = cls.metadata_from_api(items)
|
||||
if replace_metadata:
|
||||
update_cmds["set__metadata"] = metadata
|
||||
else:
|
||||
for key, value in metadata.items():
|
||||
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
|
||||
|
||||
return obj.update(**update_cmds, **more_updates)
|
||||
|
||||
@classmethod
|
||||
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
|
||||
with TimingContext("mongo", "delete_metadata"):
|
||||
return obj.update(
|
||||
**{
|
||||
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
|
||||
for key in set(keys)
|
||||
},
|
||||
**more_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
|
||||
for prefix in (
|
||||
"metadata.",
|
||||
"-metadata.",
|
||||
):
|
||||
paths = [
|
||||
cls._process_path(path) if path.startswith(prefix) else path
|
||||
for path in paths
|
||||
]
|
||||
return paths
|
||||
|
||||
@classmethod
|
||||
def escape_query_parameters(cls, call: APICall) -> dict:
|
||||
if not call.data:
|
||||
return call.data
|
||||
|
||||
keys = list(call.data)
|
||||
call_data = {
|
||||
safe_key: call.data[key]
|
||||
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
|
||||
}
|
||||
|
||||
projection = GetMixin.get_projection(call_data)
|
||||
if projection:
|
||||
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
|
||||
|
||||
ordering = GetMixin.get_ordering(call_data)
|
||||
if ordering:
|
||||
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
|
||||
|
||||
return call_data
|
@ -388,6 +388,17 @@ class ProjectBLL:
|
||||
}
|
||||
}
|
||||
|
||||
def max_started_subquery(condition):
|
||||
return {
|
||||
"$max": {
|
||||
"$cond": {
|
||||
"if": condition,
|
||||
"then": "$started",
|
||||
"else": datetime.min,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
@ -431,14 +442,22 @@ class ProjectBLL:
|
||||
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
||||
cond, time_thresh=time_thresh
|
||||
)
|
||||
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
|
||||
|
||||
def get_state_filter() -> dict:
|
||||
if not specific_state:
|
||||
return {}
|
||||
if specific_state == EntityVisibility.archived:
|
||||
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
|
||||
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"type": {"$in": ["training", "testing", "annotation"]},
|
||||
"project": {"$in": project_ids},
|
||||
**get_state_filter(),
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
@ -547,6 +566,8 @@ class ProjectBLL:
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: a.get(section, 0) + b.get(section, 0)
|
||||
if not section.endswith("max_task_started")
|
||||
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
@ -562,6 +583,10 @@ class ProjectBLL:
|
||||
project_section_statuses = nested_get(
|
||||
status_count, (project_id, section), default=default_counts
|
||||
)
|
||||
|
||||
def get_time_or_none(value):
|
||||
return value if value != datetime.min else None
|
||||
|
||||
return {
|
||||
"status_count": project_section_statuses,
|
||||
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
|
||||
@ -570,6 +595,9 @@ class ProjectBLL:
|
||||
"completed_tasks": project_runtime.get(
|
||||
f"{section}_recently_completed", 0
|
||||
),
|
||||
"last_task_run": get_time_or_none(
|
||||
project_runtime.get(f"{section}_max_task_started", datetime.min)
|
||||
),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
@ -723,7 +751,9 @@ class ProjectBLL:
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
|
||||
@classmethod
|
||||
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
def calc_own_contents(
|
||||
cls, company: str, project_ids: Sequence[str]
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of task/models per requested project
|
||||
Use separate aggregation calls on Task/Model instead of lookup
|
||||
@ -739,30 +769,17 @@ class ProjectBLL:
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {"project": 1}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$project",
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
}
|
||||
{"$project": {"project": 1}},
|
||||
{"$group": {"_id": "$project", "count": {"$sum": 1}}}
|
||||
]
|
||||
|
||||
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
||||
return {
|
||||
data["_id"]: data["count"]
|
||||
for data in cls_.aggregate(pipeline)
|
||||
}
|
||||
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
|
||||
|
||||
with TimingContext("mongo", "get_security_groups"):
|
||||
tasks = get_agrregate_res(Task)
|
||||
models = get_agrregate_res(Model)
|
||||
return {
|
||||
pid: {
|
||||
"own_tasks": tasks.get(pid, 0),
|
||||
"own_models": models.get(pid, 0),
|
||||
}
|
||||
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
|
||||
for pid in project_ids
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ from typing import (
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
@ -239,3 +240,53 @@ class ProjectQueries:
|
||||
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@classmethod
|
||||
def get_model_metadata_keys(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"metadata": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
|
||||
{"$unwind": "$metadata"},
|
||||
{"$group": {"_id": "$metadata.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r.get("_id"))
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
@ -32,7 +32,7 @@ class QueueBLL(object):
|
||||
name: str,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[Sequence[dict]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Queue:
|
||||
"""Creates a queue"""
|
||||
with translate_errors_context():
|
||||
|
@ -95,6 +95,7 @@ class GetMixin(PropsMixin):
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {}
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
@ -599,7 +600,7 @@ class GetMixin(PropsMixin):
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_data_with_scroll_and_filter_support(
|
||||
def get_data_with_scroll_support(
|
||||
cls,
|
||||
query_dict: dict,
|
||||
data_getter: Callable[[], Sequence[dict]],
|
||||
@ -629,15 +630,12 @@ class GetMixin(PropsMixin):
|
||||
if cls._start_key in query_dict:
|
||||
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
|
||||
|
||||
def update_state(returned_len: int):
|
||||
if not state:
|
||||
return
|
||||
if state:
|
||||
state.position = query_dict[cls._start_key]
|
||||
cls.get_cache_manager().set_state(state)
|
||||
if ret_params is not None:
|
||||
ret_params["scroll_id"] = state.id
|
||||
|
||||
update_state(len(data))
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
@ -770,7 +768,7 @@ class GetMixin(PropsMixin):
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
)
|
||||
return cls.get_data_with_scroll_and_filter_support(
|
||||
return cls.get_data_with_scroll_support(
|
||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||
)
|
||||
|
||||
|
@ -1,10 +1,8 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
@ -12,6 +10,7 @@ from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeDictField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
@ -22,6 +21,10 @@ from apiserver.database.model.task.task import Task
|
||||
|
||||
|
||||
class Model(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
@ -30,8 +33,6 @@ class Model(AttributedDocument):
|
||||
"project",
|
||||
"task",
|
||||
"last_update",
|
||||
"metadata.key",
|
||||
"metadata.type",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
@ -63,6 +64,7 @@ class Model(AttributedDocument):
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
"metadata.*",
|
||||
),
|
||||
datetime_fields=("last_update",),
|
||||
)
|
||||
@ -86,6 +88,6 @@ class Model(AttributedDocument):
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
@ -1,16 +1,19 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin, AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
@ -19,23 +22,25 @@ from apiserver.database.model.task.task import Task
|
||||
|
||||
class Entry(EmbeddedDocument, ProperDictMixin):
|
||||
""" Entry representing a task waiting in the queue """
|
||||
|
||||
task = StringField(required=True, reference_field=Task)
|
||||
''' Task ID '''
|
||||
""" Task ID """
|
||||
added = DateTimeField(required=True)
|
||||
''' Added to the queue '''
|
||||
""" Added to the queue """
|
||||
|
||||
|
||||
class Queue(DbModelMixin, Document):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"indexes": ["metadata.key", "metadata.type"],
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@ -44,10 +49,12 @@ class Queue(DbModelMixin, Document):
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
tags = SafeSortedListField(
|
||||
StringField(required=True), default=list, user_set_allowed=True
|
||||
)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
@ -159,11 +159,10 @@ external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"execution.parameters.": AttributedDocument._numeric_locale,
|
||||
"last_metrics.": AttributedDocument._numeric_locale,
|
||||
"hyperparams.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
@ -184,7 +183,10 @@ class Task(AttributedDocument):
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{"fields": ["company", "project"], "collation": _numeric_locale},
|
||||
{
|
||||
"fields": ["company", "project"],
|
||||
"collation": AttributedDocument._numeric_locale,
|
||||
},
|
||||
{
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
|
29
apiserver/mongo/migrations/1_3_0.py
Normal file
29
apiserver/mongo/migrations/1_3_0.py
Normal file
@ -0,0 +1,29 @@
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
|
||||
|
||||
def _convert_metadata(db: Database, name):
|
||||
collection: Collection = db[name]
|
||||
|
||||
metadata_field = "metadata"
|
||||
query = {metadata_field: {"$exists": True, "$type": 4}}
|
||||
for doc in collection.find(filter=query, projection=(metadata_field,)):
|
||||
metadata = {
|
||||
ParameterKeyEscaper.escape(item["key"]): item
|
||||
for item in doc.get(metadata_field, [])
|
||||
if isinstance(item, dict) and "key" in item
|
||||
}
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {"metadata": metadata}},
|
||||
)
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
collections = ["model", "queue"]
|
||||
for name in collections:
|
||||
_convert_metadata(db, name)
|
||||
|
||||
_drop_all_indices_from_collections(db, collections)
|
@ -226,6 +226,12 @@ create_credentials {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${create_credentials."2.1"} {
|
||||
request.properties.label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_credentials {
|
||||
|
@ -61,14 +61,14 @@ _definitions {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
|
||||
@ -98,9 +98,11 @@ _definitions {
|
||||
additionalProperties: true
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -407,7 +409,7 @@ update_for_task {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
override_model_id {
|
||||
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
|
||||
@ -473,7 +475,7 @@ create {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
@ -529,9 +531,11 @@ create {
|
||||
}
|
||||
"2.13": ${create."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -568,7 +572,7 @@ edit {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
@ -624,9 +628,11 @@ edit {
|
||||
}
|
||||
"2.13": ${edit."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -657,7 +663,7 @@ update {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks Default is false."
|
||||
@ -707,9 +713,11 @@ update {
|
||||
}
|
||||
"2.13": ${update."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -718,7 +726,7 @@ publish_many {
|
||||
description: Publish models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to publish"
|
||||
ids.description: "IDs of the models to publish"
|
||||
force_publish_task {
|
||||
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
|
||||
type: boolean
|
||||
@ -779,7 +787,7 @@ archive_many {
|
||||
description: Archive models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to archive"
|
||||
ids.description: "IDs of the models to archive"
|
||||
}
|
||||
}
|
||||
response {
|
||||
@ -815,10 +823,9 @@ delete_many {
|
||||
description: Delete models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to delete"
|
||||
ids.description: "IDs of the models to delete"
|
||||
force {
|
||||
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
|
||||
"""
|
||||
description: "Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
@ -975,6 +982,11 @@ add_or_update_metadata {
|
||||
description: "Metadata items to add or update"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
replace_metadata {
|
||||
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
@ -42,15 +42,20 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last update time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@ -70,6 +75,18 @@ _definitions {
|
||||
description: "Total run time of all tasks in project (in seconds)"
|
||||
type: integer
|
||||
}
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks_24h {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
last_task_run {
|
||||
description: "The most recent started time of a task"
|
||||
type: integer
|
||||
}
|
||||
status_count {
|
||||
description: "Status counts"
|
||||
type: object
|
||||
@ -78,6 +95,10 @@ _definitions {
|
||||
description: "Number of 'created' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
completed {
|
||||
description: "Number of 'completed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
queued {
|
||||
description: "Number of 'queued' tasks in project"
|
||||
type: integer
|
||||
@ -158,14 +179,14 @@ _definitions {
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@ -299,14 +320,14 @@ create {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@ -419,7 +440,6 @@ get_all {
|
||||
description: "Projects list"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/projects_get_all_response_single" }
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -545,42 +565,6 @@ get_all_ex {
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
stats {
|
||||
properties {
|
||||
active.properties {
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
running_tasks {
|
||||
description: "Number of running tasks"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
archived.properties {
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
running_tasks {
|
||||
description: "Number of running tasks"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
@ -603,14 +587,14 @@ update {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@ -748,7 +732,6 @@ delete {
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
@ -881,6 +864,7 @@ get_hyper_parameters {
|
||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
@ -929,6 +913,55 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_keys {
|
||||
"999.0" {
|
||||
description: """Get a list of all metadata keys used in models within the given project."""
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
|
||||
page {
|
||||
description: "Page number"
|
||||
default: 0
|
||||
type: integer
|
||||
}
|
||||
page_size {
|
||||
description: "Page size"
|
||||
default: 500
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
keys {
|
||||
description: "A list of model keys"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
type: integer
|
||||
}
|
||||
total {
|
||||
description: "Total number of results"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
@ -936,7 +969,6 @@ get_task_tags {
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
@ -1058,4 +1090,4 @@ get_task_parents {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -79,9 +79,11 @@ _definitions {
|
||||
items { "$ref": "#/definitions/entry" }
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Queue metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -281,6 +283,15 @@ create {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${create."2.4"} {
|
||||
metadata {
|
||||
description: "Queue metadata"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.4" {
|
||||
@ -322,7 +333,15 @@ update {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${update."2.4"} {
|
||||
metadata {
|
||||
description: "Queue metadata"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -632,6 +651,11 @@ add_or_update_metadata {
|
||||
description: "Metadata items to add or update"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
replace_metadata {
|
||||
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
@ -21,16 +21,14 @@ from apiserver.apimodels.models import (
|
||||
ModelsPublishManyRequest,
|
||||
ModelsDeleteManyRequest,
|
||||
)
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.model import ModelBLL, Metadata
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import publish_task
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import validate_id
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import (
|
||||
@ -50,8 +48,8 @@ from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
ModelsBackwardsCompatibility,
|
||||
validate_metadata,
|
||||
get_metadata_from_api,
|
||||
unescape_metadata,
|
||||
escape_metadata,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
@ -64,19 +62,20 @@ project_bll = ProjectBLL()
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
Metadata.escape_query_parameters(call)
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
)
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
conform_output_tags(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
conform_output_tags(call, models[0])
|
||||
unescape_metadata(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
@ -86,25 +85,25 @@ def get_by_task_id(call: APICall, company_id, _):
|
||||
|
||||
task_id = call.data["task"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not task.models or not task.models.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="models.output")
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not task.models or not task.models.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="models.output")
|
||||
|
||||
model_id = task.models.output[-1].model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
call.result.data = {"model": model_dict}
|
||||
model_id = task.models.output[-1].model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
unescape_metadata(call, model_dict)
|
||||
call.result.data = {"model": model_dict}
|
||||
|
||||
|
||||
def _process_include_subprojects(call_data: dict):
|
||||
@ -121,47 +120,50 @@ def _process_include_subprojects(call_data: dict):
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
_process_include_subprojects(call.data)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
_process_include_subprojects(call.data)
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_by_id_ex", required_fields=["id"])
|
||||
def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_by_id_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_by_id_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
|
||||
@ -189,15 +191,22 @@ create_fields = {
|
||||
"metadata": list,
|
||||
}
|
||||
|
||||
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata", "system_tags", "tags")
|
||||
last_update_fields = (
|
||||
"uri",
|
||||
"framework",
|
||||
"design",
|
||||
"labels",
|
||||
"ready",
|
||||
"metadata",
|
||||
"system_tags",
|
||||
"tags",
|
||||
)
|
||||
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
metadata = fields.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
escape_metadata(fields)
|
||||
return fields
|
||||
|
||||
|
||||
@ -231,82 +240,80 @@ def update_for_task(call: APICall, company_id, _):
|
||||
"exactly one field is required", fields=("uri", "override_model_id")
|
||||
)
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
f"model can only be updated for tasks in the {allowed_states} states",
|
||||
**query,
|
||||
)
|
||||
|
||||
if override_model_id:
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=override_model_id
|
||||
)
|
||||
else:
|
||||
if "name" not in call.data:
|
||||
# use task name if name not provided
|
||||
call.data["name"] = task.name
|
||||
|
||||
if "comment" not in call.data:
|
||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||
|
||||
if task.models and task.models.output:
|
||||
# model exists, update
|
||||
model_id = task.models.output[-1].model
|
||||
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
||||
res.update({"id": model_id, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
|
||||
# new model, create
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
created=now,
|
||||
last_update=now,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
project=task.project,
|
||||
framework=task.execution.framework,
|
||||
parent=task.models.input[0].model
|
||||
if task.models and task.models.input
|
||||
else None,
|
||||
design=task.execution.model_desc,
|
||||
labels=task.execution.model_labels,
|
||||
ready=(task.status == TaskStatus.published),
|
||||
**fields,
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
f"model can only be updated for tasks in the {allowed_states} states",
|
||||
**query,
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
models__output=[
|
||||
ModelItem(
|
||||
model=model.id,
|
||||
name=TaskModelNames[TaskModelTypes.output],
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if override_model_id:
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=override_model_id
|
||||
)
|
||||
else:
|
||||
if "name" not in call.data:
|
||||
# use task name if name not provided
|
||||
call.data["name"] = task.name
|
||||
|
||||
if "comment" not in call.data:
|
||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||
|
||||
if task.models and task.models.output:
|
||||
# model exists, update
|
||||
model_id = task.models.output[-1].model
|
||||
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
||||
res.update({"id": model_id, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
|
||||
# new model, create
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
created=now,
|
||||
last_update=now,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
project=task.project,
|
||||
framework=task.execution.framework,
|
||||
parent=task.models.input[0].model
|
||||
if task.models and task.models.input
|
||||
else None,
|
||||
design=task.execution.model_desc,
|
||||
labels=task.execution.model_labels,
|
||||
ready=(task.status == TaskStatus.published),
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
models__output=[
|
||||
ModelItem(
|
||||
model=model.id,
|
||||
name=TaskModelNames[TaskModelTypes.output],
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
call.result.data = {"id": model.id, "created": True}
|
||||
call.result.data = {"id": model.id, "created": True}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@ -319,36 +326,33 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
if req_model.public:
|
||||
company_id = ""
|
||||
|
||||
with translate_errors_context():
|
||||
project = req_model.project
|
||||
if project:
|
||||
validate_id(Project, company=company_id, project=project)
|
||||
|
||||
project = req_model.project
|
||||
if project:
|
||||
validate_id(Project, company=company_id, project=project)
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(company_id, req_data)
|
||||
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(company_id, req_data)
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
escape_metadata(fields)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
validate_metadata(fields.get("metadata"))
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
|
||||
|
||||
def prepare_update_fields(call, company_id, fields: dict):
|
||||
@ -383,6 +387,7 @@ def prepare_update_fields(call, company_id, fields: dict):
|
||||
)
|
||||
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
escape_metadata(fields)
|
||||
return fields
|
||||
|
||||
|
||||
@ -394,89 +399,85 @@ def validate_task(company_id, fields: dict):
|
||||
def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
value = fields[key]
|
||||
if (
|
||||
field
|
||||
and isinstance(value, dict)
|
||||
and isinstance(field, EmbeddedDocument)
|
||||
):
|
||||
d = field.to_mongo(use_db_field=False).to_dict()
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
if fields:
|
||||
if any(uf in fields for uf in last_update_fields):
|
||||
fields.update(last_update=datetime.utcnow())
|
||||
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
value = fields[key]
|
||||
if (
|
||||
field
|
||||
and isinstance(value, dict)
|
||||
and isinstance(field, EmbeddedDocument)
|
||||
):
|
||||
d = field.to_mongo(use_db_field=False).to_dict()
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
if fields:
|
||||
if any(uf in fields for uf in last_update_fields):
|
||||
fields.update(last_update=datetime.utcnow())
|
||||
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
new_project = fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, model.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
new_project = fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, model.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
unescape_metadata(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
if any(uf in updated_fields for uf in last_update_fields):
|
||||
model.update(upsert=False, last_update=datetime.utcnow())
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=updated_fields
|
||||
)
|
||||
|
||||
metadata = data.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
if any(uf in updated_fields for uf in last_update_fields):
|
||||
model.update(upsert=False, last_update=datetime.utcnow())
|
||||
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=updated_fields
|
||||
)
|
||||
conform_output_tags(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
conform_output_tags(call, updated_fields)
|
||||
unescape_metadata(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@ -641,26 +642,25 @@ def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
model_id = request.model
|
||||
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
updated = metadata_add_or_update(
|
||||
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
|
||||
)
|
||||
if updated:
|
||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
||||
|
||||
return {"updated": updated}
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
return {
|
||||
"updated": Metadata.edit_metadata(
|
||||
model,
|
||||
items=request.metadata,
|
||||
replace_metadata=request.replace_metadata,
|
||||
last_update=datetime.utcnow(),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
model_id = request.model
|
||||
ModelBLL.get_company_model_by_id(
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
|
||||
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
|
||||
if updated:
|
||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
||||
|
||||
return {"updated": updated}
|
||||
return {
|
||||
"updated": Metadata.delete_metadata(
|
||||
model, keys=request.keys, last_update=datetime.utcnow()
|
||||
)
|
||||
}
|
||||
|
@ -275,6 +275,23 @@ def get_unique_metric_variants(
|
||||
call.result.data = {"metrics": metrics}
|
||||
|
||||
|
||||
@endpoint("projects.get_model_metadata_keys",)
|
||||
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
|
||||
total, remaining, keys = project_queries.get_model_metadata_keys(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"remaining": remaining,
|
||||
"keys": keys,
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.9",
|
||||
|
@ -13,17 +13,19 @@ from apiserver.apimodels.queues import (
|
||||
QueueMetrics,
|
||||
AddOrUpdateMetadataRequest,
|
||||
DeleteMetadataRequest,
|
||||
GetNextTaskRequest,
|
||||
)
|
||||
from apiserver.bll.model import Metadata
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
conform_tags,
|
||||
get_metadata_from_api,
|
||||
escape_metadata,
|
||||
unescape_metadata,
|
||||
)
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
|
||||
@ -36,6 +38,7 @@ def get_by_id(call: APICall, company_id, req_model: QueueRequest):
|
||||
queue = queue_bll.get_by_id(company_id, req_model.queue)
|
||||
queue_dict = queue.to_proper_dict()
|
||||
conform_output_tags(call, queue_dict)
|
||||
unescape_metadata(call, queue_dict)
|
||||
call.result.data = {"queue": queue_dict}
|
||||
|
||||
|
||||
@ -49,13 +52,13 @@ def get_by_id(call: APICall):
|
||||
def get_all_ex(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_queue_infos(
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
ret_params=ret_params,
|
||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
unescape_metadata(call, queues)
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@ -63,13 +66,12 @@ def get_all_ex(call: APICall):
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_all(
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
ret_params=ret_params,
|
||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
unescape_metadata(call, queues)
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@ -83,7 +85,7 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
name=request.name,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
metadata=get_metadata_from_api(request.metadata),
|
||||
metadata=Metadata.metadata_from_api(request.metadata),
|
||||
)
|
||||
call.result.data = {"id": queue.id}
|
||||
|
||||
@ -97,10 +99,12 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
data = call.data_model_for_partial_update
|
||||
conform_tag_fields(call, data, validate=True)
|
||||
escape_metadata(data)
|
||||
updated, fields = queue_bll.update(
|
||||
company_id=company_id, queue_id=req_model.queue, **data
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
unescape_metadata(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
|
||||
@ -121,11 +125,19 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("queues.get_next_task", min_version="2.4", request_data_model=QueueRequest)
|
||||
def get_next_task(call: APICall, company_id, req_model: QueueRequest):
|
||||
task = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
|
||||
if task:
|
||||
call.result.data = {"entry": task.to_proper_dict()}
|
||||
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
|
||||
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
|
||||
entry = queue_bll.get_next_task(
|
||||
company_id=company_id, queue_id=req_model.queue
|
||||
)
|
||||
if entry:
|
||||
data = {"entry": entry.to_proper_dict()}
|
||||
if req_model.get_task_info:
|
||||
task = Task.objects(id=entry.task).first()
|
||||
if task:
|
||||
data["task_info"] = {"company": task.company, "user": task.user}
|
||||
|
||||
call.result.data = data
|
||||
|
||||
|
||||
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
|
||||
@ -245,21 +257,19 @@ def get_queue_metrics(
|
||||
|
||||
@endpoint("queues.add_or_update_metadata", min_version="2.13")
|
||||
def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return {
|
||||
"updated": metadata_add_or_update(
|
||||
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
|
||||
"updated": Metadata.edit_metadata(
|
||||
queue, items=request.metadata, replace_metadata=request.replace_metadata
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("queues.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
|
||||
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
|
||||
|
@ -2,7 +2,6 @@ from datetime import datetime
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
||||
from apiserver.apimodels.organization import Filter
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||
@ -222,22 +221,38 @@ class DockerCmdBackwardsCompatibility:
|
||||
nested_set(task, cls.field, docker_cmd)
|
||||
|
||||
|
||||
def validate_metadata(metadata: Sequence[dict]):
|
||||
def escape_metadata(document: dict):
|
||||
"""
|
||||
Escape special characters in metadata keys
|
||||
"""
|
||||
metadata = document.get("metadata")
|
||||
if not metadata:
|
||||
return
|
||||
|
||||
keys = [m.get("key") for m in metadata]
|
||||
unique_keys = set(keys)
|
||||
unique_keys.discard(None)
|
||||
if len(keys) != len(set(keys)):
|
||||
raise errors.bad_request.ValidationError("Metadata keys should be unique")
|
||||
document["metadata"] = {
|
||||
ParameterKeyEscaper.escape(k): v
|
||||
for k, v in metadata.items()
|
||||
}
|
||||
|
||||
|
||||
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
|
||||
if not api_metadata:
|
||||
return api_metadata
|
||||
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||
"""
|
||||
Unescape special characters in metadata keys
|
||||
"""
|
||||
if isinstance(documents, dict):
|
||||
documents = [documents]
|
||||
|
||||
metadata = [m.to_struct() for m in api_metadata]
|
||||
validate_metadata(metadata)
|
||||
old_client = call.requested_endpoint_version <= PartialVersion("2.16")
|
||||
for doc in documents:
|
||||
if old_client and "metadata" in doc:
|
||||
doc["metadata"] = []
|
||||
continue
|
||||
|
||||
return metadata
|
||||
metadata = doc.get("metadata")
|
||||
if not metadata:
|
||||
continue
|
||||
|
||||
doc["metadata"] = {
|
||||
ParameterKeyEscaper.unescape(k): v
|
||||
for k, v in metadata.items()
|
||||
}
|
||||
|
@ -1,12 +1,11 @@
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.tests.api_client import APIClient
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestQueueAndModelMetadata(TestService):
|
||||
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
||||
meta1 = {"test_key": {"key": "test_key", "type": "str", "value": "test_value"}}
|
||||
|
||||
def test_queue_metas(self):
|
||||
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
|
||||
@ -23,20 +22,43 @@ class TestQueueAndModelMetadata(TestService):
|
||||
)
|
||||
|
||||
model_id = self._temp_model("TestMetadata1")
|
||||
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
||||
self.api.models.edit(model=model_id, metadata=self.meta1)
|
||||
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
||||
|
||||
def test_project_meta_query(self):
|
||||
self._temp_model("TestMetadata", metadata=self.meta1)
|
||||
project = self.temp_project(name="MetaParent")
|
||||
model_id = self._temp_model(
|
||||
"TestMetadata2",
|
||||
project=project,
|
||||
metadata={
|
||||
"test_key": {"key": "test_key", "type": "str", "value": "test_value"},
|
||||
"test_key2": {"key": "test_key2", "type": "str", "value": "test_value"},
|
||||
},
|
||||
)
|
||||
res = self.api.projects.get_model_metadata_keys()
|
||||
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
|
||||
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
|
||||
self.assertTrue("test_key" in res["keys"])
|
||||
self.assertFalse("test_key2" in res["keys"])
|
||||
|
||||
model = self.api.models.get_all_ex(
|
||||
id=[model_id], only_fields=["metadata.test_key"]
|
||||
).models[0]
|
||||
self.assertTrue("test_key" in model.metadata)
|
||||
self.assertFalse("test_key2" in model.metadata)
|
||||
|
||||
def _test_meta_operations(
|
||||
self, service: APIClient.Service, entity: str, _id: str,
|
||||
):
|
||||
assert_meta = partial(self._assertMeta, service=service, entity=entity)
|
||||
assert_meta(_id=_id, meta=self.meta1)
|
||||
|
||||
meta2 = [
|
||||
{"key": "test1", "type": "str", "value": "data1"},
|
||||
{"key": "test2", "type": "str", "value": "data2"},
|
||||
{"key": "test3", "type": "str", "value": "data3"},
|
||||
]
|
||||
meta2 = {
|
||||
"test1": {"key": "test1", "type": "str", "value": "data1"},
|
||||
"test2": {"key": "test2", "type": "str", "value": "data2"},
|
||||
"test3": {"key": "test3", "type": "str", "value": "data3"},
|
||||
}
|
||||
service.update(**{entity: _id, "metadata": meta2})
|
||||
assert_meta(_id=_id, meta=meta2)
|
||||
|
||||
@ -48,16 +70,17 @@ class TestQueueAndModelMetadata(TestService):
|
||||
]
|
||||
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=[meta2[0], *updates])
|
||||
assert_meta(_id=_id, meta={**meta2, **{u["key"]: u for u in updates}})
|
||||
|
||||
res = service.delete_metadata(
|
||||
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=meta2[:1])
|
||||
# noinspection PyTypeChecker
|
||||
assert_meta(_id=_id, meta=dict(list(meta2.items())[:1]))
|
||||
|
||||
def _assertMeta(
|
||||
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
|
||||
self, service: APIClient.Service, entity: str, _id: str, meta: dict
|
||||
):
|
||||
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
|
||||
self.assertEqual(res.metadata, meta)
|
||||
|
Loading…
Reference in New Issue
Block a user