clearml-server/apiserver/services/utils.py

199 lines
6.3 KiB
Python
Raw Normal View History

2021-05-03 14:46:00 +00:00
from datetime import datetime
from typing import Union, Sequence, Tuple
2021-01-05 14:28:49 +00:00
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
2021-01-05 14:28:49 +00:00
from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin
from apiserver.database.utils import partition_tags
from apiserver.service_repo import APICall
2021-05-03 14:46:00 +00:00
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
2021-01-05 14:28:49 +00:00
from apiserver.utilities.partial_version import PartialVersion
def get_tags_filter_dictionary(input_: Filter) -> dict:
if not input_:
return {}
return {
field: vals
for field, vals in (("tags", input_.tags), ("system_tags", input_.system_tags))
if vals
}
def get_tags_response(ret: dict) -> dict:
return {field: sorted(vals) for field, vals in ret.items()}
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
2021-05-03 14:52:54 +00:00
Make sure that tags are always returned sorted
For old clients both tags and system tags are returned in 'tags' field
"""
if isinstance(documents, dict):
documents = [documents]
2021-05-03 14:52:54 +00:00
merge_tags = call.requested_endpoint_version < PartialVersion("2.3")
for doc in documents:
2021-05-03 14:52:54 +00:00
if merge_tags:
system_tags = doc.get("system_tags")
if system_tags:
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
for field in ("system_tags", "tags"):
tags = doc.get(field)
if tags:
doc[field] = sorted(tags)
def conform_tag_fields(call: APICall, document: dict, validate=False):
"""
Upgrade old client tags in place
"""
if "tags" in document:
tags, system_tags = conform_tags(
call, document["tags"], document.get("system_tags"), validate
)
if tags != document.get("tags"):
document["tags"] = tags
if system_tags != document.get("system_tags"):
document["system_tags"] = system_tags
def conform_tags(
call: APICall, tags: Sequence, system_tags: Sequence, validate=False
) -> Tuple[Sequence, Sequence]:
"""
Make sure that 'tags' from the old SDK clients
are correctly split into 'tags' and 'system_tags'
Make sure that there are no duplicate tags
"""
if validate:
validate_tags(tags, system_tags)
if call.requested_endpoint_version < PartialVersion("2.3"):
tags, system_tags = _upgrade_tags(call, tags, system_tags)
return tags, system_tags
def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
if tags is not None and not system_tags:
service_name = call.endpoint_name.partition(".")[0]
entity = service_name[:-1] if service_name.endswith("s") else service_name
return partition_tags(entity, tags)
return tags, system_tags
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
for values in filter(None, (tags, system_tags)):
unsupported = [
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
]
if unsupported:
raise errors.bad_request.FieldsValueError(
"unsupported tag prefix", values=unsupported
)
2021-05-03 14:46:00 +00:00
class ModelsBackwardsCompatibility:
max_version = PartialVersion("2.13")
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
models_field = "models"
@classmethod
def prepare_for_save(cls, call: APICall, fields: dict):
if call.requested_endpoint_version > cls.max_version:
return
for mode, field in cls.mode_to_fields.items():
value = nested_get(fields, field)
2021-05-03 14:48:01 +00:00
if value:
nested_set(
fields,
(cls.models_field, mode),
value=[dict(name=mode, model=value, updated=datetime.utcnow())],
)
2021-05-03 14:46:00 +00:00
nested_delete(fields, field)
@classmethod
def unprepare_from_saved(
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
):
if call.requested_endpoint_version > cls.max_version:
return
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
for task in tasks_data:
for mode, field in cls.mode_to_fields.items():
models = nested_get(task, (cls.models_field, mode))
if not models:
continue
model = models[0] if mode == "input" else models[-1]
if model:
nested_set(task, field, model.get("model"))
2021-05-03 14:48:01 +00:00
class DockerCmdBackwardsCompatibility:
max_version = PartialVersion("2.13")
field = ("execution", "docker_cmd")
@classmethod
def prepare_for_save(cls, call: APICall, fields: dict):
if call.requested_endpoint_version > cls.max_version:
return
docker_cmd = nested_get(fields, cls.field)
if docker_cmd:
image, _, arguments = docker_cmd.partition(" ")
nested_set(fields, ("container", "image"), value=image)
nested_set(fields, ("container", "arguments"), value=arguments)
nested_delete(fields, cls.field)
@classmethod
def unprepare_from_saved(
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
):
2021-05-03 14:48:01 +00:00
if call.requested_endpoint_version > cls.max_version:
return
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
for task in tasks_data:
container = task.get("container")
if not container or not container.get("image"):
continue
docker_cmd = " ".join(
filter(None, map(container.get, ("image", "arguments")))
)
2021-05-03 14:48:01 +00:00
if docker_cmd:
nested_set(task, cls.field, docker_cmd)
def validate_metadata(metadata: Sequence[dict]):
if not metadata:
return
keys = [m.get("key") for m in metadata]
unique_keys = set(keys)
unique_keys.discard(None)
if len(keys) != len(set(keys)):
raise errors.bad_request.ValidationError("Metadata keys should be unique")
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
if not api_metadata:
return api_metadata
metadata = [m.to_struct() for m in api_metadata]
validate_metadata(metadata)
return metadata