Add support for queue and model metadata

This commit is contained in:
allegroai 2021-05-03 17:50:25 +03:00
parent 2e7f418ee2
commit 29de110abb
17 changed files with 1038 additions and 583 deletions

View File

@ -0,0 +1,23 @@
from typing import Sequence
from jsonmodels import validators
from jsonmodels.fields import StringField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
class MetadataItem(Base):
key = StringField(required=True)
type = StringField(required=True)
value = StringField(required=True)
class DeleteMetadata(Base):
keys: Sequence[str] = ListField(str, validators=validators.Length(minimum_value=1))
class AddOrUpdateMetadata(Base):
metadata: Sequence[MetadataItem] = ListField(
[MetadataItem], validators=validators.Length(minimum_value=1)
)

View File

@ -3,6 +3,11 @@ from six import string_types
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.base import UpdateResponse
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
AddOrUpdateMetadata,
)
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
@ -13,7 +18,7 @@ class GetFrameworksRequest(models.Base):
class CreateModelRequest(models.Base):
name = fields.StringField(required=True)
uri = fields.StringField(required=True)
labels = DictField(value_types=string_types+(int,))
labels = DictField(value_types=string_types + (int,))
tags = ListField(items_types=string_types)
system_tags = ListField(items_types=string_types)
comment = fields.StringField()
@ -25,6 +30,7 @@ class CreateModelRequest(models.Base):
ready = fields.BoolField(default=True)
ui_cache = DictField()
task = fields.StringField()
metadata = ListField(items_types=[MetadataItem])
class CreateModelResponse(models.Base):
@ -53,3 +59,11 @@ class ModelTaskPublishResponse(models.Base):
class PublishModelResponse(UpdateResponse):
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
updated = fields.IntField()
class DeleteMetadataRequest(DeleteMetadata):
model = fields.StringField(required=True)
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
model = fields.StringField(required=True)

View File

@ -3,6 +3,11 @@ from jsonmodels.fields import StringField, IntField, BoolField, FloatField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
AddOrUpdateMetadata,
)
class GetDefaultResp(Base):
@ -14,6 +19,7 @@ class CreateRequest(Base):
name = StringField(required=True)
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
class QueueRequest(Base):
@ -28,6 +34,7 @@ class UpdateRequest(QueueRequest):
name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
class TaskRequest(QueueRequest):
@ -58,3 +65,11 @@ class QueueMetrics(Base):
class GetMetricsResponse(Base):
queues = ListField(QueueMetrics)
class DeleteMetadataRequest(DeleteMetadata):
queue = StringField(required=True)
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
queue = StringField(required=True)

View File

@ -60,10 +60,13 @@ class TaskRequest(models.Base):
task = StringField(required=True)
class UpdateRequest(TaskRequest):
class TaskUpdateRequest(TaskRequest):
force = BoolField(default=False)
class UpdateRequest(TaskUpdateRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
force = BoolField(default=False)
class EnqueueRequest(UpdateRequest):
@ -128,9 +131,8 @@ class CloneRequest(TaskRequest):
new_project_name = StringField()
class AddOrUpdateArtifactsRequest(TaskRequest):
class AddOrUpdateArtifactsRequest(TaskUpdateRequest):
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArtifactId(models.Base):
@ -140,9 +142,8 @@ class ArtifactId(models.Base):
)
class DeleteArtifactsRequest(TaskRequest):
class DeleteArtifactsRequest(TaskUpdateRequest):
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ResetRequest(UpdateRequest):
@ -173,7 +174,7 @@ class ReplaceHyperparams(object):
all = "all"
class EditHyperParamsRequest(TaskRequest):
class EditHyperParamsRequest(TaskUpdateRequest):
hyperparams: Sequence[HyperParamItem] = ListField(
[HyperParamItem], validators=Length(minimum_value=1)
)
@ -181,7 +182,6 @@ class EditHyperParamsRequest(TaskRequest):
validators=Enum(*get_options(ReplaceHyperparams)),
default=ReplaceHyperparams.none,
)
force = BoolField(default=False)
class HyperParamKey(models.Base):
@ -189,11 +189,10 @@ class HyperParamKey(models.Base):
name = StringField(nullable=True)
class DeleteHyperParamsRequest(TaskRequest):
class DeleteHyperParamsRequest(TaskUpdateRequest):
hyperparams: Sequence[HyperParamKey] = ListField(
[HyperParamKey], validators=Length(minimum_value=1)
)
force = BoolField(default=False)
class GetConfigurationsRequest(MultiTaskRequest):
@ -211,17 +210,15 @@ class Configuration(models.Base):
description = StringField()
class EditConfigurationRequest(TaskRequest):
class EditConfigurationRequest(TaskUpdateRequest):
configuration: Sequence[Configuration] = ListField(
[Configuration], validators=Length(minimum_value=1)
)
replace_configuration = BoolField(default=False)
force = BoolField(default=False)
class DeleteConfigurationRequest(TaskRequest):
class DeleteConfigurationRequest(TaskUpdateRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
force = BoolField(default=False)
class ArchiveRequest(MultiTaskRequest):

View File

@ -32,6 +32,7 @@ class QueueBLL(object):
name: str,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[Sequence[dict]] = None,
) -> Queue:
"""Creates a queue"""
with translate_errors_context():
@ -43,6 +44,7 @@ class QueueBLL(object):
name=name,
tags=tags or [],
system_tags=system_tags or [],
metadata=metadata,
last_update=now,
)
queue.save()

