mirror of
https://github.com/clearml/clearml-server
synced 2025-05-22 03:56:33 +00:00
Add support for queue and model metadata
This commit is contained in:
parent
2e7f418ee2
commit
29de110abb
23
apiserver/apimodels/metadata.py
Normal file
23
apiserver/apimodels/metadata.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from jsonmodels import validators
|
||||||
|
from jsonmodels.fields import StringField
|
||||||
|
from jsonmodels.models import Base
|
||||||
|
|
||||||
|
from apiserver.apimodels import ListField
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataItem(Base):
|
||||||
|
key = StringField(required=True)
|
||||||
|
type = StringField(required=True)
|
||||||
|
value = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteMetadata(Base):
|
||||||
|
keys: Sequence[str] = ListField(str, validators=validators.Length(minimum_value=1))
|
||||||
|
|
||||||
|
|
||||||
|
class AddOrUpdateMetadata(Base):
|
||||||
|
metadata: Sequence[MetadataItem] = ListField(
|
||||||
|
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||||
|
)
|
@ -3,6 +3,11 @@ from six import string_types
|
|||||||
|
|
||||||
from apiserver.apimodels import ListField, DictField
|
from apiserver.apimodels import ListField, DictField
|
||||||
from apiserver.apimodels.base import UpdateResponse
|
from apiserver.apimodels.base import UpdateResponse
|
||||||
|
from apiserver.apimodels.metadata import (
|
||||||
|
MetadataItem,
|
||||||
|
DeleteMetadata,
|
||||||
|
AddOrUpdateMetadata,
|
||||||
|
)
|
||||||
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
|
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||||
|
|
||||||
|
|
||||||
@ -25,6 +30,7 @@ class CreateModelRequest(models.Base):
|
|||||||
ready = fields.BoolField(default=True)
|
ready = fields.BoolField(default=True)
|
||||||
ui_cache = DictField()
|
ui_cache = DictField()
|
||||||
task = fields.StringField()
|
task = fields.StringField()
|
||||||
|
metadata = ListField(items_types=[MetadataItem])
|
||||||
|
|
||||||
|
|
||||||
class CreateModelResponse(models.Base):
|
class CreateModelResponse(models.Base):
|
||||||
@ -53,3 +59,11 @@ class ModelTaskPublishResponse(models.Base):
|
|||||||
class PublishModelResponse(UpdateResponse):
|
class PublishModelResponse(UpdateResponse):
|
||||||
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
|
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
|
||||||
updated = fields.IntField()
|
updated = fields.IntField()
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteMetadataRequest(DeleteMetadata):
|
||||||
|
model = fields.StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||||
|
model = fields.StringField(required=True)
|
||||||
|
@ -3,6 +3,11 @@ from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
|||||||
from jsonmodels.models import Base
|
from jsonmodels.models import Base
|
||||||
|
|
||||||
from apiserver.apimodels import ListField
|
from apiserver.apimodels import ListField
|
||||||
|
from apiserver.apimodels.metadata import (
|
||||||
|
MetadataItem,
|
||||||
|
DeleteMetadata,
|
||||||
|
AddOrUpdateMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GetDefaultResp(Base):
|
class GetDefaultResp(Base):
|
||||||
@ -14,6 +19,7 @@ class CreateRequest(Base):
|
|||||||
name = StringField(required=True)
|
name = StringField(required=True)
|
||||||
tags = ListField(items_types=[str])
|
tags = ListField(items_types=[str])
|
||||||
system_tags = ListField(items_types=[str])
|
system_tags = ListField(items_types=[str])
|
||||||
|
metadata = ListField(items_types=[MetadataItem])
|
||||||
|
|
||||||
|
|
||||||
class QueueRequest(Base):
|
class QueueRequest(Base):
|
||||||
@ -28,6 +34,7 @@ class UpdateRequest(QueueRequest):
|
|||||||
name = StringField()
|
name = StringField()
|
||||||
tags = ListField(items_types=[str])
|
tags = ListField(items_types=[str])
|
||||||
system_tags = ListField(items_types=[str])
|
system_tags = ListField(items_types=[str])
|
||||||
|
metadata = ListField(items_types=[MetadataItem])
|
||||||
|
|
||||||
|
|
||||||
class TaskRequest(QueueRequest):
|
class TaskRequest(QueueRequest):
|
||||||
@ -58,3 +65,11 @@ class QueueMetrics(Base):
|
|||||||
|
|
||||||
class GetMetricsResponse(Base):
|
class GetMetricsResponse(Base):
|
||||||
queues = ListField(QueueMetrics)
|
queues = ListField(QueueMetrics)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteMetadataRequest(DeleteMetadata):
|
||||||
|
queue = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||||
|
queue = StringField(required=True)
|
||||||
|
@ -60,10 +60,13 @@ class TaskRequest(models.Base):
|
|||||||
task = StringField(required=True)
|
task = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
class UpdateRequest(TaskRequest):
|
class TaskUpdateRequest(TaskRequest):
|
||||||
|
force = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateRequest(TaskUpdateRequest):
|
||||||
status_reason = StringField(default="")
|
status_reason = StringField(default="")
|
||||||
status_message = StringField(default="")
|
status_message = StringField(default="")
|
||||||
force = BoolField(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
class EnqueueRequest(UpdateRequest):
|
class EnqueueRequest(UpdateRequest):
|
||||||
@ -128,9 +131,8 @@ class CloneRequest(TaskRequest):
|
|||||||
new_project_name = StringField()
|
new_project_name = StringField()
|
||||||
|
|
||||||
|
|
||||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
class AddOrUpdateArtifactsRequest(TaskUpdateRequest):
|
||||||
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
|
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
|
||||||
force = BoolField(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
class ArtifactId(models.Base):
|
class ArtifactId(models.Base):
|
||||||
@ -140,9 +142,8 @@ class ArtifactId(models.Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeleteArtifactsRequest(TaskRequest):
|
class DeleteArtifactsRequest(TaskUpdateRequest):
|
||||||
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
|
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
|
||||||
force = BoolField(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
class ResetRequest(UpdateRequest):
|
class ResetRequest(UpdateRequest):
|
||||||
@ -173,7 +174,7 @@ class ReplaceHyperparams(object):
|
|||||||
all = "all"
|
all = "all"
|
||||||
|
|
||||||
|
|
||||||
class EditHyperParamsRequest(TaskRequest):
|
class EditHyperParamsRequest(TaskUpdateRequest):
|
||||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||||
[HyperParamItem], validators=Length(minimum_value=1)
|
[HyperParamItem], validators=Length(minimum_value=1)
|
||||||
)
|
)
|
||||||
@ -181,7 +182,6 @@ class EditHyperParamsRequest(TaskRequest):
|
|||||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||||
default=ReplaceHyperparams.none,
|
default=ReplaceHyperparams.none,
|
||||||
)
|
)
|
||||||
force = BoolField(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
class HyperParamKey(models.Base):
|
class HyperParamKey(models.Base):
|
||||||
@ -189,11 +189,10 @@ class HyperParamKey(models.Base):
|
|||||||
name = StringField(nullable=True)
|
name = StringField(nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class DeleteHyperParamsRequest(TaskRequest):
|
class DeleteHyperParamsRequest(TaskUpdateRequest):
|
||||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||||
[HyperParamKey], validators=Length(minimum_value=1)
|
[HyperParamKey], validators=Length(minimum_value=1)
|
||||||
)
|
)
|
||||||
force = BoolField(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
class GetConfigurationsRequest(MultiTaskRequest):
|
class GetConfigurationsRequest(MultiTaskRequest):
|
||||||
@ -211,17 +210,15 @@ class Configuration(models.Base):
|
|||||||
description = StringField()
|
description = StringField()
|
||||||
|
|
||||||
|
|
||||||
class EditConfigurationRequest(TaskRequest):
|
class EditConfigurationRequest(TaskUpdateRequest):
|
||||||
configuration: Sequence[Configuration] = ListField(
|
configuration: Sequence[Configuration] = ListField(
|
||||||
[Configuration], validators=Length(minimum_value=1)
|
[Configuration], validators=Length(minimum_value=1)
|
||||||
)
|
)
|
||||||
replace_configuration = BoolField(default=False)
|
replace_configuration = BoolField(default=False)
|
||||||
force = BoolField(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteConfigurationRequest(TaskRequest):
|
class DeleteConfigurationRequest(TaskUpdateRequest):
|
||||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||||
force = BoolField(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
class ArchiveRequest(MultiTaskRequest):
|
class ArchiveRequest(MultiTaskRequest):
|
||||||
|
@ -32,6 +32,7 @@ class QueueBLL(object):
|
|||||||
name: str,
|
name: str,
|
||||||
tags: Optional[Sequence[str]] = None,
|
tags: Optional[Sequence[str]] = None,
|
||||||
system_tags: Optional[Sequence[str]] = None,
|
system_tags: Optional[Sequence[str]] = None,
|
||||||
|
metadata: Optional[Sequence[dict]] = None,
|
||||||
) -> Queue:
|
) -> Queue:
|
||||||
"""Creates a queue"""
|
"""Creates a queue"""
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
@ -43,6 +44,7 @@ class QueueBLL(object):
|
|||||||
name=name,
|
name=name,
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
system_tags=system_tags or [],
|
system_tags=system_tags or [],
|
||||||
|
metadata=metadata,
|
||||||
last_update=now,
|
last_update=now,
|
||||||
)
|
)
|
||||||
queue.save()
|
queue.save()
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from hashlib import md5
|
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||||
|
from apiserver.database.utils import hash_field_name
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
from apiserver.utilities.dicts import nested_get, nested_set
|
from apiserver.utilities.dicts import nested_get, nested_set
|
||||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||||
@ -15,7 +15,7 @@ def get_artifact_id(artifact: dict):
|
|||||||
Calculate id from 'key' and 'mode' fields
|
Calculate id from 'key' and 'mode' fields
|
||||||
Return hash on on the id so that it will not contain mongo illegal characters
|
Return hash on on the id so that it will not contain mongo illegal characters
|
||||||
"""
|
"""
|
||||||
key_hash: str = md5(artifact["key"].encode()).hexdigest()
|
key_hash: str = hash_field_name(artifact["key"])
|
||||||
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
|
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
|
||||||
return f"{key_hash}_{mode}"
|
return f"{key_hash}_{mode}"
|
||||||
|
|
||||||
|
44
apiserver/database/model/metadata.py
Normal file
44
apiserver/database/model/metadata.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from typing import Sequence, Type
|
||||||
|
|
||||||
|
from mongoengine import EmbeddedDocument, StringField, Document
|
||||||
|
from pymongo import UpdateOne
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
|
||||||
|
from apiserver.database.model.base import ProperDictMixin
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataItem(EmbeddedDocument, ProperDictMixin):
|
||||||
|
key = StringField(required=True)
|
||||||
|
type = StringField(required=True)
|
||||||
|
value = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
def metadata_add_or_update(cls: Type[Document], _id: str, items: Sequence[dict]) -> int:
|
||||||
|
collection: Collection = cls._get_collection()
|
||||||
|
res = collection.update_one(
|
||||||
|
filter={"_id": _id},
|
||||||
|
update={
|
||||||
|
"$set": {f"metadata.$[elem{idx}]": item for idx, item in enumerate(items)}
|
||||||
|
},
|
||||||
|
array_filters=[
|
||||||
|
{f"elem{idx}.key": item["key"]} for idx, item in enumerate(items)
|
||||||
|
],
|
||||||
|
upsert=False,
|
||||||
|
)
|
||||||
|
if len(items) == 1 and res.modified_count == 1:
|
||||||
|
return res.modified_count
|
||||||
|
|
||||||
|
requests = [
|
||||||
|
UpdateOne(
|
||||||
|
filter={"_id": _id, "metadata.key": {"$ne": item["key"]}},
|
||||||
|
update={"$push": {"metadata": item}},
|
||||||
|
)
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
|
res = collection.bulk_write(requests)
|
||||||
|
|
||||||
|
return 1 if res.modified_count else 0
|
||||||
|
|
||||||
|
|
||||||
|
def metadata_delete(cls: Type[Document], _id: str, keys: Sequence[str]) -> int:
|
||||||
|
return cls.objects(id=_id).update_one(pull__metadata__key__in=keys)
|
@ -1,9 +1,22 @@
|
|||||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
from typing import Sequence
|
||||||
|
|
||||||
|
from mongoengine import (
|
||||||
|
Document,
|
||||||
|
StringField,
|
||||||
|
DateTimeField,
|
||||||
|
BooleanField,
|
||||||
|
EmbeddedDocumentListField,
|
||||||
|
)
|
||||||
|
|
||||||
from apiserver.database import Database, strict
|
from apiserver.database import Database, strict
|
||||||
from apiserver.database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
from apiserver.database.fields import (
|
||||||
|
StrippedStringField,
|
||||||
|
SafeDictField,
|
||||||
|
SafeSortedListField,
|
||||||
|
)
|
||||||
from apiserver.database.model import DbModelMixin
|
from apiserver.database.model import DbModelMixin
|
||||||
from apiserver.database.model.base import GetMixin
|
from apiserver.database.model.base import GetMixin
|
||||||
|
from apiserver.database.model.metadata import MetadataItem
|
||||||
from apiserver.database.model.model_labels import ModelLabels
|
from apiserver.database.model.model_labels import ModelLabels
|
||||||
from apiserver.database.model.company import Company
|
from apiserver.database.model.company import Company
|
||||||
from apiserver.database.model.project import Project
|
from apiserver.database.model.project import Project
|
||||||
@ -19,6 +32,8 @@ class Model(DbModelMixin, Document):
|
|||||||
"parent",
|
"parent",
|
||||||
"project",
|
"project",
|
||||||
"task",
|
"task",
|
||||||
|
"metadata.key",
|
||||||
|
"metadata.type",
|
||||||
("company", "framework"),
|
("company", "framework"),
|
||||||
("company", "name"),
|
("company", "name"),
|
||||||
("company", "user"),
|
("company", "user"),
|
||||||
@ -73,3 +88,6 @@ class Model(DbModelMixin, Document):
|
|||||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||||
)
|
)
|
||||||
company_origin = StringField(exclude_by_default=True)
|
company_origin = StringField(exclude_by_default=True)
|
||||||
|
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||||
|
MetadataItem, default=list, user_set_allowed=True
|
||||||
|
)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from mongoengine import (
|
from mongoengine import (
|
||||||
Document,
|
Document,
|
||||||
EmbeddedDocument,
|
EmbeddedDocument,
|
||||||
@ -11,6 +13,7 @@ from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
|||||||
from apiserver.database.model import DbModelMixin
|
from apiserver.database.model import DbModelMixin
|
||||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||||
from apiserver.database.model.company import Company
|
from apiserver.database.model.company import Company
|
||||||
|
from apiserver.database.model.metadata import MetadataItem
|
||||||
from apiserver.database.model.task.task import Task
|
from apiserver.database.model.task.task import Task
|
||||||
|
|
||||||
|
|
||||||
@ -32,6 +35,7 @@ class Queue(DbModelMixin, Document):
|
|||||||
meta = {
|
meta = {
|
||||||
'db_alias': Database.backend,
|
'db_alias': Database.backend,
|
||||||
'strict': strict,
|
'strict': strict,
|
||||||
|
"indexes": ["metadata.key", "metadata.type"],
|
||||||
}
|
}
|
||||||
|
|
||||||
id = StringField(primary_key=True)
|
id = StringField(primary_key=True)
|
||||||
@ -44,3 +48,6 @@ class Queue(DbModelMixin, Document):
|
|||||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||||
last_update = DateTimeField()
|
last_update = DateTimeField()
|
||||||
|
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||||
|
MetadataItem, default=list, user_set_allowed=True
|
||||||
|
)
|
||||||
|
@ -1,3 +1,23 @@
|
|||||||
|
metadata {
|
||||||
|
type: array
|
||||||
|
items {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
key {
|
||||||
|
type: string
|
||||||
|
description: The key uniquely identifying the metadata item inside the given entity
|
||||||
|
}
|
||||||
|
tyoe {
|
||||||
|
type: string
|
||||||
|
description: The type of the metadata item
|
||||||
|
}
|
||||||
|
value {
|
||||||
|
type: string
|
||||||
|
description: The value stored in the metadata item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
credentials {
|
credentials {
|
||||||
type: object
|
type: object
|
||||||
properties {
|
properties {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
|
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
|
||||||
_definitions {
|
_definitions {
|
||||||
|
include "_common.conf"
|
||||||
multi_field_pattern_data {
|
multi_field_pattern_data {
|
||||||
type: object
|
type: object
|
||||||
properties {
|
properties {
|
||||||
@ -444,6 +445,12 @@ create {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"2.13": ${create."2.1"} {
|
||||||
|
metadata {
|
||||||
|
description: "Model metadata"
|
||||||
|
"$ref": "#/definitions/metadata"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
edit {
|
edit {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
@ -532,6 +539,12 @@ edit {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"2.13": ${edit."2.1"} {
|
||||||
|
metadata {
|
||||||
|
description: "Model metadata"
|
||||||
|
"$ref": "#/definitions/metadata"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
update {
|
update {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
@ -608,6 +621,12 @@ update {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"2.13": ${update."2.1"} {
|
||||||
|
metadata {
|
||||||
|
description: "Model metadata"
|
||||||
|
"$ref": "#/definitions/metadata"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
set_ready {
|
set_ready {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
@ -798,4 +817,63 @@ move {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
add_or_update_metadata {
|
||||||
|
"2.13" {
|
||||||
|
description: "Add or update model metadata"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [model, metadata]
|
||||||
|
properties {
|
||||||
|
model {
|
||||||
|
description: "ID of the model"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
description: "Metadata items to add or update"
|
||||||
|
"$ref": "#/definitions/metadata"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
updated {
|
||||||
|
description: "Number of models updated (0 or 1)"
|
||||||
|
type: integer
|
||||||
|
enum: [0, 1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete_metadata {
|
||||||
|
"2.13" {
|
||||||
|
description: "Delete metadata from model"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [ model, keys ]
|
||||||
|
properties {
|
||||||
|
model {
|
||||||
|
description: "ID of the model"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
keys {
|
||||||
|
description: "The list of metadata keys to delete"
|
||||||
|
type: array
|
||||||
|
items {type: string}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
updated {
|
||||||
|
description: "Number of models updated (0 or 1)"
|
||||||
|
type: integer
|
||||||
|
enum: [0, 1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
{
|
|
||||||
_description: "Provides a management API for queues of tasks waiting to be executed by workers deployed anywhere (see Workers Service)."
|
_description: "Provides a management API for queues of tasks waiting to be executed by workers deployed anywhere (see Workers Service)."
|
||||||
_definitions {
|
_definitions {
|
||||||
|
include "_common.conf"
|
||||||
queue_metrics {
|
queue_metrics {
|
||||||
type: object
|
type: object
|
||||||
properties: {
|
properties: {
|
||||||
@ -78,6 +78,10 @@
|
|||||||
type: array
|
type: array
|
||||||
items { "$ref": "#/definitions/entry" }
|
items { "$ref": "#/definitions/entry" }
|
||||||
}
|
}
|
||||||
|
metadata {
|
||||||
|
description: "Queue metadata"
|
||||||
|
"$ref": "#/definitions/metadata"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -565,4 +569,62 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
add_or_update_metadata {
|
||||||
|
"2.13" {
|
||||||
|
description: "Add or update queue metadata"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [queue, metadata]
|
||||||
|
properties {
|
||||||
|
queue {
|
||||||
|
description: "ID of the queue"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
description: "Metadata items to add or update"
|
||||||
|
"$ref": "#/definitions/metadata"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
updated {
|
||||||
|
description: "Number of queues updated (0 or 1)"
|
||||||
|
type: integer
|
||||||
|
enum: [0, 1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete_metadata {
|
||||||
|
"2.13" {
|
||||||
|
description: "Delete metadata from queue"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [ queue, keys ]
|
||||||
|
properties {
|
||||||
|
queue {
|
||||||
|
description: "ID of the queue"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
keys {
|
||||||
|
description: "The list of metadata keys to delete"
|
||||||
|
type: array
|
||||||
|
items {type: string}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
updated {
|
||||||
|
description: "Number of queues updated (0 or 1)"
|
||||||
|
type: integer
|
||||||
|
enum: [0, 1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,8 @@ from apiserver.apimodels.models import (
|
|||||||
ModelTaskPublishResponse,
|
ModelTaskPublishResponse,
|
||||||
GetFrameworksRequest,
|
GetFrameworksRequest,
|
||||||
DeleteModelRequest,
|
DeleteModelRequest,
|
||||||
|
DeleteMetadataRequest,
|
||||||
|
AddOrUpdateMetadataRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||||
@ -23,6 +25,7 @@ from apiserver.bll.task.utils import deleted_prefix
|
|||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
from apiserver.database.model import validate_id
|
from apiserver.database.model import validate_id
|
||||||
|
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||||
from apiserver.database.model.model import Model
|
from apiserver.database.model.model import Model
|
||||||
from apiserver.database.model.project import Project
|
from apiserver.database.model.project import Project
|
||||||
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem
|
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem
|
||||||
@ -32,7 +35,13 @@ from apiserver.database.utils import (
|
|||||||
filter_fields,
|
filter_fields,
|
||||||
)
|
)
|
||||||
from apiserver.service_repo import APICall, endpoint
|
from apiserver.service_repo import APICall, endpoint
|
||||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags, ModelsBackwardsCompatibility
|
from apiserver.services.utils import (
|
||||||
|
conform_tag_fields,
|
||||||
|
conform_output_tags,
|
||||||
|
ModelsBackwardsCompatibility,
|
||||||
|
validate_metadata,
|
||||||
|
get_metadata_from_api,
|
||||||
|
)
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
|
|
||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
@ -160,12 +169,16 @@ create_fields = {
|
|||||||
"design": None,
|
"design": None,
|
||||||
"labels": dict,
|
"labels": dict,
|
||||||
"ready": None,
|
"ready": None,
|
||||||
|
"metadata": list,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def parse_model_fields(call, valid_fields):
|
def parse_model_fields(call, valid_fields):
|
||||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||||
conform_tag_fields(call, fields, validate=True)
|
conform_tag_fields(call, fields, validate=True)
|
||||||
|
metadata = fields.get("metadata")
|
||||||
|
if metadata:
|
||||||
|
validate_metadata(metadata)
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
|
||||||
@ -185,6 +198,17 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_company_model(company_id: str, model_id: str, only_fields=None) -> Model:
|
||||||
|
query = dict(company=company_id, id=model_id)
|
||||||
|
qs = Model.objects(**query)
|
||||||
|
if only_fields:
|
||||||
|
qs = qs.only(*only_fields)
|
||||||
|
model = qs.first()
|
||||||
|
if not model:
|
||||||
|
raise errors.bad_request.InvalidModelId(**query)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.update_for_task", required_fields=["task"])
|
@endpoint("models.update_for_task", required_fields=["task"])
|
||||||
def update_for_task(call: APICall, company_id, _):
|
def update_for_task(call: APICall, company_id, _):
|
||||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||||
@ -218,10 +242,9 @@ def update_for_task(call: APICall, company_id, _):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if override_model_id:
|
if override_model_id:
|
||||||
query = dict(company=company_id, id=override_model_id)
|
model = _get_company_model(
|
||||||
model = Model.objects(**query).first()
|
company_id=company_id, model_id=override_model_id
|
||||||
if not model:
|
)
|
||||||
raise errors.bad_request.InvalidModelId(**query)
|
|
||||||
else:
|
else:
|
||||||
if "name" not in call.data:
|
if "name" not in call.data:
|
||||||
# use task name if name not provided
|
# use task name if name not provided
|
||||||
@ -294,6 +317,8 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
|||||||
fields = filter_fields(Model, req_data)
|
fields = filter_fields(Model, req_data)
|
||||||
conform_tag_fields(call, fields, validate=True)
|
conform_tag_fields(call, fields, validate=True)
|
||||||
|
|
||||||
|
validate_metadata(fields.get("metadata"))
|
||||||
|
|
||||||
# create and save model
|
# create and save model
|
||||||
model = Model(
|
model = Model(
|
||||||
id=database.utils.id(),
|
id=database.utils.id(),
|
||||||
@ -352,10 +377,7 @@ def edit(call: APICall, company_id, _):
|
|||||||
model_id = call.data["model"]
|
model_id = call.data["model"]
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
query = dict(id=model_id, company=company_id)
|
model = _get_company_model(company_id=company_id, model_id=model_id)
|
||||||
model = Model.objects(**query).first()
|
|
||||||
if not model:
|
|
||||||
raise errors.bad_request.InvalidModelId(**query)
|
|
||||||
|
|
||||||
fields = parse_model_fields(call, create_fields)
|
fields = parse_model_fields(call, create_fields)
|
||||||
fields = prepare_update_fields(call, company_id, fields)
|
fields = prepare_update_fields(call, company_id, fields)
|
||||||
@ -401,11 +423,7 @@ def _update_model(call: APICall, company_id, model_id=None):
|
|||||||
model_id = model_id or call.data["model"]
|
model_id = model_id or call.data["model"]
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
# get model by id
|
model = _get_company_model(company_id=company_id, model_id=model_id)
|
||||||
query = dict(id=model_id, company=company_id)
|
|
||||||
model = Model.objects(**query).first()
|
|
||||||
if not model:
|
|
||||||
raise errors.bad_request.InvalidModelId(**query)
|
|
||||||
|
|
||||||
data = prepare_update_fields(call, company_id, call.data)
|
data = prepare_update_fields(call, company_id, call.data)
|
||||||
|
|
||||||
@ -416,6 +434,10 @@ def _update_model(call: APICall, company_id, model_id=None):
|
|||||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
metadata = data.get("metadata")
|
||||||
|
if metadata:
|
||||||
|
validate_metadata(metadata)
|
||||||
|
|
||||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||||
if updated_count:
|
if updated_count:
|
||||||
new_project = updated_fields.get("project", model.project)
|
new_project = updated_fields.get("project", model.project)
|
||||||
@ -463,11 +485,11 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
|
|||||||
force = request.force
|
force = request.force
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
query = dict(id=model_id, company=company_id)
|
model = _get_company_model(
|
||||||
model = Model.objects(**query).only("id", "task", "project", "uri").first()
|
company_id=company_id,
|
||||||
if not model:
|
model_id=model_id,
|
||||||
raise errors.bad_request.InvalidModelId(**query)
|
only_fields=("id", "task", "project", "uri"),
|
||||||
|
)
|
||||||
deleted_model_id = f"{deleted_prefix}{model_id}"
|
deleted_model_id = f"{deleted_prefix}{model_id}"
|
||||||
|
|
||||||
using_tasks = Task.objects(models__input__model=model_id).only("id")
|
using_tasks = Task.objects(models__input__model=model_id).only("id")
|
||||||
@ -507,7 +529,7 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
|
|||||||
upsert=False,
|
upsert=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
del_count = Model.objects(**query).delete()
|
del_count = Model.objects(id=model_id, company=company_id).delete()
|
||||||
if del_count:
|
if del_count:
|
||||||
_reset_cached_tags(company_id, projects=[model.project])
|
_reset_cached_tags(company_id, projects=[model.project])
|
||||||
call.result.data = dict(deleted=del_count > 0, url=model.uri,)
|
call.result.data = dict(deleted=del_count > 0, url=model.uri,)
|
||||||
@ -549,3 +571,25 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
|||||||
project_name=request.project_name,
|
project_name=request.project_name,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("models.add_or_update_metadata", min_version="2.13")
|
||||||
|
def add_or_update_metadata(
|
||||||
|
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||||
|
):
|
||||||
|
model_id = request.model
|
||||||
|
_get_company_model(company_id=company_id, model_id=model_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"updated": metadata_add_or_update(
|
||||||
|
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("models.delete_metadata", min_version="2.13")
|
||||||
|
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||||
|
model_id = request.model
|
||||||
|
_get_company_model(company_id=company_id, model_id=model_id, only_fields=("id",))
|
||||||
|
|
||||||
|
return {"updated": metadata_delete(cls=Model, _id=model_id, keys=request.keys)}
|
||||||
|
@ -11,11 +11,20 @@ from apiserver.apimodels.queues import (
|
|||||||
GetMetricsRequest,
|
GetMetricsRequest,
|
||||||
GetMetricsResponse,
|
GetMetricsResponse,
|
||||||
QueueMetrics,
|
QueueMetrics,
|
||||||
|
AddOrUpdateMetadataRequest,
|
||||||
|
DeleteMetadataRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.queue import QueueBLL
|
from apiserver.bll.queue import QueueBLL
|
||||||
from apiserver.bll.workers import WorkerBLL
|
from apiserver.bll.workers import WorkerBLL
|
||||||
|
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||||
|
from apiserver.database.model.queue import Queue
|
||||||
from apiserver.service_repo import APICall, endpoint
|
from apiserver.service_repo import APICall, endpoint
|
||||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags, conform_tags
|
from apiserver.services.utils import (
|
||||||
|
conform_tag_fields,
|
||||||
|
conform_output_tags,
|
||||||
|
conform_tags,
|
||||||
|
get_metadata_from_api,
|
||||||
|
)
|
||||||
from apiserver.utilities import extract_properties_to_lists
|
from apiserver.utilities import extract_properties_to_lists
|
||||||
|
|
||||||
worker_bll = WorkerBLL()
|
worker_bll = WorkerBLL()
|
||||||
@ -62,7 +71,11 @@ def create(call: APICall, company_id, request: CreateRequest):
|
|||||||
call, request.tags, request.system_tags, validate=True
|
call, request.tags, request.system_tags, validate=True
|
||||||
)
|
)
|
||||||
queue = queue_bll.create(
|
queue = queue_bll.create(
|
||||||
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
|
company_id=company_id,
|
||||||
|
name=request.name,
|
||||||
|
tags=tags,
|
||||||
|
system_tags=system_tags,
|
||||||
|
metadata=get_metadata_from_api(request.metadata),
|
||||||
)
|
)
|
||||||
call.result.data = {"id": queue.id}
|
call.result.data = {"id": queue.id}
|
||||||
|
|
||||||
@ -220,3 +233,25 @@ def get_queue_metrics(
|
|||||||
for queue, data in queue_dicts.items()
|
for queue, data in queue_dicts.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("queues.add_or_update_metadata", min_version="2.13")
|
||||||
|
def add_or_update_metadata(
|
||||||
|
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||||
|
):
|
||||||
|
queue_id = request.queue
|
||||||
|
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"updated": metadata_add_or_update(
|
||||||
|
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("queues.delete_metadata", min_version="2.13")
|
||||||
|
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||||
|
queue_id = request.queue
|
||||||
|
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||||
|
|
||||||
|
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
|
||||||
|
@ -70,7 +70,7 @@ from apiserver.bll.task.param_utils import (
|
|||||||
escape_paths,
|
escape_paths,
|
||||||
)
|
)
|
||||||
from apiserver.bll.task.task_cleanup import cleanup_task
|
from apiserver.bll.task.task_cleanup import cleanup_task
|
||||||
from apiserver.bll.task.utils import update_task, deleted_prefix
|
from apiserver.bll.task.utils import update_task, deleted_prefix, get_task_for_update
|
||||||
from apiserver.bll.util import SetFieldsResolver
|
from apiserver.bll.util import SetFieldsResolver
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
from apiserver.database.model import EntityVisibility
|
from apiserver.database.model import EntityVisibility
|
||||||
@ -1160,9 +1160,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
|||||||
|
|
||||||
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
||||||
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
|
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
|
||||||
TaskBLL.get_task_with_access(
|
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||||
request.task, company_id=company_id, requires_write_access=True, only=["id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
models_field = f"models__{request.type}"
|
models_field = f"models__{request.type}"
|
||||||
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
|
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
|
||||||
@ -1181,9 +1179,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
|
|||||||
|
|
||||||
@endpoint("tasks.delete_models", min_version="2.13")
|
@endpoint("tasks.delete_models", min_version="2.13")
|
||||||
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
||||||
task = TaskBLL.get_task_with_access(
|
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||||
request.task, company_id=company_id, requires_write_access=True, only=["id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
delete_names = {
|
delete_names = {
|
||||||
type_: [m.name for m in request.models if m.type == type_]
|
type_: [m.name for m in request.models if m.type == type_]
|
||||||
|
@ -2,6 +2,7 @@ from datetime import datetime
|
|||||||
from typing import Union, Sequence, Tuple
|
from typing import Union, Sequence, Tuple
|
||||||
|
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
|
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
||||||
from apiserver.apimodels.organization import Filter
|
from apiserver.apimodels.organization import Filter
|
||||||
from apiserver.database.model.base import GetMixin
|
from apiserver.database.model.base import GetMixin
|
||||||
from apiserver.database.utils import partition_tags
|
from apiserver.database.utils import partition_tags
|
||||||
@ -148,7 +149,9 @@ class DockerCmdBackwardsCompatibility:
|
|||||||
nested_delete(fields, cls.field)
|
nested_delete(fields, cls.field)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def unprepare_from_saved(cls, call: APICall, tasks_data: Union[Sequence[dict], dict]):
|
def unprepare_from_saved(
|
||||||
|
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
|
||||||
|
):
|
||||||
if call.requested_endpoint_version > cls.max_version:
|
if call.requested_endpoint_version > cls.max_version:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -160,6 +163,29 @@ class DockerCmdBackwardsCompatibility:
|
|||||||
if not container or not container.get("image"):
|
if not container or not container.get("image"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
docker_cmd = " ".join(filter(None, map(container.get, ("image", "arguments"))))
|
docker_cmd = " ".join(
|
||||||
|
filter(None, map(container.get, ("image", "arguments")))
|
||||||
|
)
|
||||||
if docker_cmd:
|
if docker_cmd:
|
||||||
nested_set(task, cls.field, docker_cmd)
|
nested_set(task, cls.field, docker_cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_metadata(metadata: Sequence[dict]):
|
||||||
|
if not metadata:
|
||||||
|
return
|
||||||
|
|
||||||
|
keys = [m.get("key") for m in metadata]
|
||||||
|
unique_keys = set(keys)
|
||||||
|
unique_keys.discard(None)
|
||||||
|
if len(keys) != len(set(keys)):
|
||||||
|
raise errors.bad_request.ValidationError("Metadata keys should be unique")
|
||||||
|
|
||||||
|
|
||||||
|
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
|
||||||
|
if not api_metadata:
|
||||||
|
return api_metadata
|
||||||
|
|
||||||
|
metadata = [m.to_struct() for m in api_metadata]
|
||||||
|
validate_metadata(metadata)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
74
apiserver/tests/automated/test_queue_model_metadata.py
Normal file
74
apiserver/tests/automated/test_queue_model_metadata.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from apiserver.tests.api_client import APIClient
|
||||||
|
from apiserver.tests.automated import TestService
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueueAndModelMetadata(TestService):
|
||||||
|
def setUp(self, version="2.13"):
|
||||||
|
super().setUp(version=version)
|
||||||
|
|
||||||
|
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
||||||
|
|
||||||
|
def test_queue_metas(self):
|
||||||
|
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
|
||||||
|
self._test_meta_operations(
|
||||||
|
service=self.api.queues, entity="queue", _id=queue_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_models_metas(self):
|
||||||
|
service = self.api.models
|
||||||
|
entity = "model"
|
||||||
|
model_id = self._temp_model("TestMetadata", metadata=self.meta1)
|
||||||
|
self._test_meta_operations(
|
||||||
|
service=self.api.models, entity="model", _id=model_id
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = self._temp_model("TestMetadata1")
|
||||||
|
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
||||||
|
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
||||||
|
|
||||||
|
def _test_meta_operations(
|
||||||
|
self, service: APIClient.Service, entity: str, _id: str,
|
||||||
|
):
|
||||||
|
assert_meta = partial(self._assertMeta, service=service, entity=entity)
|
||||||
|
assert_meta(_id=_id, meta=self.meta1)
|
||||||
|
|
||||||
|
meta2 = [
|
||||||
|
{"key": "test1", "type": "str", "value": "data1"},
|
||||||
|
{"key": "test2", "type": "str", "value": "data2"},
|
||||||
|
{"key": "test3", "type": "str", "value": "data3"},
|
||||||
|
]
|
||||||
|
service.update(**{entity: _id, "metadata": meta2})
|
||||||
|
assert_meta(_id=_id, meta=meta2)
|
||||||
|
|
||||||
|
updates = [
|
||||||
|
{"key": "test2", "type": "int", "value": "10"},
|
||||||
|
{"key": "test3", "type": "int", "value": "20"},
|
||||||
|
{"key": "test4", "type": "array", "value": "xxx,yyy"},
|
||||||
|
{"key": "test5", "type": "array", "value": "zzz"},
|
||||||
|
]
|
||||||
|
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
|
||||||
|
self.assertEqual(res.updated, 1)
|
||||||
|
assert_meta(_id=_id, meta=[meta2[0], *updates])
|
||||||
|
|
||||||
|
res = service.delete_metadata(
|
||||||
|
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
|
||||||
|
)
|
||||||
|
self.assertEqual(res.updated, 1)
|
||||||
|
assert_meta(_id=_id, meta=meta2[:1])
|
||||||
|
|
||||||
|
def _assertMeta(
|
||||||
|
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
|
||||||
|
):
|
||||||
|
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
|
||||||
|
self.assertEqual(res.metadata, meta)
|
||||||
|
|
||||||
|
def _temp_queue(self, name, **kwargs):
|
||||||
|
return self.create_temp("queues", name=name, **kwargs)
|
||||||
|
|
||||||
|
def _temp_model(self, name: str, **kwargs):
|
||||||
|
return self.create_temp(
|
||||||
|
"models", uri="file://test", name=name, labels={}, **kwargs
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user