Add metadata dict support for models, queues

Add more info for projects
This commit is contained in:
allegroai 2022-03-15 16:18:57 +02:00
parent 04ea9018a3
commit af09fba755
22 changed files with 798 additions and 435 deletions

View File

@ -1,7 +1,7 @@
from typing import Sequence
from jsonmodels import validators
from jsonmodels.fields import StringField
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
@ -21,3 +21,4 @@ class AddOrUpdateMetadata(Base):
metadata: Sequence[MetadataItem] = ListField(
[MetadataItem], validators=validators.Length(minimum_value=1)
)
replace_metadata = BoolField(default=False)

View File

@ -30,7 +30,7 @@ class CreateModelRequest(models.Base):
ready = fields.BoolField(default=True)
ui_cache = DictField()
task = fields.StringField()
metadata = ListField(items_types=[MetadataItem])
metadata = DictField(value_types=[MetadataItem])
class CreateModelResponse(models.Base):

View File

@ -2,7 +2,7 @@ from jsonmodels import validators
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
@ -19,13 +19,18 @@ class CreateRequest(Base):
name = StringField(required=True)
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
metadata = DictField(value_types=[MetadataItem])
class QueueRequest(Base):
queue = StringField(required=True)
class GetNextTaskRequest(QueueRequest):
queue = StringField(required=True)
get_task_info = BoolField(default=False)
class DeleteRequest(QueueRequest):
force = BoolField(default=False)
@ -34,7 +39,7 @@ class UpdateRequest(QueueRequest):
name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
metadata = DictField(value_types=[MetadataItem])
class TaskRequest(QueueRequest):

View File

@ -7,6 +7,7 @@ from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
from .metadata import Metadata
class ModelBLL:

View File

@ -0,0 +1,111 @@
from typing import Sequence, Union, Mapping
from mongoengine import Document
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem
from apiserver.database.model.base import GetMixin
from apiserver.service_repo import APICall
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
)
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
class Metadata:
@staticmethod
def metadata_from_api(
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
) -> dict:
if not api_data:
return {}
if isinstance(api_data, dict):
return {
ParameterKeyEscaper.escape(k): v.to_struct()
for k, v in api_data.items()
}
return {
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
}
@classmethod
def edit_metadata(
cls,
obj: Document,
items: Sequence[MetadataItem],
replace_metadata: bool,
**more_updates,
) -> int:
with TimingContext("mongo", "edit_metadata"):
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
update_cmds["set__metadata"] = metadata
else:
for key, value in metadata.items():
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
return obj.update(**update_cmds, **more_updates)
@classmethod
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
with TimingContext("mongo", "delete_metadata"):
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
for key in set(keys)
},
**more_updates,
)
@staticmethod
def _process_path(path: str):
"""
Frontend does a partial escaping on the path so the all '.' in key names are escaped
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
raise errors.bad_request.ValidationError("invalid field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
)
@classmethod
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
for prefix in (
"metadata.",
"-metadata.",
):
paths = [
cls._process_path(path) if path.startswith(prefix) else path
for path in paths
]
return paths
@classmethod
def escape_query_parameters(cls, call: APICall) -> dict:
if not call.data:
return call.data
keys = list(call.data)
call_data = {
safe_key: call.data[key]
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
}
projection = GetMixin.get_projection(call_data)
if projection:
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
ordering = GetMixin.get_ordering(call_data)
if ordering:
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
return call_data

View File

@ -388,6 +388,17 @@ class ProjectBLL:
}
}
def max_started_subquery(condition):
return {
"$max": {
"$cond": {
"if": condition,
"then": "$started",
"else": datetime.min,
}
}
}
def runtime_subquery(additional_cond):
return {
# the sum of
@ -431,14 +442,22 @@ class ProjectBLL:
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
cond, time_thresh=time_thresh
)
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
def get_state_filter() -> dict:
if not specific_state:
return {}
if specific_state == EntityVisibility.archived:
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"company": {"$in": [None, "", company_id]},
"type": {"$in": ["training", "testing", "annotation"]},
"project": {"$in": project_ids},
**get_state_filter(),
}
},
ensure_valid_fields(),
@ -547,6 +566,8 @@ class ProjectBLL:
) -> Dict[str, dict]:
return {
section: a.get(section, 0) + b.get(section, 0)
if not section.endswith("max_task_started")
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
for section in set(a) | set(b)
}
@ -562,6 +583,10 @@ class ProjectBLL:
project_section_statuses = nested_get(
status_count, (project_id, section), default=default_counts
)
def get_time_or_none(value):
return value if value != datetime.min else None
return {
"status_count": project_section_statuses,
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
@ -570,6 +595,9 @@ class ProjectBLL:
"completed_tasks": project_runtime.get(
f"{section}_recently_completed", 0
),
"last_task_run": get_time_or_none(
project_runtime.get(f"{section}_max_task_started", datetime.min)
),
}
report_for_states = [
@ -723,7 +751,9 @@ class ProjectBLL:
return Model.objects(query).distinct(field="framework")
@classmethod
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
def calc_own_contents(
cls, company: str, project_ids: Sequence[str]
) -> Dict[str, dict]:
"""
Returns the amount of task/models per requested project
Use separate aggregation calls on Task/Model instead of lookup
@ -739,30 +769,17 @@ class ProjectBLL:
"project": {"$in": project_ids},
}
},
{
"$project": {"project": 1}
},
{
"$group": {
"_id": "$project",
"count": {"$sum": 1},
}
}
{"$project": {"project": 1}},
{"$group": {"_id": "$project", "count": {"$sum": 1}}}
]
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
return {
data["_id"]: data["count"]
for data in cls_.aggregate(pipeline)
}
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
with TimingContext("mongo", "get_security_groups"):
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {
"own_tasks": tasks.get(pid, 0),
"own_models": models.get(pid, 0),
}
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
for pid in project_ids
}

