mirror of
https://github.com/clearml/clearml-server
synced 2025-05-20 19:24:55 +00:00
Add support for queue and model metadata
This commit is contained in:
parent
2e7f418ee2
commit
29de110abb
23
apiserver/apimodels/metadata.py
Normal file
23
apiserver/apimodels/metadata.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class MetadataItem(Base):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
|
||||
|
||||
class DeleteMetadata(Base):
|
||||
keys: Sequence[str] = ListField(str, validators=validators.Length(minimum_value=1))
|
||||
|
||||
|
||||
class AddOrUpdateMetadata(Base):
|
||||
metadata: Sequence[MetadataItem] = ListField(
|
||||
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||
)
|
@ -3,6 +3,11 @@ from six import string_types
|
||||
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
AddOrUpdateMetadata,
|
||||
)
|
||||
from apiserver.apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
|
||||
|
||||
@ -13,7 +18,7 @@ class GetFrameworksRequest(models.Base):
|
||||
class CreateModelRequest(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
uri = fields.StringField(required=True)
|
||||
labels = DictField(value_types=string_types+(int,))
|
||||
labels = DictField(value_types=string_types + (int,))
|
||||
tags = ListField(items_types=string_types)
|
||||
system_tags = ListField(items_types=string_types)
|
||||
comment = fields.StringField()
|
||||
@ -25,6 +30,7 @@ class CreateModelRequest(models.Base):
|
||||
ready = fields.BoolField(default=True)
|
||||
ui_cache = DictField()
|
||||
task = fields.StringField()
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
|
||||
|
||||
class CreateModelResponse(models.Base):
|
||||
@ -53,3 +59,11 @@ class ModelTaskPublishResponse(models.Base):
|
||||
class PublishModelResponse(UpdateResponse):
|
||||
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
|
||||
updated = fields.IntField()
|
||||
|
||||
|
||||
class DeleteMetadataRequest(DeleteMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
@ -3,6 +3,11 @@ from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
AddOrUpdateMetadata,
|
||||
)
|
||||
|
||||
|
||||
class GetDefaultResp(Base):
|
||||
@ -14,6 +19,7 @@ class CreateRequest(Base):
|
||||
name = StringField(required=True)
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
|
||||
|
||||
class QueueRequest(Base):
|
||||
@ -28,6 +34,7 @@ class UpdateRequest(QueueRequest):
|
||||
name = StringField()
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
|
||||
|
||||
class TaskRequest(QueueRequest):
|
||||
@ -58,3 +65,11 @@ class QueueMetrics(Base):
|
||||
|
||||
class GetMetricsResponse(Base):
|
||||
queues = ListField(QueueMetrics)
|
||||
|
||||
|
||||
class DeleteMetadataRequest(DeleteMetadata):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
queue = StringField(required=True)
|
||||
|
@ -60,10 +60,13 @@ class TaskRequest(models.Base):
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class UpdateRequest(TaskRequest):
|
||||
class TaskUpdateRequest(TaskRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class UpdateRequest(TaskUpdateRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class EnqueueRequest(UpdateRequest):
|
||||
@ -128,9 +131,8 @@ class CloneRequest(TaskRequest):
|
||||
new_project_name = StringField()
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
class AddOrUpdateArtifactsRequest(TaskUpdateRequest):
|
||||
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ArtifactId(models.Base):
|
||||
@ -140,9 +142,8 @@ class ArtifactId(models.Base):
|
||||
)
|
||||
|
||||
|
||||
class DeleteArtifactsRequest(TaskRequest):
|
||||
class DeleteArtifactsRequest(TaskUpdateRequest):
|
||||
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
@ -173,7 +174,7 @@ class ReplaceHyperparams(object):
|
||||
all = "all"
|
||||
|
||||
|
||||
class EditHyperParamsRequest(TaskRequest):
|
||||
class EditHyperParamsRequest(TaskUpdateRequest):
|
||||
hyperparams: Sequence[HyperParamItem] = ListField(
|
||||
[HyperParamItem], validators=Length(minimum_value=1)
|
||||
)
|
||||
@ -181,7 +182,6 @@ class EditHyperParamsRequest(TaskRequest):
|
||||
validators=Enum(*get_options(ReplaceHyperparams)),
|
||||
default=ReplaceHyperparams.none,
|
||||
)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class HyperParamKey(models.Base):
|
||||
@ -189,11 +189,10 @@ class HyperParamKey(models.Base):
|
||||
name = StringField(nullable=True)
|
||||
|
||||
|
||||
class DeleteHyperParamsRequest(TaskRequest):
|
||||
class DeleteHyperParamsRequest(TaskUpdateRequest):
|
||||
hyperparams: Sequence[HyperParamKey] = ListField(
|
||||
[HyperParamKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class GetConfigurationsRequest(MultiTaskRequest):
|
||||
@ -211,17 +210,15 @@ class Configuration(models.Base):
|
||||
description = StringField()
|
||||
|
||||
|
||||
class EditConfigurationRequest(TaskRequest):
|
||||
class EditConfigurationRequest(TaskUpdateRequest):
|
||||
configuration: Sequence[Configuration] = ListField(
|
||||
[Configuration], validators=Length(minimum_value=1)
|
||||
)
|
||||
replace_configuration = BoolField(default=False)
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteConfigurationRequest(TaskRequest):
|
||||
class DeleteConfigurationRequest(TaskUpdateRequest):
|
||||
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ArchiveRequest(MultiTaskRequest):
|
||||
|
@ -32,6 +32,7 @@ class QueueBLL(object):
|
||||
name: str,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[Sequence[dict]] = None,
|
||||
) -> Queue:
|
||||
"""Creates a queue"""
|
||||
with translate_errors_context():
|
||||
@ -43,6 +44,7 @@ class QueueBLL(object):
|
||||
name=name,
|
||||
tags=tags or [],
|
||||
system_tags=system_tags or [],
|
||||
metadata=metadata,
|
||||
last_update=now,
|
||||
)
|
||||
queue.save()
|
||||
|
@ -1,10 +1,10 @@
|
||||
from hashlib import md5
|
||||
from operator import itemgetter
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update, update_task
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.database.utils import hash_field_name
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
|
||||
@ -15,7 +15,7 @@ def get_artifact_id(artifact: dict):
|
||||
Calculate id from 'key' and 'mode' fields
|
||||
Return hash on on the id so that it will not contain mongo illegal characters
|
||||
"""
|
||||
key_hash: str = md5(artifact["key"].encode()).hexdigest()
|
||||
key_hash: str = hash_field_name(artifact["key"])
|
||||
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
|
||||
return f"{key_hash}_{mode}"
|
||||
|
||||
|
44
apiserver/database/model/metadata.py
Normal file
44
apiserver/database/model/metadata.py
Normal file
@ -0,0 +1,44 @@
|
||||
from typing import Sequence, Type
|
||||
|
||||
from mongoengine import EmbeddedDocument, StringField, Document
|
||||
from pymongo import UpdateOne
|
||||
from pymongo.collection import Collection
|
||||
|
||||
from apiserver.database.model.base import ProperDictMixin
|
||||
|
||||
|
||||
class MetadataItem(EmbeddedDocument, ProperDictMixin):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
|
||||
|
||||
def metadata_add_or_update(cls: Type[Document], _id: str, items: Sequence[dict]) -> int:
|
||||
collection: Collection = cls._get_collection()
|
||||
res = collection.update_one(
|
||||
filter={"_id": _id},
|
||||
update={
|
||||
"$set": {f"metadata.$[elem{idx}]": item for idx, item in enumerate(items)}
|
||||
},
|
||||
array_filters=[
|
||||
{f"elem{idx}.key": item["key"]} for idx, item in enumerate(items)
|
||||
],
|
||||
upsert=False,
|
||||
)
|
||||
if len(items) == 1 and res.modified_count == 1:
|
||||
return res.modified_count
|
||||
|
||||
requests = [
|
||||
UpdateOne(
|
||||
filter={"_id": _id, "metadata.key": {"$ne": item["key"]}},
|
||||
update={"$push": {"metadata": item}},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
res = collection.bulk_write(requests)
|
||||
|
||||
return 1 if res.modified_count else 0
|
||||
|
||||
|
||||
def metadata_delete(cls: Type[Document], _id: str, keys: Sequence[str]) -> int:
|
||||
return cls.objects(id=_id).update_one(pull__metadata__key__in=keys)
|
@ -1,9 +1,22 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
EmbeddedDocumentListField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeDictField,
|
||||
SafeSortedListField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.model_labels import ModelLabels
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.project import Project
|
||||
@ -19,6 +32,8 @@ class Model(DbModelMixin, Document):
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
"metadata.key",
|
||||
"metadata.type",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
@ -73,3 +88,6 @@ class Model(DbModelMixin, Document):
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
)
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
@ -11,6 +13,7 @@ from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.task.task import Task
|
||||
|
||||
|
||||
@ -32,6 +35,7 @@ class Queue(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"indexes": ["metadata.key", "metadata.type"],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@ -44,3 +48,6 @@ class Queue(DbModelMixin, Document):
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
)
|
||||
|
@ -1,3 +1,23 @@
|
||||
metadata {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
key {
|
||||
type: string
|
||||
description: The key uniquely identifying the metadata item inside the given entity
|
||||
}
|
||||
tyoe {
|
||||
type: string
|
||||
description: The type of the metadata item
|
||||
}
|
||||
value {
|
||||
type: string
|
||||
description: The value stored in the metadata item
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
credentials {
|
||||
type: object
|
||||
properties {
|
||||
|
@ -1,5 +1,6 @@
|
||||
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
|
||||
_definitions {
|
||||
include "_common.conf"
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
@ -444,6 +445,12 @@ create {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${create."2.1"} {
|
||||
metadata {
|
||||
description: "Model metadata"
|
||||
"$ref": "#/definitions/metadata"
|
||||
}
|
||||
}
|
||||
}
|
||||
edit {
|
||||
"2.1" {
|
||||
@ -532,6 +539,12 @@ edit {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${edit."2.1"} {
|
||||
metadata {
|
||||
description: "Model metadata"
|
||||
"$ref": "#/definitions/metadata"
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
@ -608,6 +621,12 @@ update {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${update."2.1"} {
|
||||
metadata {
|
||||
description: "Model metadata"
|
||||
"$ref": "#/definitions/metadata"
|
||||
}
|
||||
}
|
||||
}
|
||||
set_ready {
|
||||
"2.1" {
|
||||
@ -798,4 +817,63 @@ move {
|
||||
}
|
||||
}
|
||||
}
|
||||
add_or_update_metadata {
|
||||
"2.13" {
|
||||
description: "Add or update model metadata"
|
||||
request {
|
||||
type: object
|
||||
required: [model, metadata]
|
||||
properties {
|
||||
model {
|
||||
description: "ID of the model"
|
||||
type: string
|
||||
}
|
||||
metadata {
|
||||
description: "Metadata items to add or update"
|
||||
"$ref": "#/definitions/metadata"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of models updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_metadata {
|
||||
"2.13" {
|
||||
description: "Delete metadata from model"
|
||||
request {
|
||||
type: object
|
||||
required: [ model, keys ]
|
||||
properties {
|
||||
model {
|
||||
description: "ID of the model"
|
||||
type: string
|
||||
}
|
||||
keys {
|
||||
description: "The list of metadata keys to delete"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of models updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -15,6 +15,8 @@ from apiserver.apimodels.models import (
|
||||
ModelTaskPublishResponse,
|
||||
GetFrameworksRequest,
|
||||
DeleteModelRequest,
|
||||
DeleteMetadataRequest,
|
||||
AddOrUpdateMetadataRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
@ -23,6 +25,7 @@ from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import validate_id
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem
|
||||
@ -32,7 +35,13 @@ from apiserver.database.utils import (
|
||||
filter_fields,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags, ModelsBackwardsCompatibility
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
ModelsBackwardsCompatibility,
|
||||
validate_metadata,
|
||||
get_metadata_from_api,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
@ -160,12 +169,16 @@ create_fields = {
|
||||
"design": None,
|
||||
"labels": dict,
|
||||
"ready": None,
|
||||
"metadata": list,
|
||||
}
|
||||
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
metadata = fields.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
return fields
|
||||
|
||||
|
||||
@ -185,6 +198,17 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
)
|
||||
|
||||
|
||||
def _get_company_model(company_id: str, model_id: str, only_fields=None) -> Model:
|
||||
query = dict(company=company_id, id=model_id)
|
||||
qs = Model.objects(**query)
|
||||
if only_fields:
|
||||
qs = qs.only(*only_fields)
|
||||
model = qs.first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
return model
|
||||
|
||||
|
||||
@endpoint("models.update_for_task", required_fields=["task"])
|
||||
def update_for_task(call: APICall, company_id, _):
|
||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||
@ -218,10 +242,9 @@ def update_for_task(call: APICall, company_id, _):
|
||||
)
|
||||
|
||||
if override_model_id:
|
||||
query = dict(company=company_id, id=override_model_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
model = _get_company_model(
|
||||
company_id=company_id, model_id=override_model_id
|
||||
)
|
||||
else:
|
||||
if "name" not in call.data:
|
||||
# use task name if name not provided
|
||||
@ -294,6 +317,8 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
|
||||
validate_metadata(fields.get("metadata"))
|
||||
|
||||
# create and save model
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
@ -352,10 +377,7 @@ def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
model = _get_company_model(company_id=company_id, model_id=model_id)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
@ -401,11 +423,7 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
# get model by id
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
model = _get_company_model(company_id=company_id, model_id=model_id)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
@ -416,6 +434,10 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
metadata = data.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
@ -463,11 +485,11 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
|
||||
force = request.force
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).only("id", "task", "project", "uri").first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
model = _get_company_model(
|
||||
company_id=company_id,
|
||||
model_id=model_id,
|
||||
only_fields=("id", "task", "project", "uri"),
|
||||
)
|
||||
deleted_model_id = f"{deleted_prefix}{model_id}"
|
||||
|
||||
using_tasks = Task.objects(models__input__model=model_id).only("id")
|
||||
@ -507,7 +529,7 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
del_count = Model.objects(**query).delete()
|
||||
del_count = Model.objects(id=model_id, company=company_id).delete()
|
||||
if del_count:
|
||||
_reset_cached_tags(company_id, projects=[model.project])
|
||||
call.result.data = dict(deleted=del_count > 0, url=model.uri,)
|
||||
@ -549,3 +571,25 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
project_name=request.project_name,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.add_or_update_metadata", min_version="2.13")
|
||||
def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
model_id = request.model
|
||||
_get_company_model(company_id=company_id, model_id=model_id)
|
||||
|
||||
return {
|
||||
"updated": metadata_add_or_update(
|
||||
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
model_id = request.model
|
||||
_get_company_model(company_id=company_id, model_id=model_id, only_fields=("id",))
|
||||
|
||||
return {"updated": metadata_delete(cls=Model, _id=model_id, keys=request.keys)}
|
||||
|
@ -11,11 +11,20 @@ from apiserver.apimodels.queues import (
|
||||
GetMetricsRequest,
|
||||
GetMetricsResponse,
|
||||
QueueMetrics,
|
||||
AddOrUpdateMetadataRequest,
|
||||
DeleteMetadataRequest,
|
||||
)
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags, conform_tags
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
conform_tags,
|
||||
get_metadata_from_api,
|
||||
)
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
|
||||
worker_bll = WorkerBLL()
|
||||
@ -62,7 +71,11 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
call, request.tags, request.system_tags, validate=True
|
||||
)
|
||||
queue = queue_bll.create(
|
||||
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
|
||||
company_id=company_id,
|
||||
name=request.name,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
metadata=get_metadata_from_api(request.metadata),
|
||||
)
|
||||
call.result.data = {"id": queue.id}
|
||||
|
||||
@ -220,3 +233,25 @@ def get_queue_metrics(
|
||||
for queue, data in queue_dicts.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@endpoint("queues.add_or_update_metadata", min_version="2.13")
|
||||
def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
return {
|
||||
"updated": metadata_add_or_update(
|
||||
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("queues.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
|
||||
|
@ -70,7 +70,7 @@ from apiserver.bll.task.param_utils import (
|
||||
escape_paths,
|
||||
)
|
||||
from apiserver.bll.task.task_cleanup import cleanup_task
|
||||
from apiserver.bll.task.utils import update_task, deleted_prefix
|
||||
from apiserver.bll.task.utils import update_task, deleted_prefix, get_task_for_update
|
||||
from apiserver.bll.util import SetFieldsResolver
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
@ -1160,9 +1160,7 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
|
||||
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
||||
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
|
||||
TaskBLL.get_task_with_access(
|
||||
request.task, company_id=company_id, requires_write_access=True, only=["id"]
|
||||
)
|
||||
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
|
||||
models_field = f"models__{request.type}"
|
||||
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
|
||||
@ -1181,9 +1179,7 @@ def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequ
|
||||
|
||||
@endpoint("tasks.delete_models", min_version="2.13")
|
||||
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
request.task, company_id=company_id, requires_write_access=True, only=["id"]
|
||||
)
|
||||
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
|
||||
delete_names = {
|
||||
type_: [m.name for m in request.models if m.type == type_]
|
||||
|
@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
||||
from apiserver.apimodels.organization import Filter
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.utils import partition_tags
|
||||
@ -148,7 +149,9 @@ class DockerCmdBackwardsCompatibility:
|
||||
nested_delete(fields, cls.field)
|
||||
|
||||
@classmethod
|
||||
def unprepare_from_saved(cls, call: APICall, tasks_data: Union[Sequence[dict], dict]):
|
||||
def unprepare_from_saved(
|
||||
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
|
||||
):
|
||||
if call.requested_endpoint_version > cls.max_version:
|
||||
return
|
||||
|
||||
@ -160,6 +163,29 @@ class DockerCmdBackwardsCompatibility:
|
||||
if not container or not container.get("image"):
|
||||
continue
|
||||
|
||||
docker_cmd = " ".join(filter(None, map(container.get, ("image", "arguments"))))
|
||||
docker_cmd = " ".join(
|
||||
filter(None, map(container.get, ("image", "arguments")))
|
||||
)
|
||||
if docker_cmd:
|
||||
nested_set(task, cls.field, docker_cmd)
|
||||
|
||||
|
||||
def validate_metadata(metadata: Sequence[dict]):
|
||||
if not metadata:
|
||||
return
|
||||
|
||||
keys = [m.get("key") for m in metadata]
|
||||
unique_keys = set(keys)
|
||||
unique_keys.discard(None)
|
||||
if len(keys) != len(set(keys)):
|
||||
raise errors.bad_request.ValidationError("Metadata keys should be unique")
|
||||
|
||||
|
||||
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
|
||||
if not api_metadata:
|
||||
return api_metadata
|
||||
|
||||
metadata = [m.to_struct() for m in api_metadata]
|
||||
validate_metadata(metadata)
|
||||
|
||||
return metadata
|
||||
|
74
apiserver/tests/automated/test_queue_model_metadata.py
Normal file
74
apiserver/tests/automated/test_queue_model_metadata.py
Normal file
@ -0,0 +1,74 @@
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.tests.api_client import APIClient
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestQueueAndModelMetadata(TestService):
|
||||
def setUp(self, version="2.13"):
|
||||
super().setUp(version=version)
|
||||
|
||||
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
||||
|
||||
def test_queue_metas(self):
|
||||
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
|
||||
self._test_meta_operations(
|
||||
service=self.api.queues, entity="queue", _id=queue_id
|
||||
)
|
||||
|
||||
def test_models_metas(self):
|
||||
service = self.api.models
|
||||
entity = "model"
|
||||
model_id = self._temp_model("TestMetadata", metadata=self.meta1)
|
||||
self._test_meta_operations(
|
||||
service=self.api.models, entity="model", _id=model_id
|
||||
)
|
||||
|
||||
model_id = self._temp_model("TestMetadata1")
|
||||
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
||||
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
||||
|
||||
def _test_meta_operations(
|
||||
self, service: APIClient.Service, entity: str, _id: str,
|
||||
):
|
||||
assert_meta = partial(self._assertMeta, service=service, entity=entity)
|
||||
assert_meta(_id=_id, meta=self.meta1)
|
||||
|
||||
meta2 = [
|
||||
{"key": "test1", "type": "str", "value": "data1"},
|
||||
{"key": "test2", "type": "str", "value": "data2"},
|
||||
{"key": "test3", "type": "str", "value": "data3"},
|
||||
]
|
||||
service.update(**{entity: _id, "metadata": meta2})
|
||||
assert_meta(_id=_id, meta=meta2)
|
||||
|
||||
updates = [
|
||||
{"key": "test2", "type": "int", "value": "10"},
|
||||
{"key": "test3", "type": "int", "value": "20"},
|
||||
{"key": "test4", "type": "array", "value": "xxx,yyy"},
|
||||
{"key": "test5", "type": "array", "value": "zzz"},
|
||||
]
|
||||
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=[meta2[0], *updates])
|
||||
|
||||
res = service.delete_metadata(
|
||||
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=meta2[:1])
|
||||
|
||||
def _assertMeta(
|
||||
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
|
||||
):
|
||||
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
|
||||
self.assertEqual(res.metadata, meta)
|
||||
|
||||
def _temp_queue(self, name, **kwargs):
|
||||
return self.create_temp("queues", name=name, **kwargs)
|
||||
|
||||
def _temp_model(self, name: str, **kwargs):
|
||||
return self.create_temp(
|
||||
"models", uri="file://test", name=name, labels={}, **kwargs
|
||||
)
|
Loading…
Reference in New Issue
Block a user