Add field_mappings to organizations download endpoints

This commit is contained in:
allegroai 2023-07-26 18:39:41 +03:00
parent ed86750b24
commit bc2fe28bdd
3 changed files with 117 additions and 31 deletions

View File

@ -1,9 +1,10 @@
from enum import auto from enum import auto
from typing import Sequence
from jsonmodels import fields, models from jsonmodels import fields, models
from jsonmodels.validators import Length from jsonmodels.validators import Length
from apiserver.apimodels import DictField, ActualEnumField from apiserver.apimodels import DictField, ActualEnumField, ScalarField
from apiserver.utilities.stringenum import StringEnum from apiserver.utilities.stringenum import StringEnum
@ -34,14 +35,28 @@ class EntityType(StringEnum):
model = auto() model = auto()
class PrepareDownloadForGetAll(models.Base): class ValueMapping(models.Base):
key = ScalarField(nullable=True)
value = ScalarField(nullable=True)
class FieldMapping(models.Base):
field = fields.StringField(required=True)
name = fields.StringField()
values: Sequence[ValueMapping] = fields.ListField(items_types=[ValueMapping])
class PrepareDownloadForGetAllRequest(models.Base):
entity_type = ActualEnumField(EntityType) entity_type = ActualEnumField(EntityType)
allow_public = fields.BoolField(default=True) allow_public = fields.BoolField(default=True)
search_hidden = fields.BoolField(default=False) search_hidden = fields.BoolField(default=False)
only_fields = fields.ListField( only_fields = fields.ListField(
items_types=[str], validators=[Length(1)], required=True items_types=[str], validators=[Length(1)], required=True
) )
field_mappings: Sequence[FieldMapping] = fields.ListField(
items_types=[FieldMapping], validators=[Length(1)], required=True
)
class DownloadForGetAll(models.Base): class DownloadForGetAllRequest(models.Base):
prepare_id = fields.StringField(required=True) prepare_id = fields.StringField(required=True)

View File