View File

@ -10,6 +10,7 @@ from typing import (
from redis import StrictRedis
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get
@ -239,3 +240,53 @@ class ProjectQueries:
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
def get_model_metadata_keys(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"metadata": {"$exists": True, "$gt": {}},
}
},
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
{"$unwind": "$metadata"},
{"$group": {"_id": "$metadata.k"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Model.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
ParameterKeyEscaper.unescape(r.get("_id"))
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results

View File

@ -32,7 +32,7 @@ class QueueBLL(object):
name: str,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[Sequence[dict]] = None,
metadata: Optional[dict] = None,
) -> Queue:
"""Creates a queue"""
with translate_errors_context():

View File

@ -95,6 +95,7 @@ class GetMixin(PropsMixin):
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {}
class QueryParameterOptions(object):
@ -599,7 +600,7 @@ class GetMixin(PropsMixin):
return size
@classmethod
def get_data_with_scroll_and_filter_support(
def get_data_with_scroll_support(
cls,
query_dict: dict,
data_getter: Callable[[], Sequence[dict]],
@ -629,15 +630,12 @@ class GetMixin(PropsMixin):
if cls._start_key in query_dict:
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
def update_state(returned_len: int):
if not state:
return
if state:
state.position = query_dict[cls._start_key]
cls.get_cache_manager().set_state(state)
if ret_params is not None:
ret_params["scroll_id"] = state.id
update_state(len(data))
return data
@classmethod
@ -770,7 +768,7 @@ class GetMixin(PropsMixin):
override_projection=override_projection,
override_collation=override_collation,
)
return cls.get_data_with_scroll_and_filter_support(
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
)

View File

@ -1,10 +1,8 @@
from typing import Sequence
from mongoengine import (
StringField,
DateTimeField,
BooleanField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
)
from apiserver.database import Database, strict
@ -12,6 +10,7 @@ from apiserver.database.fields import (
StrippedStringField,
SafeDictField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import GetMixin
@ -22,6 +21,10 @@ from apiserver.database.model.task.task import Task
class Model(AttributedDocument):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
}
meta = {
"db_alias": Database.backend,
"strict": strict,
@ -30,8 +33,6 @@ class Model(AttributedDocument):
"project",
"task",
"last_update",
"metadata.key",
"metadata.type",
("company", "framework"),
("company", "name"),
("company", "user"),
@ -63,6 +64,7 @@ class Model(AttributedDocument):
"project",
"task",
"parent",
"metadata.*",
),
datetime_fields=("last_update",),
)
@ -86,6 +88,6 @@ class Model(AttributedDocument):
default=dict, user_set_allowed=True, exclude_by_default=True
)
company_origin = StringField(exclude_by_default=True)
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)