View File

@ -1,10 +1,10 @@
from hashlib import md5
from operator import itemgetter
from typing import Sequence
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.database.utils import hash_field_name
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
@ -15,7 +15,7 @@ def get_artifact_id(artifact: dict):
Calculate id from 'key' and 'mode' fields
Return hash on on the id so that it will not contain mongo illegal characters
"""
key_hash: str = md5(artifact["key"].encode()).hexdigest()
key_hash: str = hash_field_name(artifact["key"])
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
return f"{key_hash}_{mode}"

View File

@ -0,0 +1,44 @@
from typing import Sequence, Type
from mongoengine import EmbeddedDocument, StringField, Document
from pymongo import UpdateOne
from pymongo.collection import Collection
from apiserver.database.model.base import ProperDictMixin
class MetadataItem(EmbeddedDocument, ProperDictMixin):
key = StringField(required=True)
type = StringField(required=True)
value = StringField(required=True)
def metadata_add_or_update(cls: Type[Document], _id: str, items: Sequence[dict]) -> int:
collection: Collection = cls._get_collection()
res = collection.update_one(
filter={"_id": _id},
update={
"$set": {f"metadata.$[elem{idx}]": item for idx, item in enumerate(items)}
},
array_filters=[
{f"elem{idx}.key": item["key"]} for idx, item in enumerate(items)
],
upsert=False,
)
if len(items) == 1 and res.modified_count == 1:
return res.modified_count
requests = [
UpdateOne(
filter={"_id": _id, "metadata.key": {"$ne": item["key"]}},
update={"$push": {"metadata": item}},
)
for item in items
]
res = collection.bulk_write(requests)
return 1 if res.modified_count else 0
def metadata_delete(cls: Type[Document], _id: str, keys: Sequence[str]) -> int:
return cls.objects(id=_id).update_one(pull__metadata__key__in=keys)

View File

@ -1,9 +1,22 @@
from mongoengine import Document, StringField, DateTimeField, BooleanField
from typing import Sequence
from mongoengine import (
Document,
StringField,
DateTimeField,
BooleanField,
EmbeddedDocumentListField,
)
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeDictField, SafeSortedListField
from apiserver.database.fields import (
StrippedStringField,
SafeDictField,
SafeSortedListField,
)
from apiserver.database.model import DbModelMixin
from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.company import Company
from apiserver.database.model.project import Project
@ -19,6 +32,8 @@ class Model(DbModelMixin, Document):
"parent",
"project",
"task",
"metadata.key",
"metadata.type",
("company", "framework"),
("company", "name"),
("company", "user"),
@ -73,3 +88,6 @@ class Model(DbModelMixin, Document):
default=dict, user_set_allowed=True, exclude_by_default=True
)
company_origin = StringField(exclude_by_default=True)
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
)

View File

@ -1,3 +1,5 @@
from typing import Sequence
from mongoengine import (
Document,
EmbeddedDocument,
@ -11,6 +13,7 @@ from apiserver.database.fields import StrippedStringField, SafeSortedListField
from apiserver.database.model import DbModelMixin
from apiserver.database.model.base import ProperDictMixin, GetMixin
from apiserver.database.model.company import Company
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.task.task import Task
@ -32,6 +35,7 @@ class Queue(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
"indexes": ["metadata.key", "metadata.type"],
}
id = StringField(primary_key=True)
@ -44,3 +48,6 @@ class Queue(DbModelMixin, Document):
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
)

View File

@ -1,3 +1,23 @@
metadata {
type: array
items {
type: object
properties {
key {
type: string
description: The key uniquely identifying the metadata item inside the given entity
}
tyoe {
type: string
description: The type of the metadata item
}
value {
type: string
description: The value stored in the metadata item
}
}
}
}
credentials {
type: object
properties {

View File

@ -1,5 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions {
include "_common.conf"
multi_field_pattern_data {
type: object
properties {
@ -444,6 +445,12 @@ create {
}
}
}
"2.13": ${create."2.1"} {
metadata {
description: "Model metadata"
"$ref": "#/definitions/metadata"
}
}
}
edit {
"2.1" {
@ -532,6 +539,12 @@ edit {
}
}
}
"2.13": ${edit."2.1"} {
metadata {
description: "Model metadata"
"$ref": "#/definitions/metadata"
}
}
}
update {
"2.1" {
@ -608,6 +621,12 @@ update {
}
}
}
"2.13": ${update."2.1"} {
metadata {
description: "Model metadata"
"$ref": "#/definitions/metadata"
}
}
}
set_ready {
"2.1" {
@ -798,4 +817,63 @@ move {
}
}
}
add_or_update_metadata {
"2.13" {
description: "Add or update model metadata"
request {
type: object
required: [model, metadata]
properties {
model {
description: "ID of the model"
type: string
}
metadata {
description: "Metadata items to add or update"
"$ref": "#/definitions/metadata"
}
}
}
response {
type: object
properties {
updated {
description: "Number of models updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
delete_metadata {
"2.13" {
description: "Delete metadata from model"
request {
type: object
required: [ model, keys ]
properties {
model {
description: "ID of the model"
type: string
}
keys {
description: "The list of metadata keys to delete"
type: array
items {type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of models updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -15,6 +15,8 @@ from apiserver.apimodels.models import (
ModelTaskPublishResponse,
GetFrameworksRequest,
DeleteModelRequest,
DeleteMetadataRequest,
AddOrUpdateMetadataRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
@ -23,6 +25,7 @@ from apiserver.bll.task.utils import deleted_prefix
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import validate_id
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem
@ -32,7 +35,13 @@ from apiserver.database.utils import (
filter_fields,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import conform_tag_fields, conform_output_tags, ModelsBackwardsCompatibility
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
ModelsBackwardsCompatibility,
validate_metadata,
get_metadata_from_api,
)
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
@ -160,12 +169,16 @@ create_fields = {
"design": None,
"labels": dict,
"ready": None,
"metadata": list,
}
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields, validate=True)
metadata = fields.get("metadata")
if metadata:
validate_metadata(metadata)
return fields
@ -185,6 +198,17 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
)
def _get_company_model(company_id: str, model_id: str, only_fields=None) -> Model:
query = dict(company=company_id, id=model_id)
qs = Model.objects(**query)
if only_fields:
qs = qs.only(*only_fields)
model = qs.first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
return model
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call: APICall, company_id, _):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
@ -218,10 +242,9 @@ def update_for_task(call: APICall, company_id, _):
)
if override_model_id:
query = dict(company=company_id, id=override_model_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
model = _get_company_model(
company_id=company_id, model_id=override_model_id
)
else:
if "name" not in call.data:
# use task name if name not provided
@ -294,6 +317,8 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
validate_metadata(fields.get("metadata"))
# create and save model
model = Model(
id=database.utils.id(),
@ -352,10 +377,7 @@ def edit(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
model = _get_company_model(company_id=company_id, model_id=model_id)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
@ -401,11 +423,7 @@ def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
with translate_errors_context():
# get model by id
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
model = _get_company_model(company_id=company_id, model_id=model_id)
data = prepare_update_fields(call, company_id, call.data)
@ -416,6 +434,10 @@ def _update_model(call: APICall, company_id, model_id=None):
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
metadata = data.get("metadata")
if metadata:
validate_metadata(metadata)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
new_project = updated_fields.get("project", model.project)
@ -463,11 +485,11 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
force = request.force
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).only("id", "task", "project", "uri").first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
model = _get_company_model(
company_id=company_id,
model_id=model_id,
only_fields=("id", "task", "project", "uri"),
)
deleted_model_id = f"{deleted_prefix}{model_id}"
using_tasks = Task.objects(models__input__model=model_id).only("id")
@ -507,7 +529,7 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
upsert=False,
)
del_count = Model.objects(**query).delete()
del_count = Model.objects(id=model_id, company=company_id).delete()
if del_count:
_reset_cached_tags(company_id, projects=[model.project])
call.result.data = dict(deleted=del_count > 0, url=model.uri,)
@ -549,3 +571,25 @@ def move(call: APICall, company_id: str, request: MoveRequest):
project_name=request.project_name,
)
}
@endpoint("models.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
model_id = request.model
_get_company_model(company_id=company_id, model_id=model_id)
return {
"updated": metadata_add_or_update(
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
)
}
@endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model
_get_company_model(company_id=company_id, model_id=model_id, only_fields=("id",))
return {"updated": metadata_delete(cls=Model, _id=model_id, keys=request.keys)}

View File

@ -11,11 +11,20 @@ from apiserver.apimodels.queues import (
GetMetricsRequest,
GetMetricsResponse,
QueueMetrics,
AddOrUpdateMetadataRequest,
DeleteMetadataRequest,
)
from apiserver.bll.queue import QueueBLL
from apiserver.bll.workers import WorkerBLL
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.queue import Queue
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import conform_tag_fields, conform_output_tags, conform_tags
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
conform_tags,
get_metadata_from_api,
)
from apiserver.utilities import extract_properties_to_lists
worker_bll = WorkerBLL()
@ -62,7 +71,11 @@ def create(call: APICall, company_id, request: CreateRequest):
call, request.tags, request.system_tags, validate=True
)
queue = queue_bll.create(
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
company_id=company_id,
name=request.name,
tags=tags,
system_tags=system_tags,
metadata=get_metadata_from_api(request.metadata),
)
call.result.data = {"id": queue.id}
@ -220,3 +233,25 @@ def get_queue_metrics(
for queue, data in queue_dicts.items()
]
)
@endpoint("queues.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {
"updated": metadata_add_or_update(
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
)
}
@endpoint("queues.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}

View File

@ -70,7 +70,7 @@ from apiserver.bll.task.param_utils import (
escape_paths,
)
from apiserver.bll.task.task_cleanup import cleanup_task
from apiserver.bll.task.utils import update_task, deleted_prefix
from apiserver.bll.task.utils import update_task, deleted_prefix, get_task_for_update
from apiserver.bll.util import SetFieldsResolver
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
@ -1160,9 +1160,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
TaskBLL.get_task_with_access(
request.task, company_id=company_id, requires_write_access=True, only=["id"]
)
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
models_field = f"models__{request.type}"
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
@ -1181,9 +1179,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
@endpoint("tasks.delete_models", min_version="2.13")
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
task = TaskBLL.get_task_with_access(
request.task, company_id=company_id, requires_write_access=True, only=["id"]
)
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
delete_names = {
type_: [m.name for m in request.models if m.type == type_]

View File

@ -2,6 +2,7 @@ from datetime import datetime
from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin
from apiserver.database.utils import partition_tags
@ -148,7 +149,9 @@ class DockerCmdBackwardsCompatibility:
nested_delete(fields, cls.field)
@classmethod
def unprepare_from_saved(cls, call: APICall, tasks_data: Union[Sequence[dict], dict]):
def unprepare_from_saved(
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
):
if call.requested_endpoint_version > cls.max_version:
return
@ -160,6 +163,29 @@ class DockerCmdBackwardsCompatibility:
if not container or not container.get("image"):
continue
docker_cmd = " ".join(filter(None, map(container.get, ("image", "arguments"))))
docker_cmd = " ".join(
filter(None, map(container.get, ("image", "arguments")))
)
if docker_cmd:
nested_set(task, cls.field, docker_cmd)
def validate_metadata(metadata: Sequence[dict]):
if not metadata:
return
keys = [m.get("key") for m in metadata]
unique_keys = set(keys)
unique_keys.discard(None)
if len(keys) != len(set(keys)):
raise errors.bad_request.ValidationError("Metadata keys should be unique")
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
if not api_metadata:
return api_metadata
metadata = [m.to_struct() for m in api_metadata]
validate_metadata(metadata)
return metadata

View File

@ -0,0 +1,74 @@
from functools import partial
from typing import Sequence
from apiserver.tests.api_client import APIClient
from apiserver.tests.automated import TestService
class TestQueueAndModelMetadata(TestService):
def setUp(self, version="2.13"):
super().setUp(version=version)
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
def test_queue_metas(self):
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
self._test_meta_operations(
service=self.api.queues, entity="queue", _id=queue_id
)
def test_models_metas(self):
service = self.api.models
entity = "model"
model_id = self._temp_model("TestMetadata", metadata=self.meta1)
self._test_meta_operations(
service=self.api.models, entity="model", _id=model_id
)
model_id = self._temp_model("TestMetadata1")
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str,
):
assert_meta = partial(self._assertMeta, service=service, entity=entity)
assert_meta(_id=_id, meta=self.meta1)
meta2 = [
{"key": "test1", "type": "str", "value": "data1"},
{"key": "test2", "type": "str", "value": "data2"},
{"key": "test3", "type": "str", "value": "data3"},
]
service.update(**{entity: _id, "metadata": meta2})
assert_meta(_id=_id, meta=meta2)
updates = [
{"key": "test2", "type": "int", "value": "10"},
{"key": "test3", "type": "int", "value": "20"},
{"key": "test4", "type": "array", "value": "xxx,yyy"},
{"key": "test5", "type": "array", "value": "zzz"},
]
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=[meta2[0], *updates])
res = service.delete_metadata(
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
)
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=meta2[:1])
def _assertMeta(
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
):
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
self.assertEqual(res.metadata, meta)
def _temp_queue(self, name, **kwargs):
return self.create_temp("queues", name=name, **kwargs)
def _temp_model(self, name: str, **kwargs):
return self.create_temp(
"models", uri="file://test", name=name, labels={}, **kwargs
)