Add field_mappings to organizations download endpoints

This commit is contained in:
allegroai 2023-07-26 18:39:41 +03:00
parent ed86750b24
commit bc2fe28bdd
3 changed files with 117 additions and 31 deletions

View File

@ -1,9 +1,10 @@
from enum import auto
from typing import Sequence
from jsonmodels import fields, models
from jsonmodels.validators import Length
from apiserver.apimodels import DictField, ActualEnumField
from apiserver.apimodels import DictField, ActualEnumField, ScalarField
from apiserver.utilities.stringenum import StringEnum
@ -34,14 +35,28 @@ class EntityType(StringEnum):
model = auto()
class PrepareDownloadForGetAll(models.Base):
class ValueMapping(models.Base):
key = ScalarField(nullable=True)
value = ScalarField(nullable=True)
class FieldMapping(models.Base):
field = fields.StringField(required=True)
name = fields.StringField()
values: Sequence[ValueMapping] = fields.ListField(items_types=[ValueMapping])
class PrepareDownloadForGetAllRequest(models.Base):
entity_type = ActualEnumField(EntityType)
allow_public = fields.BoolField(default=True)
search_hidden = fields.BoolField(default=False)
only_fields = fields.ListField(
items_types=[str], validators=[Length(1)], required=True
)
field_mappings: Sequence[FieldMapping] = fields.ListField(
items_types=[FieldMapping], validators=[Length(1)], required=True
)
class DownloadForGetAll(models.Base):
class DownloadForGetAllRequest(models.Base):
prepare_id = fields.StringField(required=True)

View File