View File

@ -1,16 +1,19 @@
from typing import Sequence
from mongoengine import (
Document,
EmbeddedDocument,
StringField,
DateTimeField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
)
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeSortedListField
from apiserver.database.model import DbModelMixin
from apiserver.database.fields import (
StrippedStringField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import DbModelMixin, AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
from apiserver.database.model.company import Company
from apiserver.database.model.metadata import MetadataItem
@ -19,23 +22,25 @@ from apiserver.database.model.task.task import Task
class Entry(EmbeddedDocument, ProperDictMixin):
""" Entry representing a task waiting in the queue """
task = StringField(required=True, reference_field=Task)
''' Task ID '''
""" Task ID """
added = DateTimeField(required=True)
''' Added to the queue '''
""" Added to the queue """
class Queue(DbModelMixin, Document):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
}
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name",),
list_fields=("tags", "system_tags", "id"),
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
)
meta = {
'db_alias': Database.backend,
'strict': strict,
"indexes": ["metadata.key", "metadata.type"],
"db_alias": Database.backend,
"strict": strict,
}
id = StringField(primary_key=True)
@ -44,10 +49,12 @@ class Queue(DbModelMixin, Document):
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
tags = SafeSortedListField(
StringField(required=True), default=list, user_set_allowed=True
)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)

View File

