clearml-server/apiserver/services/utils.py

271 lines
8.2 KiB
Python

from datetime import datetime
from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.organization import Filter
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.database.utils import partition_tags
from apiserver.service_repo import APICall
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from apiserver.utilities.partial_version import PartialVersion
def process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
def get_tags_filter_dictionary(input_: Filter) -> dict:
if not input_:
return {}
return {
field: vals
for field, vals in (("tags", input_.tags), ("system_tags", input_.system_tags))
if vals
}
def sort_tags_response(ret: dict) -> dict:
return {field: sorted(vals) for field, vals in ret.items()}
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
Make sure that tags are always returned sorted
For old clients both tags and system tags are returned in 'tags' field
"""
if isinstance(documents, dict):
documents = [documents]
merge_tags = call.requested_endpoint_version < PartialVersion("2.3")
for doc in documents:
if merge_tags:
system_tags = doc.get("system_tags")
if system_tags:
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
for field in ("system_tags", "tags"):
tags = doc.get(field)
if tags:
doc[field] = sorted(tags)
def conform_tag_fields(call: APICall, document: dict, validate=False):
"""
Upgrade old client tags in place
"""
if "tags" in document:
tags, system_tags = conform_tags(
call, document["tags"], document.get("system_tags"), validate
)
if tags != document.get("tags"):
document["tags"] = tags
if system_tags != document.get("system_tags"):
document["system_tags"] = system_tags
def conform_tags(
call: APICall, tags: Sequence, system_tags: Sequence, validate=False
) -> Tuple[Sequence, Sequence]:
"""
Make sure that 'tags' from the old SDK clients
are correctly split into 'tags' and 'system_tags'
Make sure that there are no duplicate tags
"""
if validate:
validate_tags(tags, system_tags)
if call.requested_endpoint_version < PartialVersion("2.3"):
tags, system_tags = _upgrade_tags(call, tags, system_tags)
return tags, system_tags
def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
if tags is not None and not system_tags:
service_name = call.endpoint_name.partition(".")[0]
entity = service_name[:-1] if service_name.endswith("s") else service_name
return partition_tags(entity, tags)
return tags, system_tags
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
for values in filter(None, (tags, system_tags)):
unsupported = [
t for t in values if t.startswith(GetMixin.NewListFieldBucketHelper.op_prefix)
]
if unsupported:
raise errors.bad_request.FieldsValueError(
"unsupported tag prefix", values=unsupported
)
def escape_dict(data: dict) -> dict:
if not data:
return data
return {ParameterKeyEscaper.escape(k): v for k, v in data.items()}
def unescape_dict(data: dict) -> dict:
if not data:
return data
return {ParameterKeyEscaper.unescape(k): v for k, v in data.items()}
def escape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
if isinstance(path, str):
path = (path,)
data = nested_get(fields, path)
if not data or not isinstance(data, dict):
return
nested_set(fields, path, escape_dict(data))
def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
if isinstance(path, str):
path = (path,)
data = nested_get(fields, path)
if not data or not isinstance(data, dict):
return
nested_set(fields, path, unescape_dict(data))
class ModelsBackwardsCompatibility:
max_version = PartialVersion("2.13")
mode_to_fields = {
TaskModelTypes.input: ("execution", "model"),
TaskModelTypes.output: ("output", "model"),
}
models_field = "models"
@classmethod
def prepare_for_save(cls, call: APICall, fields: dict):
if call.requested_endpoint_version >= cls.max_version:
return
for mode, field in cls.mode_to_fields.items():
value = nested_get(fields, field)
if value is None:
continue
val = [
dict(
name=TaskModelNames[mode],
model=value,
updated=datetime.utcnow(),
)
] if value else []
nested_set(fields, (cls.models_field, mode), value=val)
nested_delete(fields, field)
@classmethod
def unprepare_from_saved(
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
):
if call.requested_endpoint_version >= cls.max_version:
return
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
for task in tasks_data:
for mode, field in cls.mode_to_fields.items():
models = nested_get(task, (cls.models_field, mode))
if not models:
continue
model = models[0] if mode == TaskModelTypes.input else models[-1]
if model:
nested_set(task, field, model.get("model"))
class DockerCmdBackwardsCompatibility:
max_version = PartialVersion("2.13")
field = ("execution", "docker_cmd")
@classmethod
def prepare_for_save(cls, call: APICall, fields: dict):
if call.requested_endpoint_version >= cls.max_version:
return
docker_cmd = nested_get(fields, cls.field)
if docker_cmd is not None:
image, _, arguments = docker_cmd.partition(" ")
nested_set(fields, ("container", "image"), value=image)
nested_set(fields, ("container", "arguments"), value=arguments)
nested_delete(fields, cls.field)
@classmethod
def unprepare_from_saved(
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
):
if call.requested_endpoint_version >= cls.max_version:
return
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
for task in tasks_data:
container = task.get("container")
if not container or not container.get("image"):
continue
docker_cmd = " ".join(
filter(None, map(container.get, ("image", "arguments")))
)
if docker_cmd:
nested_set(task, cls.field, docker_cmd)
def escape_metadata(document: dict):
"""
Escape special characters in metadata keys
"""
metadata = document.get("metadata")
if not metadata:
return
document["metadata"] = {
ParameterKeyEscaper.escape(k): v
for k, v in metadata.items()
}
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
Unescape special characters in metadata keys
"""
if isinstance(documents, dict):
documents = [documents]
old_client = call.requested_endpoint_version <= PartialVersion("2.16")
for doc in documents:
if old_client and "metadata" in doc:
doc["metadata"] = []
continue
metadata = doc.get("metadata")
if not metadata:
continue
doc["metadata"] = {
ParameterKeyEscaper.unescape(k): v
for k, v in metadata.items()
}