@ -1,5 +1,38 @@
_description: "This service provides organization level operations"
_definitions {
value_mapping {
type: object
required: [key, value]
properties {
key {
description: Original value
type: object
}
value {
description: Translated value
type: object
}
}
}
field_mapping {
type: object
required: [field]
properties {
field {
description: The source field name as specified in the only_fields
type: string
}
name {
description: The column name in the exported csv file
type: string
}
values {
type: array
items { "$ref": "#/definitions/value_mapping"}
}
}
}
}
get_tags {
"2.8" {
description: "Get all the user and system tags used for the company tasks and models"
@ -202,7 +235,7 @@ prepare_download_for_get_all {
description: Prepares download from get_all_ex parameters
request {
type: object
required: [ entity_type, only_fields]
required: [ entity_type, only_fields, field_mappings]
properties {
only_fields {
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
@ -227,6 +260,11 @@ prepare_download_for_get_all {
model
]
}
field_mappings {
description: The name and value mappings for the exported fields. The fields that are not in the mappings will not be exported
type: array
items { "$ref": "#/definitions/field_mapping"}
}
}
}
response {

View File

@ -3,7 +3,7 @@ from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from operator import itemgetter
from typing import Mapping, Type, Sequence, Optional, Callable
from typing import Mapping, Type, Sequence, Optional, Callable, Hashable
from flask import stream_with_context
from mongoengine import Q
@ -12,9 +12,9 @@ from apiserver.apierrors import errors
from apiserver.apimodels.organization import (
TagsRequest,
EntitiesCountRequest,
DownloadForGetAll,
DownloadForGetAllRequest,
EntityType,
PrepareDownloadForGetAll,
PrepareDownloadForGetAllRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.organization import OrgBLL, Tags
@ -47,7 +47,10 @@ def get_tags(call: APICall, company, request: TagsRequest):
ret = defaultdict(set)
for entity in Tags.Model, Tags.Task:
tags = org_bll.get_tags(
company, entity, include_system=request.include_system, filter_=filter_dict,
company,
entity,
include_system=request.include_system,
filter_=filter_dict,
)
for field, vals in tags.items():
ret[field] |= vals
@ -149,7 +152,9 @@ def _get_download_getter_fn(
def get_model_data() -> Sequence[dict]:
models = Model.get_many_with_join(
company=company, query_dict=call_data, allow_public=allow_public,
company=company,
query_dict=call_data,
allow_public=allow_public,
)
conform_model_data(call, models)
return models
@ -182,9 +187,26 @@ download_conf = config.get("services.organization.download")
@endpoint("organization.prepare_download_for_get_all")
def prepare_download_for_get_all(
call: APICall, company: str, request: PrepareDownloadForGetAll
call: APICall, company: str, request: PrepareDownloadForGetAllRequest
):
# validate input params
field_names = set()
for fm in request.field_mappings:
name = fm.name or fm.field
if name in field_names:
raise errors.bad_request.ValidationError(
f"Field_name appears more than once in field_mappings: {str(name)}"
)
field_names.add(name)
if fm.values:
value_keys = set()
for v in fm.values:
if v.key in value_keys:
raise errors.bad_request.ValidationError(
f"Value key appears more than once in field_mappings: {str(v.key)}"
)
value_keys.add(v.key)
getter = _get_download_getter_fn(
company,
call,
@ -205,7 +227,7 @@ def prepare_download_for_get_all(
@endpoint("organization.download_for_get_all")
def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
def download_for_get_all(call: APICall, company, request: DownloadForGetAllRequest):
request_data = redis.get(f"get_all_download_{request.prepare_id}")
if not request_data:
raise errors.bad_request.InvalidId(
@ -214,17 +236,25 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
try:
call_data = json.loads(request_data)
request = PrepareDownloadForGetAll(**call_data)
request = PrepareDownloadForGetAllRequest(**call_data)
except Exception as ex:
raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
class SingleLine:
def write(self, line: str) -> str:
@staticmethod
def write(line: str) -> str:
return line
def generate():
projection = call_data.get("only_fields", [])
field_mappings = {
mapping.get("name", mapping["field"]): {
"field_path": mapping["field"].split("."),
"values": {
v.get("key"): v.get("value") for v in (mapping.get("values") or [])
},
}
for mapping in call_data.get("field_mappings", [])
}
get_fn = _get_download_getter_fn(
company,
call,
@ -233,26 +263,31 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
entity_type=request.entity_type,
)
if not get_fn:
yield csv.writer(SingleLine()).writerow(projection)
yield csv.writer(SingleLine()).writerow(field_mappings)
return
fields = [path.split(".") for path in projection]
def get_entity_field_as_str(data: dict, field: Sequence[str]) -> str:
val = nested_get(data, field, "")
def get_entity_field_as_str(
data: dict, field_path: Sequence[str], values: Mapping
) -> str:
val = nested_get(data, field_path, "")
if isinstance(val, dict):
val = val.get("id", "")
if values and isinstance(val, Hashable):
val = values.get(val, val)
return str(val)
def get_projected_fields(data: dict) -> Sequence[str]:
return [get_entity_field_as_str(data, f) for f in fields]
return [
get_entity_field_as_str(
data, field_path=m["field_path"], values=m["values"]
)
for m in field_mappings.values()
]
with ThreadPoolExecutor(1) as pool:
page = 0
page_size = int(
download_conf.get("batch_size", 500)
)
page_size = int(download_conf.get("batch_size", 500))
future = pool.submit(get_fn, page, page_size)
while True:
@ -266,12 +301,12 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
with StringIO() as fp:
writer = csv.writer(fp)
if page == 1:
writer.writerow(projection)
writer.writerow(field_mappings)
writer.writerows(get_projected_fields(r) for r in result)
yield fp.getvalue()
if page == 0:
yield csv.writer(SingleLine()).writerow(projection)
yield csv.writer(SingleLine()).writerow(field_mappings)
def get_project_name() -> Optional[str]:
projects = call_data.get("project")
@ -287,12 +322,10 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
if not project:
return
return project.basename[:download_conf.get("max_project_name_length", 60)]
return project.basename[: download_conf.get("max_project_name_length", 60)]
call.result.filename = "-".join(
filter(
None, ("clearml", get_project_name(), f"{request.entity_type}s.csv")
)
filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv"))
)
call.result.content_type = "text/csv"
call.result.raw_data = stream_with_context(generate())