@ -159,11 +159,10 @@ external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"execution.parameters.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
"hyperparams.": AttributedDocument._numeric_locale,
}
meta = {
@ -184,7 +183,10 @@ class Task(AttributedDocument):
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
{"fields": ["company", "project"], "collation": _numeric_locale},
{
"fields": ["company", "project"],
"collation": AttributedDocument._numeric_locale,
},
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [

View File

@ -0,0 +1,29 @@
from pymongo.collection import Collection
from pymongo.database import Database
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .utils import _drop_all_indices_from_collections
def _convert_metadata(db: Database, name):
collection: Collection = db[name]
metadata_field = "metadata"
query = {metadata_field: {"$exists": True, "$type": 4}}
for doc in collection.find(filter=query, projection=(metadata_field,)):
metadata = {
ParameterKeyEscaper.escape(item["key"]): item
for item in doc.get(metadata_field, [])
if isinstance(item, dict) and "key" in item
}
collection.update_one(
{"_id": doc["_id"]}, {"$set": {"metadata": metadata}},
)
def migrate_backend(db: Database):
collections = ["model", "queue"]
for name in collections:
_convert_metadata(db, name)
_drop_all_indices_from_collections(db, collections)

View File

@ -226,6 +226,12 @@ create_credentials {
}
}
}
"999.0": ${create_credentials."2.1"} {
request.properties.label {
type: string
description: Optional credentials label
}
}
}
get_credentials {

View File

@ -61,14 +61,14 @@ _definitions {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
description: "System tags. This field is reserved for system use, please don't use it."
items { type: string }
}
framework {
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
@ -98,9 +98,11 @@ _definitions {
additionalProperties: true
}
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@ -407,7 +409,7 @@ update_for_task {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
override_model_id {
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
@ -473,7 +475,7 @@ create {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
@ -529,9 +531,11 @@ create {
}
"2.13": ${create."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@ -568,7 +572,7 @@ edit {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
@ -624,9 +628,11 @@ edit {
}
"2.13": ${edit."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@ -657,7 +663,7 @@ update {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
ready {
description: "Indication if the model is final and can be used by other tasks Default is false."
@ -707,9 +713,11 @@ update {
}
"2.13": ${update."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@ -718,7 +726,7 @@ publish_many {
description: Publish models
request {
properties {
ids.description: "IDs of models to publish"
ids.description: "IDs of the models to publish"
force_publish_task {
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
type: boolean
@ -779,7 +787,7 @@ archive_many {
description: Archive models
request {
properties {
ids.description: "IDs of models to archive"
ids.description: "IDs of the models to archive"
}
}
response {
@ -815,10 +823,9 @@ delete_many {
description: Delete models
request {
properties {
ids.description: "IDs of models to delete"
ids.description: "IDs of the models to delete"
force {
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
"""
description: "Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published."
type: boolean
}
}
@ -975,6 +982,11 @@ add_or_update_metadata {
description: "Metadata items to add or update"
items {"$ref": "#/definitions/metadata_item"}
}
replace_metadata {
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
type: boolean
default: false
}
}
}
response {

View File

@ -42,15 +42,20 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@ -70,6 +75,18 @@ _definitions {
description: "Total run time of all tasks in project (in seconds)"
type: integer
}
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks_24h {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
last_task_run {
description: "The most recent started time of a task"
type: integer
}
status_count {
description: "Status counts"
type: object
@ -78,6 +95,10 @@ _definitions {
description: "Number of 'created' tasks in project"
type: integer
}
completed {
description: "Number of 'completed' tasks in project"
type: integer
}
queued {
description: "Number of 'queued' tasks in project"
type: integer
@ -158,14 +179,14 @@ _definitions {
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@ -299,14 +320,14 @@ create {
type: string
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@ -419,7 +440,6 @@ get_all {
description: "Projects list"
type: array
items { "$ref": "#/definitions/projects_get_all_response_single" }
}
}
}
@ -545,42 +565,6 @@ get_all_ex {
type: boolean
default: true
}
response {
properties {
stats {
properties {
active.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
archived.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
}
}
}
}
}
}
update {
@ -603,14 +587,14 @@ update {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@ -748,7 +732,6 @@ delete {
type: boolean
default: false
}
}
}
response {
@ -881,6 +864,7 @@ get_hyper_parameters {
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
@ -929,6 +913,55 @@ get_hyper_parameters {
}
}
}
get_model_metadata_keys {
"999.0" {
description: """Get a list of all metadata keys used in models within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
type: string
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
type: boolean
default: true
}
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
response {
type: object
properties {
keys {
description: "A list of model keys"
type: array
items {type: string}
}
remaining {
description: "Remaining results"
type: integer
}
total {
description: "Total number of results"
type: integer
}
}
}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"
@ -936,7 +969,6 @@ get_task_tags {
response = ${_definitions.tags_response}
}
}
get_model_tags {
"2.8" {
description: "Get user and system tags used for the models under the specified projects"
@ -1058,4 +1090,4 @@ get_task_parents {
}
}
}
}
}

View File

@ -79,9 +79,11 @@ _definitions {
items { "$ref": "#/definitions/entry" }
}
metadata {
type: array
description: "Queue metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@ -281,6 +283,15 @@ create {
}
}
}
"2.13": ${create."2.4"} {
metadata {
description: "Queue metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
update {
"2.4" {
@ -322,7 +333,15 @@ update {
type: object
additionalProperties: true
}
}
}
}
"2.13": ${update."2.4"} {
metadata {
description: "Queue metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
@ -632,6 +651,11 @@ add_or_update_metadata {
description: "Metadata items to add or update"
items {"$ref": "#/definitions/metadata_item"}
}
replace_metadata {
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
type: boolean
default: false
}
}
}
response {

View File

@ -21,16 +21,14 @@ from apiserver.apimodels.models import (
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
)
from apiserver.bll.model import ModelBLL
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import validate_id
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import (
@ -50,8 +48,8 @@ from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
ModelsBackwardsCompatibility,
validate_metadata,
get_metadata_from_api,
unescape_metadata,
escape_metadata,
)
from apiserver.timing_context import TimingContext
@ -64,19 +62,20 @@ project_bll = ProjectBLL()
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
models = Model.get_many(
company=company_id,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
Metadata.escape_query_parameters(call)
models = Model.get_many(
company=company_id,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
conform_output_tags(call, models[0])
call.result.data = {"model": models[0]}
conform_output_tags(call, models[0])
unescape_metadata(call, models[0])
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
@ -86,25 +85,25 @@ def get_by_task_id(call: APICall, company_id, _):
task_id = call.data["task"]
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not task.models or not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output")
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not task.models or not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output")
model_id = task.models.output[-1].model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
call.result.data = {"model": model_dict}
model_id = task.models.output[-1].model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
unescape_metadata(call, model_dict)
call.result.data = {"model": model_dict}
def _process_include_subprojects(call_data: dict):
@ -121,47 +120,50 @@ def _process_include_subprojects(call_data: dict):
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
_process_include_subprojects(call.data)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
call.result.data = {"models": models, **ret_params}
_process_include_subprojects(call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_by_id_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
call.result.data = {"models": models}
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_by_id_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
call.result.data = {"models": models, **ret_params}
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
@ -189,15 +191,22 @@ create_fields = {
"metadata": list,
}
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata", "system_tags", "tags")
last_update_fields = (
"uri",
"framework",
"design",
"labels",
"ready",
"metadata",
"system_tags",
"tags",
)
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields, validate=True)
metadata = fields.get("metadata")
if metadata:
validate_metadata(metadata)
escape_metadata(fields)
return fields
@ -231,82 +240,80 @@ def update_for_task(call: APICall, company_id, _):
"exactly one field is required", fields=("uri", "override_model_id")
)
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
raise errors.bad_request.InvalidTaskStatus(
f"model can only be updated for tasks in the {allowed_states} states",
**query,
)
if override_model_id:
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=override_model_id
)
else:
if "name" not in call.data:
# use task name if name not provided
call.data["name"] = task.name
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.models and task.models.output:
# model exists, update
model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct()
res.update({"id": model_id, "created": False})
call.result.data = res
return
# new model, create
fields = parse_model_fields(call, create_fields)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
created=now,
last_update=now,
user=call.identity.user,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
project=task.project,
framework=task.execution.framework,
parent=task.models.input[0].model
if task.models and task.models.input
else None,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
**fields,
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
raise errors.bad_request.InvalidTaskStatus(
f"model can only be updated for tasks in the {allowed_states} states",
**query,
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
last_iteration_max=iteration,
models__output=[
ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
)
],
)
if override_model_id:
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=override_model_id
)
else:
if "name" not in call.data:
# use task name if name not provided
call.data["name"] = task.name
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.models and task.models.output:
# model exists, update
model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct()
res.update({"id": model_id, "created": False})
call.result.data = res
return
# new model, create
fields = parse_model_fields(call, create_fields)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
created=now,
last_update=now,
user=call.identity.user,
company=company_id,
project=task.project,
framework=task.execution.framework,
parent=task.models.input[0].model
if task.models and task.models.input
else None,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
last_iteration_max=iteration,
models__output=[
ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
)
],
)
call.result.data = {"id": model.id, "created": True}
call.result.data = {"id": model.id, "created": True}
@endpoint(
@ -319,36 +326,33 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
if req_model.public:
company_id = ""
with translate_errors_context():
project = req_model.project
if project:
validate_id(Project, company=company_id, project=project)
project = req_model.project
if project:
validate_id(Project, company=company_id, project=project)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
user=call.identity.user,
company=company_id,
created=now,
last_update=now,
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
validate_metadata(fields.get("metadata"))
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
user=call.identity.user,
company=company_id,
created=now,
last_update=now,
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
def prepare_update_fields(call, company_id, fields: dict):
@ -383,6 +387,7 @@ def prepare_update_fields(call, company_id, fields: dict):
)
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
return fields
@ -394,89 +399,85 @@ def validate_task(company_id, fields: dict):
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
if fields:
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
if fields:
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
updated = model.update(upsert=False, **fields)
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
updated = model.update(upsert=False, **fields)
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
with translate_errors_context():
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
data = prepare_update_fields(call, company_id, call.data)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
metadata = data.get("metadata")
if metadata:
validate_metadata(metadata)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
conform_output_tags(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
conform_output_tags(call, updated_fields)
unescape_metadata(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@endpoint(
@ -641,26 +642,25 @@ def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
model_id = request.model
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
updated = metadata_add_or_update(
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
)
if updated:
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
return {"updated": updated}
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
return {
"updated": Metadata.edit_metadata(
model,
items=request.metadata,
replace_metadata=request.replace_metadata,
last_update=datetime.utcnow(),
)
}
@endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model
ModelBLL.get_company_model_by_id(
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
if updated:
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
return {"updated": updated}
return {
"updated": Metadata.delete_metadata(
model, keys=request.keys, last_update=datetime.utcnow()
)
}

View File

@ -275,6 +275,23 @@ def get_unique_metric_variants(
call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",)
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"keys": keys,
}
@endpoint(
"projects.get_hyper_parameters",
min_version="2.9",

View File

@ -13,17 +13,19 @@ from apiserver.apimodels.queues import (
QueueMetrics,
AddOrUpdateMetadataRequest,
DeleteMetadataRequest,
GetNextTaskRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
from apiserver.bll.workers import WorkerBLL
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
conform_tags,
get_metadata_from_api,
escape_metadata,
unescape_metadata,
)
from apiserver.utilities import extract_properties_to_lists
@ -36,6 +38,7 @@ def get_by_id(call: APICall, company_id, req_model: QueueRequest):
queue = queue_bll.get_by_id(company_id, req_model.queue)
queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict)
unescape_metadata(call, queue_dict)
call.result.data = {"queue": queue_dict}
@ -49,13 +52,13 @@ def get_by_id(call: APICall):
def get_all_ex(call: APICall):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_queue_infos(
company_id=call.identity.company,
query_dict=call.data,
ret_params=ret_params,
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
call.result.data = {"queues": queues, **ret_params}
@ -63,13 +66,12 @@ def get_all_ex(call: APICall):
def get_all(call: APICall):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_all(
company_id=call.identity.company,
query_dict=call.data,
ret_params=ret_params,
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
call.result.data = {"queues": queues, **ret_params}
@ -83,7 +85,7 @@ def create(call: APICall, company_id, request: CreateRequest):
name=request.name,
tags=tags,
system_tags=system_tags,
metadata=get_metadata_from_api(request.metadata),
metadata=Metadata.metadata_from_api(request.metadata),
)
call.result.data = {"id": queue.id}
@ -97,10 +99,12 @@ def create(call: APICall, company_id, request: CreateRequest):
def update(call: APICall, company_id, req_model: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data, validate=True)
escape_metadata(data)
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@ -121,11 +125,19 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
}
@endpoint("queues.get_next_task", min_version="2.4", request_data_model=QueueRequest)
def get_next_task(call: APICall, company_id, req_model: QueueRequest):
task = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
if task:
call.result.data = {"entry": task.to_proper_dict()}
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
entry = queue_bll.get_next_task(
company_id=company_id, queue_id=req_model.queue
)
if entry:
data = {"entry": entry.to_proper_dict()}
if req_model.get_task_info:
task = Task.objects(id=entry.task).first()
if task:
data["task_info"] = {"company": task.company, "user": task.user}
call.result.data = data
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
@ -245,21 +257,19 @@ def get_queue_metrics(
@endpoint("queues.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {
"updated": metadata_add_or_update(
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
"updated": Metadata.edit_metadata(
queue, items=request.metadata, replace_metadata=request.replace_metadata
)
}
@endpoint("queues.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}

View File

@ -2,7 +2,6 @@ from datetime import datetime
from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
@ -222,22 +221,38 @@ class DockerCmdBackwardsCompatibility:
nested_set(task, cls.field, docker_cmd)
def validate_metadata(metadata: Sequence[dict]):
def escape_metadata(document: dict):
"""
Escape special characters in metadata keys
"""
metadata = document.get("metadata")
if not metadata:
return
keys = [m.get("key") for m in metadata]
unique_keys = set(keys)
unique_keys.discard(None)
if len(keys) != len(set(keys)):
raise errors.bad_request.ValidationError("Metadata keys should be unique")
document["metadata"] = {
ParameterKeyEscaper.escape(k): v
for k, v in metadata.items()
}
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
if not api_metadata:
return api_metadata
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
Unescape special characters in metadata keys
"""
if isinstance(documents, dict):
documents = [documents]
metadata = [m.to_struct() for m in api_metadata]
validate_metadata(metadata)
old_client = call.requested_endpoint_version <= PartialVersion("2.16")
for doc in documents:
if old_client and "metadata" in doc:
doc["metadata"] = []
continue
return metadata
metadata = doc.get("metadata")
if not metadata:
continue
doc["metadata"] = {
ParameterKeyEscaper.unescape(k): v
for k, v in metadata.items()
}

View File

@ -1,12 +1,11 @@
from functools import partial
from typing import Sequence
from apiserver.tests.api_client import APIClient
from apiserver.tests.automated import TestService
class TestQueueAndModelMetadata(TestService):
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
meta1 = {"test_key": {"key": "test_key", "type": "str", "value": "test_value"}}
def test_queue_metas(self):
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
@ -23,20 +22,43 @@ class TestQueueAndModelMetadata(TestService):
)
model_id = self._temp_model("TestMetadata1")
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
self.api.models.edit(model=model_id, metadata=self.meta1)
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
def test_project_meta_query(self):
self._temp_model("TestMetadata", metadata=self.meta1)
project = self.temp_project(name="MetaParent")
model_id = self._temp_model(
"TestMetadata2",
project=project,
metadata={
"test_key": {"key": "test_key", "type": "str", "value": "test_value"},
"test_key2": {"key": "test_key2", "type": "str", "value": "test_value"},
},
)
res = self.api.projects.get_model_metadata_keys()
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
self.assertTrue("test_key" in res["keys"])
self.assertFalse("test_key2" in res["keys"])
model = self.api.models.get_all_ex(
id=[model_id], only_fields=["metadata.test_key"]
).models[0]
self.assertTrue("test_key" in model.metadata)
self.assertFalse("test_key2" in model.metadata)
def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str,
):
assert_meta = partial(self._assertMeta, service=service, entity=entity)
assert_meta(_id=_id, meta=self.meta1)
meta2 = [
{"key": "test1", "type": "str", "value": "data1"},
{"key": "test2", "type": "str", "value": "data2"},
{"key": "test3", "type": "str", "value": "data3"},
]
meta2 = {
"test1": {"key": "test1", "type": "str", "value": "data1"},
"test2": {"key": "test2", "type": "str", "value": "data2"},
"test3": {"key": "test3", "type": "str", "value": "data3"},
}
service.update(**{entity: _id, "metadata": meta2})
assert_meta(_id=_id, meta=meta2)
@ -48,16 +70,17 @@ class TestQueueAndModelMetadata(TestService):
]
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=[meta2[0], *updates])
assert_meta(_id=_id, meta={**meta2, **{u["key"]: u for u in updates}})
res = service.delete_metadata(
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
)
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=meta2[:1])
# noinspection PyTypeChecker
assert_meta(_id=_id, meta=dict(list(meta2.items())[:1]))
def _assertMeta(
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
self, service: APIClient.Service, entity: str, _id: str, meta: dict
):
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
self.assertEqual(res.metadata, meta)