mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add field_mappings to organizations download endpoints
This commit is contained in:
parent
ed86750b24
commit
bc2fe28bdd
@ -1,9 +1,10 @@
|
|||||||
from enum import auto
|
from enum import auto
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from jsonmodels import fields, models
|
from jsonmodels import fields, models
|
||||||
from jsonmodels.validators import Length
|
from jsonmodels.validators import Length
|
||||||
|
|
||||||
from apiserver.apimodels import DictField, ActualEnumField
|
from apiserver.apimodels import DictField, ActualEnumField, ScalarField
|
||||||
from apiserver.utilities.stringenum import StringEnum
|
from apiserver.utilities.stringenum import StringEnum
|
||||||
|
|
||||||
|
|
||||||
@ -34,14 +35,28 @@ class EntityType(StringEnum):
|
|||||||
model = auto()
|
model = auto()
|
||||||
|
|
||||||
|
|
||||||
class PrepareDownloadForGetAll(models.Base):
|
class ValueMapping(models.Base):
|
||||||
|
key = ScalarField(nullable=True)
|
||||||
|
value = ScalarField(nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class FieldMapping(models.Base):
|
||||||
|
field = fields.StringField(required=True)
|
||||||
|
name = fields.StringField()
|
||||||
|
values: Sequence[ValueMapping] = fields.ListField(items_types=[ValueMapping])
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareDownloadForGetAllRequest(models.Base):
|
||||||
entity_type = ActualEnumField(EntityType)
|
entity_type = ActualEnumField(EntityType)
|
||||||
allow_public = fields.BoolField(default=True)
|
allow_public = fields.BoolField(default=True)
|
||||||
search_hidden = fields.BoolField(default=False)
|
search_hidden = fields.BoolField(default=False)
|
||||||
only_fields = fields.ListField(
|
only_fields = fields.ListField(
|
||||||
items_types=[str], validators=[Length(1)], required=True
|
items_types=[str], validators=[Length(1)], required=True
|
||||||
)
|
)
|
||||||
|
field_mappings: Sequence[FieldMapping] = fields.ListField(
|
||||||
|
items_types=[FieldMapping], validators=[Length(1)], required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DownloadForGetAll(models.Base):
|
class DownloadForGetAllRequest(models.Base):
|
||||||
prepare_id = fields.StringField(required=True)
|
prepare_id = fields.StringField(required=True)
|
||||||
|
@ -1,5 +1,38 @@
|
|||||||
_description: "This service provides organization level operations"
|
_description: "This service provides organization level operations"
|
||||||
|
_definitions {
|
||||||
|
value_mapping {
|
||||||
|
type: object
|
||||||
|
required: [key, value]
|
||||||
|
properties {
|
||||||
|
key {
|
||||||
|
description: Original value
|
||||||
|
type: object
|
||||||
|
}
|
||||||
|
value {
|
||||||
|
description: Translated value
|
||||||
|
type: object
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
field_mapping {
|
||||||
|
type: object
|
||||||
|
required: [field]
|
||||||
|
properties {
|
||||||
|
field {
|
||||||
|
description: The source field name as specified in the only_fields
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
name {
|
||||||
|
description: The column name in the exported csv file
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
values {
|
||||||
|
type: array
|
||||||
|
items { "$ref": "#/definitions/value_mapping"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
get_tags {
|
get_tags {
|
||||||
"2.8" {
|
"2.8" {
|
||||||
description: "Get all the user and system tags used for the company tasks and models"
|
description: "Get all the user and system tags used for the company tasks and models"
|
||||||
@ -202,7 +235,7 @@ prepare_download_for_get_all {
|
|||||||
description: Prepares download from get_all_ex parameters
|
description: Prepares download from get_all_ex parameters
|
||||||
request {
|
request {
|
||||||
type: object
|
type: object
|
||||||
required: [ entity_type, only_fields]
|
required: [ entity_type, only_fields, field_mappings]
|
||||||
properties {
|
properties {
|
||||||
only_fields {
|
only_fields {
|
||||||
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||||
@ -227,6 +260,11 @@ prepare_download_for_get_all {
|
|||||||
model
|
model
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
field_mappings {
|
||||||
|
description: The name and value mappings for the exported fields. The fields that are not in the mappings will not be exported
|
||||||
|
type: array
|
||||||
|
items { "$ref": "#/definitions/field_mapping"}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response {
|
response {
|
||||||
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Mapping, Type, Sequence, Optional, Callable
|
from typing import Mapping, Type, Sequence, Optional, Callable, Hashable
|
||||||
|
|
||||||
from flask import stream_with_context
|
from flask import stream_with_context
|
||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
@ -12,9 +12,9 @@ from apiserver.apierrors import errors
|
|||||||
from apiserver.apimodels.organization import (
|
from apiserver.apimodels.organization import (
|
||||||
TagsRequest,
|
TagsRequest,
|
||||||
EntitiesCountRequest,
|
EntitiesCountRequest,
|
||||||
DownloadForGetAll,
|
DownloadForGetAllRequest,
|
||||||
EntityType,
|
EntityType,
|
||||||
PrepareDownloadForGetAll,
|
PrepareDownloadForGetAllRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.model import Metadata
|
from apiserver.bll.model import Metadata
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
@ -47,7 +47,10 @@ def get_tags(call: APICall, company, request: TagsRequest):
|
|||||||
ret = defaultdict(set)
|
ret = defaultdict(set)
|
||||||
for entity in Tags.Model, Tags.Task:
|
for entity in Tags.Model, Tags.Task:
|
||||||
tags = org_bll.get_tags(
|
tags = org_bll.get_tags(
|
||||||
company, entity, include_system=request.include_system, filter_=filter_dict,
|
company,
|
||||||
|
entity,
|
||||||
|
include_system=request.include_system,
|
||||||
|
filter_=filter_dict,
|
||||||
)
|
)
|
||||||
for field, vals in tags.items():
|
for field, vals in tags.items():
|
||||||
ret[field] |= vals
|
ret[field] |= vals
|
||||||
@ -149,7 +152,9 @@ def _get_download_getter_fn(
|
|||||||
|
|
||||||
def get_model_data() -> Sequence[dict]:
|
def get_model_data() -> Sequence[dict]:
|
||||||
models = Model.get_many_with_join(
|
models = Model.get_many_with_join(
|
||||||
company=company, query_dict=call_data, allow_public=allow_public,
|
company=company,
|
||||||
|
query_dict=call_data,
|
||||||
|
allow_public=allow_public,
|
||||||
)
|
)
|
||||||
conform_model_data(call, models)
|
conform_model_data(call, models)
|
||||||
return models
|
return models
|
||||||
@ -182,9 +187,26 @@ download_conf = config.get("services.organization.download")
|
|||||||
|
|
||||||
@endpoint("organization.prepare_download_for_get_all")
|
@endpoint("organization.prepare_download_for_get_all")
|
||||||
def prepare_download_for_get_all(
|
def prepare_download_for_get_all(
|
||||||
call: APICall, company: str, request: PrepareDownloadForGetAll
|
call: APICall, company: str, request: PrepareDownloadForGetAllRequest
|
||||||
):
|
):
|
||||||
# validate input params
|
# validate input params
|
||||||
|
field_names = set()
|
||||||
|
for fm in request.field_mappings:
|
||||||
|
name = fm.name or fm.field
|
||||||
|
if name in field_names:
|
||||||
|
raise errors.bad_request.ValidationError(
|
||||||
|
f"Field_name appears more than once in field_mappings: {str(name)}"
|
||||||
|
)
|
||||||
|
field_names.add(name)
|
||||||
|
if fm.values:
|
||||||
|
value_keys = set()
|
||||||
|
for v in fm.values:
|
||||||
|
if v.key in value_keys:
|
||||||
|
raise errors.bad_request.ValidationError(
|
||||||
|
f"Value key appears more than once in field_mappings: {str(v.key)}"
|
||||||
|
)
|
||||||
|
value_keys.add(v.key)
|
||||||
|
|
||||||
getter = _get_download_getter_fn(
|
getter = _get_download_getter_fn(
|
||||||
company,
|
company,
|
||||||
call,
|
call,
|
||||||
@ -205,7 +227,7 @@ def prepare_download_for_get_all(
|
|||||||
|
|
||||||
|
|
||||||
@endpoint("organization.download_for_get_all")
|
@endpoint("organization.download_for_get_all")
|
||||||
def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
def download_for_get_all(call: APICall, company, request: DownloadForGetAllRequest):
|
||||||
request_data = redis.get(f"get_all_download_{request.prepare_id}")
|
request_data = redis.get(f"get_all_download_{request.prepare_id}")
|
||||||
if not request_data:
|
if not request_data:
|
||||||
raise errors.bad_request.InvalidId(
|
raise errors.bad_request.InvalidId(
|
||||||
@ -214,17 +236,25 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
call_data = json.loads(request_data)
|
call_data = json.loads(request_data)
|
||||||
request = PrepareDownloadForGetAll(**call_data)
|
request = PrepareDownloadForGetAllRequest(**call_data)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
|
raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
|
||||||
|
|
||||||
class SingleLine:
|
class SingleLine:
|
||||||
def write(self, line: str) -> str:
|
@staticmethod
|
||||||
|
def write(line: str) -> str:
|
||||||
return line
|
return line
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
projection = call_data.get("only_fields", [])
|
field_mappings = {
|
||||||
|
mapping.get("name", mapping["field"]): {
|
||||||
|
"field_path": mapping["field"].split("."),
|
||||||
|
"values": {
|
||||||
|
v.get("key"): v.get("value") for v in (mapping.get("values") or [])
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for mapping in call_data.get("field_mappings", [])
|
||||||
|
}
|
||||||
get_fn = _get_download_getter_fn(
|
get_fn = _get_download_getter_fn(
|
||||||
company,
|
company,
|
||||||
call,
|
call,
|
||||||
@ -233,26 +263,31 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
|||||||
entity_type=request.entity_type,
|
entity_type=request.entity_type,
|
||||||
)
|
)
|
||||||
if not get_fn:
|
if not get_fn:
|
||||||
yield csv.writer(SingleLine()).writerow(projection)
|
yield csv.writer(SingleLine()).writerow(field_mappings)
|
||||||
return
|
return
|
||||||
|
|
||||||
fields = [path.split(".") for path in projection]
|
def get_entity_field_as_str(
|
||||||
|
data: dict, field_path: Sequence[str], values: Mapping
|
||||||
def get_entity_field_as_str(data: dict, field: Sequence[str]) -> str:
|
) -> str:
|
||||||
val = nested_get(data, field, "")
|
val = nested_get(data, field_path, "")
|
||||||
if isinstance(val, dict):
|
if isinstance(val, dict):
|
||||||
val = val.get("id", "")
|
val = val.get("id", "")
|
||||||
|
if values and isinstance(val, Hashable):
|
||||||
|
val = values.get(val, val)
|
||||||
|
|
||||||
return str(val)
|
return str(val)
|
||||||
|
|
||||||
def get_projected_fields(data: dict) -> Sequence[str]:
|
def get_projected_fields(data: dict) -> Sequence[str]:
|
||||||
return [get_entity_field_as_str(data, f) for f in fields]
|
return [
|
||||||
|
get_entity_field_as_str(
|
||||||
|
data, field_path=m["field_path"], values=m["values"]
|
||||||
|
)
|
||||||
|
for m in field_mappings.values()
|
||||||
|
]
|
||||||
|
|
||||||
with ThreadPoolExecutor(1) as pool:
|
with ThreadPoolExecutor(1) as pool:
|
||||||
page = 0
|
page = 0
|
||||||
page_size = int(
|
page_size = int(download_conf.get("batch_size", 500))
|
||||||
download_conf.get("batch_size", 500)
|
|
||||||
)
|
|
||||||
future = pool.submit(get_fn, page, page_size)
|
future = pool.submit(get_fn, page, page_size)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -266,12 +301,12 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
|||||||
with StringIO() as fp:
|
with StringIO() as fp:
|
||||||
writer = csv.writer(fp)
|
writer = csv.writer(fp)
|
||||||
if page == 1:
|
if page == 1:
|
||||||
writer.writerow(projection)
|
writer.writerow(field_mappings)
|
||||||
writer.writerows(get_projected_fields(r) for r in result)
|
writer.writerows(get_projected_fields(r) for r in result)
|
||||||
yield fp.getvalue()
|
yield fp.getvalue()
|
||||||
|
|
||||||
if page == 0:
|
if page == 0:
|
||||||
yield csv.writer(SingleLine()).writerow(projection)
|
yield csv.writer(SingleLine()).writerow(field_mappings)
|
||||||
|
|
||||||
def get_project_name() -> Optional[str]:
|
def get_project_name() -> Optional[str]:
|
||||||
projects = call_data.get("project")
|
projects = call_data.get("project")
|
||||||
@ -287,12 +322,10 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
|||||||
if not project:
|
if not project:
|
||||||
return
|
return
|
||||||
|
|
||||||
return project.basename[:download_conf.get("max_project_name_length", 60)]
|
return project.basename[: download_conf.get("max_project_name_length", 60)]
|
||||||
|
|
||||||
call.result.filename = "-".join(
|
call.result.filename = "-".join(
|
||||||
filter(
|
filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv"))
|
||||||
None, ("clearml", get_project_name(), f"{request.entity_type}s.csv")
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
call.result.content_type = "text/csv"
|
call.result.content_type = "text/csv"
|
||||||
call.result.raw_data = stream_with_context(generate())
|
call.result.raw_data = stream_with_context(generate())
|
||||||
|
Loading…
Reference in New Issue
Block a user