@ -1,5 +1,38 @@
_description: "This service provides organization level operations" _description: "This service provides organization level operations"
_definitions {
value_mapping {
type: object
required: [key, value]
properties {
key {
description: Original value
type: object
}
value {
description: Translated value
type: object
}
}
}
field_mapping {
type: object
required: [field]
properties {
field {
description: The source field name as specified in the only_fields
type: string
}
name {
description: The column name in the exported csv file
type: string
}
values {
type: array
items { "$ref": "#/definitions/value_mapping"}
}
}
}
}
get_tags { get_tags {
"2.8" { "2.8" {
description: "Get all the user and system tags used for the company tasks and models" description: "Get all the user and system tags used for the company tasks and models"
@ -202,7 +235,7 @@ prepare_download_for_get_all {
description: Prepares download from get_all_ex parameters description: Prepares download from get_all_ex parameters
request { request {
type: object type: object
required: [ entity_type, only_fields] required: [ entity_type, only_fields, field_mappings]
properties { properties {
only_fields { only_fields {
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)" description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
@ -227,6 +260,11 @@ prepare_download_for_get_all {
model model
] ]
} }
field_mappings {
description: The name and value mappings for the exported fields. The fields that are not in the mappings will not be exported
type: array
items { "$ref": "#/definitions/field_mapping"}
}
} }
} }
response { response {

View File

@ -3,7 +3,7 @@ from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from io import StringIO from io import StringIO
from operator import itemgetter from operator import itemgetter
from typing import Mapping, Type, Sequence, Optional, Callable from typing import Mapping, Type, Sequence, Optional, Callable, Hashable
from flask import stream_with_context from flask import stream_with_context
from mongoengine import Q from mongoengine import Q
@ -12,9 +12,9 @@ from apiserver.apierrors import errors
from apiserver.apimodels.organization import ( from apiserver.apimodels.organization import (
TagsRequest, TagsRequest,
EntitiesCountRequest, EntitiesCountRequest,
DownloadForGetAll, DownloadForGetAllRequest,
EntityType, EntityType,
PrepareDownloadForGetAll, PrepareDownloadForGetAllRequest,
) )
from apiserver.bll.model import Metadata from apiserver.bll.model import Metadata
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
@ -47,7 +47,10 @@ def get_tags(call: APICall, company, request: TagsRequest):
ret = defaultdict(set) ret = defaultdict(set)
for entity in Tags.Model, Tags.Task: for entity in Tags.Model, Tags.Task:
tags = org_bll.get_tags( tags = org_bll.get_tags(
company, entity, include_system=request.include_system, filter_=filter_dict, company,
entity,
include_system=request.include_system,
filter_=filter_dict,
) )
for field, vals in tags.items(): for field, vals in tags.items():
ret[field] |= vals ret[field] |= vals
@ -149,7 +152,9 @@ def _get_download_getter_fn(
def get_model_data() -> Sequence[dict]: def get_model_data() -> Sequence[dict]:
models = Model.get_many_with_join( models = Model.get_many_with_join(
company=company, query_dict=call_data, allow_public=allow_public, company=company,
query_dict=call_data,
allow_public=allow_public,
) )
conform_model_data(call, models) conform_model_data(call, models)
return models return models
@ -182,9 +187,26 @@ download_conf = config.get("services.organization.download")
@endpoint("organization.prepare_download_for_get_all") @endpoint("organization.prepare_download_for_get_all")
def prepare_download_for_get_all( def prepare_download_for_get_all(
call: APICall, company: str, request: PrepareDownloadForGetAll call: APICall, company: str, request: PrepareDownloadForGetAllRequest
): ):
# validate input params # validate input params
field_names = set()
for fm in request.field_mappings:
name = fm.name or fm.field
if name in field_names:
raise errors.bad_request.ValidationError(
f"Field_name appears more than once in field_mappings: {str(name)}"
)
field_names.add(name)
if fm.values:
value_keys = set()
for v in fm.values:
if v.key in value_keys:
raise errors.bad_request.ValidationError(
f"Value key appears more than once in field_mappings: {str(v.key)}"
)
value_keys.add(v.key)
getter = _get_download_getter_fn( getter = _get_download_getter_fn(
company, company,
call, call,
@ -205,7 +227,7 @@ def prepare_download_for_get_all(
@endpoint("organization.download_for_get_all") @endpoint("organization.download_for_get_all")
def download_for_get_all(call: APICall, company, request: DownloadForGetAll): def download_for_get_all(call: APICall, company, request: DownloadForGetAllRequest):
request_data = redis.get(f"get_all_download_{request.prepare_id}") request_data = redis.get(f"get_all_download_{request.prepare_id}")
if not request_data: if not request_data:
raise errors.bad_request.InvalidId( raise errors.bad_request.InvalidId(
@ -214,17 +236,25 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
try: try:
call_data = json.loads(request_data) call_data = json.loads(request_data)
request = PrepareDownloadForGetAll(**call_data) request = PrepareDownloadForGetAllRequest(**call_data)
except Exception as ex: except Exception as ex:
raise errors.server_error.DataError("failed parsing prepared data", ex=ex) raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
class SingleLine: class SingleLine:
def write(self, line: str) -> str: @staticmethod
def write(line: str) -> str:
return line return line
def generate(): def generate():
projection = call_data.get("only_fields", []) field_mappings = {
mapping.get("name", mapping["field"]): {
"field_path": mapping["field"].split("."),
"values": {
v.get("key"): v.get("value") for v in (mapping.get("values") or [])
},
}
for mapping in call_data.get("field_mappings", [])
}
get_fn = _get_download_getter_fn( get_fn = _get_download_getter_fn(
company, company,
call, call,
@ -233,26 +263,31 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
entity_type=request.entity_type, entity_type=request.entity_type,
) )
if not get_fn: if not get_fn:
yield csv.writer(SingleLine()).writerow(projection) yield csv.writer(SingleLine()).writerow(field_mappings)
return return
fields = [path.split(".") for path in projection] def get_entity_field_as_str(
data: dict, field_path: Sequence[str], values: Mapping
def get_entity_field_as_str(data: dict, field: Sequence[str]) -> str: ) -> str:
val = nested_get(data, field, "") val = nested_get(data, field_path, "")
if isinstance(val, dict): if isinstance(val, dict):
val = val.get("id", "") val = val.get("id", "")
if values and isinstance(val, Hashable):
val = values.get(val, val)
return str(val) return str(val)
def get_projected_fields(data: dict) -> Sequence[str]: def get_projected_fields(data: dict) -> Sequence[str]:
return [get_entity_field_as_str(data, f) for f in fields] return [
get_entity_field_as_str(
data, field_path=m["field_path"], values=m["values"]
)
for m in field_mappings.values()
]
with ThreadPoolExecutor(1) as pool: with ThreadPoolExecutor(1) as pool:
page = 0 page = 0
page_size = int( page_size = int(download_conf.get("batch_size", 500))
download_conf.get("batch_size", 500)
)
future = pool.submit(get_fn, page, page_size) future = pool.submit(get_fn, page, page_size)
while True: while True:
@ -266,12 +301,12 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
with StringIO() as fp: with StringIO() as fp:
writer = csv.writer(fp) writer = csv.writer(fp)
if page == 1: if page == 1:
writer.writerow(projection) writer.writerow(field_mappings)
writer.writerows(get_projected_fields(r) for r in result) writer.writerows(get_projected_fields(r) for r in result)
yield fp.getvalue() yield fp.getvalue()
if page == 0: if page == 0:
yield csv.writer(SingleLine()).writerow(projection) yield csv.writer(SingleLine()).writerow(field_mappings)
def get_project_name() -> Optional[str]: def get_project_name() -> Optional[str]:
projects = call_data.get("project") projects = call_data.get("project")
@ -287,12 +322,10 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
if not project: if not project:
return return
return project.basename[:download_conf.get("max_project_name_length", 60)] return project.basename[: download_conf.get("max_project_name_length", 60)]
call.result.filename = "-".join( call.result.filename = "-".join(
filter( filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv"))
None, ("clearml", get_project_name(), f"{request.entity_type}s.csv")
)
) )
call.result.content_type = "text/csv" call.result.content_type = "text/csv"
call.result.raw_data = stream_with_context(generate()) call.result.raw_data = stream_with_context(generate())