mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 18:54:20 +00:00
Add field_mappings to organizations download endpoints
This commit is contained in:
parent
ed86750b24
commit
bc2fe28bdd
@ -1,9 +1,10 @@
|
||||
from enum import auto
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import fields, models
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels import DictField, ActualEnumField
|
||||
from apiserver.apimodels import DictField, ActualEnumField, ScalarField
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
@ -34,14 +35,28 @@ class EntityType(StringEnum):
|
||||
model = auto()
|
||||
|
||||
|
||||
class PrepareDownloadForGetAll(models.Base):
|
||||
class ValueMapping(models.Base):
|
||||
key = ScalarField(nullable=True)
|
||||
value = ScalarField(nullable=True)
|
||||
|
||||
|
||||
class FieldMapping(models.Base):
|
||||
field = fields.StringField(required=True)
|
||||
name = fields.StringField()
|
||||
values: Sequence[ValueMapping] = fields.ListField(items_types=[ValueMapping])
|
||||
|
||||
|
||||
class PrepareDownloadForGetAllRequest(models.Base):
|
||||
entity_type = ActualEnumField(EntityType)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
only_fields = fields.ListField(
|
||||
items_types=[str], validators=[Length(1)], required=True
|
||||
)
|
||||
field_mappings: Sequence[FieldMapping] = fields.ListField(
|
||||
items_types=[FieldMapping], validators=[Length(1)], required=True
|
||||
)
|
||||
|
||||
|
||||
class DownloadForGetAll(models.Base):
|
||||
class DownloadForGetAllRequest(models.Base):
|
||||
prepare_id = fields.StringField(required=True)
|
||||
|
@ -1,5 +1,38 @@
|
||||
_description: "This service provides organization level operations"
|
||||
|
||||
_definitions {
|
||||
value_mapping {
|
||||
type: object
|
||||
required: [key, value]
|
||||
properties {
|
||||
key {
|
||||
description: Original value
|
||||
type: object
|
||||
}
|
||||
value {
|
||||
description: Translated value
|
||||
type: object
|
||||
}
|
||||
}
|
||||
}
|
||||
field_mapping {
|
||||
type: object
|
||||
required: [field]
|
||||
properties {
|
||||
field {
|
||||
description: The source field name as specified in the only_fields
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: The column name in the exported csv file
|
||||
type: string
|
||||
}
|
||||
values {
|
||||
type: array
|
||||
items { "$ref": "#/definitions/value_mapping"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_tags {
|
||||
"2.8" {
|
||||
description: "Get all the user and system tags used for the company tasks and models"
|
||||
@ -202,7 +235,7 @@ prepare_download_for_get_all {
|
||||
description: Prepares download from get_all_ex parameters
|
||||
request {
|
||||
type: object
|
||||
required: [ entity_type, only_fields]
|
||||
required: [ entity_type, only_fields, field_mappings]
|
||||
properties {
|
||||
only_fields {
|
||||
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||
@ -227,6 +260,11 @@ prepare_download_for_get_all {
|
||||
model
|
||||
]
|
||||
}
|
||||
field_mappings {
|
||||
description: The name and value mappings for the exported fields. The fields that are not in the mappings will not be exported
|
||||
type: array
|
||||
items { "$ref": "#/definitions/field_mapping"}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import StringIO
|
||||
from operator import itemgetter
|
||||
from typing import Mapping, Type, Sequence, Optional, Callable
|
||||
from typing import Mapping, Type, Sequence, Optional, Callable, Hashable
|
||||
|
||||
from flask import stream_with_context
|
||||
from mongoengine import Q
|
||||
@ -12,9 +12,9 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.organization import (
|
||||
TagsRequest,
|
||||
EntitiesCountRequest,
|
||||
DownloadForGetAll,
|
||||
DownloadForGetAllRequest,
|
||||
EntityType,
|
||||
PrepareDownloadForGetAll,
|
||||
PrepareDownloadForGetAllRequest,
|
||||
)
|
||||
from apiserver.bll.model import Metadata
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
@ -47,7 +47,10 @@ def get_tags(call: APICall, company, request: TagsRequest):
|
||||
ret = defaultdict(set)
|
||||
for entity in Tags.Model, Tags.Task:
|
||||
tags = org_bll.get_tags(
|
||||
company, entity, include_system=request.include_system, filter_=filter_dict,
|
||||
company,
|
||||
entity,
|
||||
include_system=request.include_system,
|
||||
filter_=filter_dict,
|
||||
)
|
||||
for field, vals in tags.items():
|
||||
ret[field] |= vals
|
||||
@ -149,7 +152,9 @@ def _get_download_getter_fn(
|
||||
|
||||
def get_model_data() -> Sequence[dict]:
|
||||
models = Model.get_many_with_join(
|
||||
company=company, query_dict=call_data, allow_public=allow_public,
|
||||
company=company,
|
||||
query_dict=call_data,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
conform_model_data(call, models)
|
||||
return models
|
||||
@ -182,9 +187,26 @@ download_conf = config.get("services.organization.download")
|
||||
|
||||
@endpoint("organization.prepare_download_for_get_all")
|
||||
def prepare_download_for_get_all(
|
||||
call: APICall, company: str, request: PrepareDownloadForGetAll
|
||||
call: APICall, company: str, request: PrepareDownloadForGetAllRequest
|
||||
):
|
||||
# validate input params
|
||||
field_names = set()
|
||||
for fm in request.field_mappings:
|
||||
name = fm.name or fm.field
|
||||
if name in field_names:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"Field_name appears more than once in field_mappings: {str(name)}"
|
||||
)
|
||||
field_names.add(name)
|
||||
if fm.values:
|
||||
value_keys = set()
|
||||
for v in fm.values:
|
||||
if v.key in value_keys:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"Value key appears more than once in field_mappings: {str(v.key)}"
|
||||
)
|
||||
value_keys.add(v.key)
|
||||
|
||||
getter = _get_download_getter_fn(
|
||||
company,
|
||||
call,
|
||||
@ -205,7 +227,7 @@ def prepare_download_for_get_all(
|
||||
|
||||
|
||||
@endpoint("organization.download_for_get_all")
|
||||
def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
||||
def download_for_get_all(call: APICall, company, request: DownloadForGetAllRequest):
|
||||
request_data = redis.get(f"get_all_download_{request.prepare_id}")
|
||||
if not request_data:
|
||||
raise errors.bad_request.InvalidId(
|
||||
@ -214,17 +236,25 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
||||
|
||||
try:
|
||||
call_data = json.loads(request_data)
|
||||
request = PrepareDownloadForGetAll(**call_data)
|
||||
request = PrepareDownloadForGetAllRequest(**call_data)
|
||||
except Exception as ex:
|
||||
raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
|
||||
|
||||
class SingleLine:
|
||||
def write(self, line: str) -> str:
|
||||
@staticmethod
|
||||
def write(line: str) -> str:
|
||||
return line
|
||||
|
||||
def generate():
|
||||
projection = call_data.get("only_fields", [])
|
||||
|
||||
field_mappings = {
|
||||
mapping.get("name", mapping["field"]): {
|
||||
"field_path": mapping["field"].split("."),
|
||||
"values": {
|
||||
v.get("key"): v.get("value") for v in (mapping.get("values") or [])
|
||||
},
|
||||
}
|
||||
for mapping in call_data.get("field_mappings", [])
|
||||
}
|
||||
get_fn = _get_download_getter_fn(
|
||||
company,
|
||||
call,
|
||||
@ -233,26 +263,31 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
||||
entity_type=request.entity_type,
|
||||
)
|
||||
if not get_fn:
|
||||
yield csv.writer(SingleLine()).writerow(projection)
|
||||
yield csv.writer(SingleLine()).writerow(field_mappings)
|
||||
return
|
||||
|
||||
fields = [path.split(".") for path in projection]
|
||||
|
||||
def get_entity_field_as_str(data: dict, field: Sequence[str]) -> str:
|
||||
val = nested_get(data, field, "")
|
||||
def get_entity_field_as_str(
|
||||
data: dict, field_path: Sequence[str], values: Mapping
|
||||
) -> str:
|
||||
val = nested_get(data, field_path, "")
|
||||
if isinstance(val, dict):
|
||||
val = val.get("id", "")
|
||||
if values and isinstance(val, Hashable):
|
||||
val = values.get(val, val)
|
||||
|
||||
return str(val)
|
||||
|
||||
def get_projected_fields(data: dict) -> Sequence[str]:
|
||||
return [get_entity_field_as_str(data, f) for f in fields]
|
||||
return [
|
||||
get_entity_field_as_str(
|
||||
data, field_path=m["field_path"], values=m["values"]
|
||||
)
|
||||
for m in field_mappings.values()
|
||||
]
|
||||
|
||||
with ThreadPoolExecutor(1) as pool:
|
||||
page = 0
|
||||
page_size = int(
|
||||
download_conf.get("batch_size", 500)
|
||||
)
|
||||
page_size = int(download_conf.get("batch_size", 500))
|
||||
future = pool.submit(get_fn, page, page_size)
|
||||
|
||||
while True:
|
||||
@ -266,12 +301,12 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
||||
with StringIO() as fp:
|
||||
writer = csv.writer(fp)
|
||||
if page == 1:
|
||||
writer.writerow(projection)
|
||||
writer.writerow(field_mappings)
|
||||
writer.writerows(get_projected_fields(r) for r in result)
|
||||
yield fp.getvalue()
|
||||
|
||||
if page == 0:
|
||||
yield csv.writer(SingleLine()).writerow(projection)
|
||||
yield csv.writer(SingleLine()).writerow(field_mappings)
|
||||
|
||||
def get_project_name() -> Optional[str]:
|
||||
projects = call_data.get("project")
|
||||
@ -287,12 +322,10 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
||||
if not project:
|
||||
return
|
||||
|
||||
return project.basename[:download_conf.get("max_project_name_length", 60)]
|
||||
return project.basename[: download_conf.get("max_project_name_length", 60)]
|
||||
|
||||
call.result.filename = "-".join(
|
||||
filter(
|
||||
None, ("clearml", get_project_name(), f"{request.entity_type}s.csv")
|
||||
)
|
||||
filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv"))
|
||||
)
|
||||
call.result.content_type = "text/csv"
|
||||
call.result.raw_data = stream_with_context(generate())
|
||||
|
Loading…
Reference in New Issue
Block a user