clearml-server/apiserver/services/utils.py

87 lines
2.8 KiB
Python
Raw Normal View History

from typing import Union, Sequence, Tuple
2021-01-05 14:28:49 +00:00
from apiserver.apierrors import errors
from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin
from apiserver.database.utils import partition_tags
from apiserver.service_repo import APICall
from apiserver.utilities.partial_version import PartialVersion
def get_tags_filter_dictionary(input_: Filter) -> dict:
if not input_:
return {}
return {
field: vals
for field, vals in (("tags", input_.tags), ("system_tags", input_.system_tags))
if vals
}
def get_tags_response(ret: dict) -> dict:
return {field: sorted(vals) for field, vals in ret.items()}
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
For old clients both tags and system tags are returned in 'tags' field
"""
if call.requested_endpoint_version >= PartialVersion("2.3"):
return
if isinstance(documents, dict):
documents = [documents]
for doc in documents:
system_tags = doc.get("system_tags")
if system_tags:
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
def conform_tag_fields(call: APICall, document: dict, validate=False):
"""
Upgrade old client tags in place
"""
if "tags" in document:
tags, system_tags = conform_tags(
call, document["tags"], document.get("system_tags"), validate
)
if tags != document.get("tags"):
document["tags"] = tags
if system_tags != document.get("system_tags"):
document["system_tags"] = system_tags
def conform_tags(
call: APICall, tags: Sequence, system_tags: Sequence, validate=False
) -> Tuple[Sequence, Sequence]:
"""
Make sure that 'tags' from the old SDK clients
are correctly split into 'tags' and 'system_tags'
Make sure that there are no duplicate tags
"""
if validate:
validate_tags(tags, system_tags)
if call.requested_endpoint_version < PartialVersion("2.3"):
tags, system_tags = _upgrade_tags(call, tags, system_tags)
return tags, system_tags
def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
if tags is not None and not system_tags:
service_name = call.endpoint_name.partition(".")[0]
entity = service_name[:-1] if service_name.endswith("s") else service_name
return partition_tags(entity, tags)
return tags, system_tags
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
for values in filter(None, (tags, system_tags)):
unsupported = [
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
]
if unsupported:
raise errors.bad_request.FieldsValueError(
"unsupported tag prefix", values=unsupported
)