Add organization.download_for_get_all endpoint

This commit is contained in:
allegroai 2023-07-26 18:31:20 +03:00
parent 5239755066
commit ff34da3c88
10 changed files with 287 additions and 40 deletions

View File

@ -1,6 +1,10 @@
from jsonmodels import fields, models
from enum import auto
from apiserver.apimodels import DictField
from jsonmodels import fields, models
from jsonmodels.validators import Length
from apiserver.apimodels import DictField, ActualEnumField
from apiserver.utilities.stringenum import StringEnum
class Filter(models.Base):
@ -23,3 +27,26 @@ class EntitiesCountRequest(models.Base):
active_users = fields.ListField(str)
search_hidden = fields.BoolField(default=False)
allow_public = fields.BoolField(default=True)
class DownloadType(StringEnum):
csv = auto()
class EntityType(StringEnum):
task = auto()
model = auto()
class PrepareDownloadForGetAll(models.Base):
download_type = ActualEnumField(DownloadType, default=DownloadType.csv)
entity_type = ActualEnumField(EntityType)
allow_public = fields.BoolField(default=True)
search_hidden = fields.BoolField(default=False)
only_fields = fields.ListField(
items_types=[str], validators=[Length(1)], required=True
)
class DownloadForGetAll(models.Base):
prepare_id = fields.StringField(required=True)

View File

@ -5,7 +5,6 @@ from mongoengine import Document
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem
from apiserver.database.model.base import GetMixin
from apiserver.service_repo import APICall
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
@ -87,13 +86,13 @@ class Metadata:
return paths
@classmethod
def escape_query_parameters(cls, call: APICall) -> dict:
if not call.data:
return call.data
def escape_query_parameters(cls, call_data: dict) -> dict:
if not call_data:
return call_data
keys = list(call.data)
keys = list(call_data)
call_data = {
safe_key: call.data[key]
safe_key: call_data[key]
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
}

View File

@ -1,3 +1,7 @@
tags_cache {
expiration_seconds: 3600
}
download {
redis_timeout_sec: 300
batch_size: 500
}

View File

@ -197,3 +197,70 @@ get_entities_count {
}
}
}
prepare_download_for_get_all {
"999.0": {
description: Prepares download from get_all_ex parameters
request {
type: object
required: [ entity_type, only_fields]
properties {
only_fields {
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items {type: string}
}
download_type {
description: "Download type. Determines the downloaded file's formatting and mime type."
type: string
enum: [ csv ]
default: csv
}
allow_public {
description: "Allow public entities to be returned in the results"
type: boolean
default: true
}
search_hidden {
description: "If set to 'true' then hidden entities are included in the search results"
type: boolean
default: false
}
entity_type {
description: "The type of the entity to retrieve"
type: string
enum: [
task
model
]
}
}
}
response {
type: object
properties {
prepare_id {
description: "Prepare ID (use when calling 'download_for_get_all')"
type: string
}
}
}
}
}
download_for_get_all {
"999.0": {
description: Generates a file for the download
request {
type: object
required: [ prepare_id ]
properties {
prepare_id {
description: "Call ID returned by a call to prepare_download_for_get_all"
type: string
}
}
}
response {
type: string
}
}
}

View File

@ -70,9 +70,7 @@ def _assert_task_or_model_exists(
@endpoint("events.add")
def add(call: APICall, company_id, _):
data = call.data.copy()
added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker
)
added, err_count, err_info = event_bll.add_events(company_id, [data], call.worker)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@ -82,11 +80,7 @@ def add_batch(call: APICall, company_id, _):
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(
company_id,
events,
call.worker,
)
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker,)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@ -576,6 +570,7 @@ def _get_multitask_plots(
sort=[{"iter": {"order": "desc"}}],
scroll_id=scroll_id,
no_scroll=no_scroll,
size=10000,
)
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=last_iters, task_names=task_names
@ -595,10 +590,7 @@ def get_multi_task_plots(call, company_id, _):
company_id, task_ids, model_events=model_events
)
return_events, total_events, next_scroll_id = _get_multitask_plots(
companies=companies,
last_iters=iters,
scroll_id=scroll_id,
no_scroll=no_scroll,
companies=companies, last_iters=iters, scroll_id=scroll_id, no_scroll=no_scroll,
)
call.result.data = dict(
plots=return_events,
@ -784,6 +776,7 @@ def get_debug_images_v1_8(call, company_id, _):
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
size=10000,
)
return_events = result.events
@ -960,12 +953,12 @@ def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest)
def _get_top_iter_unique_events_per_task(
events, max_iters: int, task_names: Mapping[str, str]
):
key = itemgetter("metric", "variant", "task", "iter")
key_fields = ("metric", "variant", "task")
unique_events = itertools.chain.from_iterable(
itertools.islice(group, max_iters)
for _, group in itertools.groupby(
sorted(events, key=key, reverse=True), key=key
sorted(events, key=itemgetter(*(key_fields + ("iter",))), reverse=True),
key=itemgetter(*key_fields),
)
)

View File

@ -67,7 +67,7 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
call_data = Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
models = Model.get_many(
company=company_id,
query_dict=call_data,
@ -112,7 +112,7 @@ def get_by_task_id(call: APICall, company_id, _):
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data)
call_data = Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}
models = Model.get_many_with_join(
@ -139,7 +139,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
@endpoint("models.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
models = Model.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True
)
@ -150,7 +150,7 @@ def get_by_id_ex(call: APICall, company_id, _):
@endpoint("models.get_all", required_fields=[])
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}
models = Model.get_many(

View File

@ -1,21 +1,42 @@
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter
from typing import Mapping, Type
from typing import Mapping, Type, Sequence, Optional, Callable
from flask import stream_with_context
from mongoengine import Q
from apiserver.apimodels.organization import TagsRequest, EntitiesCountRequest
from apiserver.apierrors import errors
from apiserver.apimodels.organization import (
TagsRequest,
EntitiesCountRequest,
DownloadForGetAll,
EntityType,
PrepareDownloadForGetAll,
)
from apiserver.bll.model import Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.config_repo import config
from apiserver.database.model import User, AttributedDocument, EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType
from apiserver.redis_manager import redman
from apiserver.service_repo import endpoint, APICall
from apiserver.services.models import conform_model_data
from apiserver.services.tasks import (
escape_execution_parameters,
_hidden_query,
conform_task_data,
)
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get
org_bll = OrgBLL()
project_bll = ProjectBLL()
redis = redman.connection("apiserver")
@endpoint("organization.get_tags", request_data_model=TagsRequest)
@ -105,3 +126,139 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
)
call.result.data = ret
def _get_download_getter_fn(
company: str,
call: APICall,
call_data: dict,
allow_public: bool,
entity_type: EntityType,
) -> Optional[Callable[[int, int], Sequence[dict]]]:
def get_task_data() -> Sequence[dict]:
tasks = Task.get_many_with_join(
company=company,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=allow_public,
)
conform_task_data(call, tasks)
return tasks
def get_model_data() -> Sequence[dict]:
models = Model.get_many_with_join(
company=company, query_dict=call_data, allow_public=allow_public,
)
conform_model_data(call, models)
return models
if entity_type == EntityType.task:
call_data = escape_execution_parameters(call_data)
get_fn = get_task_data
elif entity_type == EntityType.model:
call_data = Metadata.escape_query_parameters(call_data)
get_fn = get_model_data
else:
raise errors.bad_request.ValidationError(
f"Unsupported entity type: {str(entity_type)}"
)
def getter(page: int, page_size: int) -> Sequence[dict]:
call_data.pop("scroll_id", None)
call_data["page"] = page
call_data["page_size"] = page_size
return get_fn()
return getter
@endpoint("organization.prepare_download_for_get_all")
def prepare_download_for_get_all(
call: APICall, company: str, request: PrepareDownloadForGetAll
):
# validate input params
getter = _get_download_getter_fn(
company,
call,
call_data=call.data.copy(),
allow_public=request.allow_public,
entity_type=request.entity_type,
)
if getter:
getter(0, 1)
redis.setex(
f"get_all_download_{call.id}",
int(config.get("services.organization.download.redis_timeout_sec", 300)),
json.dumps(call.data),
)
call.result.data = dict(prepare_id=call.id)
@endpoint("organization.download_for_get_all")
def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
request_data = redis.get(f"get_all_download_{request.prepare_id}")
if not request_data:
raise errors.bad_request.InvalidId(
f"prepare ID not found", prepare_id=request.prepare_id
)
try:
call_data = json.loads(request_data)
request = PrepareDownloadForGetAll(**call_data)
except Exception as ex:
raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
def generate():
projection = call_data.get("only_fields", [])
headers = ",".join(projection)
get_fn = _get_download_getter_fn(
company,
call,
call_data=call_data,
allow_public=request.allow_public,
entity_type=request.entity_type,
)
if not get_fn:
return headers
fields = [path.split(".") for path in projection]
def get_entity_field_as_str(data: dict, field: Sequence[str]) -> str:
val = nested_get(data, field, "")
if isinstance(val, dict):
val = val.get("id", "")
return str(val)
def get_string_from_entity_data(data: dict) -> str:
return ",".join(get_entity_field_as_str(data, f) for f in fields)
with ThreadPoolExecutor(1) as pool:
page = 0
page_size = int(
config.get("services.organization.download.batch_size", 500)
)
future = pool.submit(get_fn, page, page_size)
out = [headers]
while True:
result = future.result()
if not result:
break
page += 1
future = pool.submit(get_fn, page, page_size)
out.extend(get_string_from_entity_data(r) for r in result)
yield "\n".join(out) + "\n"
out = []
if out:
yield "\n".join(out)
call.result.filename = f"{request.entity_type}_export.{request.download_type}"
call.result.content_type = "text/csv"
call.result.raw_data = stream_with_context(generate())

View File

@ -83,7 +83,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data)
ret_params = {}
call_data = Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
queues = queue_bll.get_queue_infos(
company_id=company,
query_dict=call_data,
@ -99,7 +99,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
def get_all(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data)
ret_params = {}
call_data = Metadata.escape_query_parameters(call)
call_data = Metadata.escape_query_parameters(call.data)
queues = queue_bll.get_all(
company_id=company,
query_dict=call_data,

View File

@ -222,7 +222,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
entity_cls = Task
conform_data = conform_task_data
call_data = escape_execution_parameters(call)
call_data = escape_execution_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}

View File

@ -171,13 +171,13 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
call.result.data = {"task": task_dict}
def escape_execution_parameters(call: APICall) -> dict:
if not call.data:
return call.data
def escape_execution_parameters(call_data: dict) -> dict:
if not call_data:
return call_data
keys = list(call.data)
keys = list(call_data)
call_data = {
safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys))
safe_key: call_data[key] for key, safe_key in zip(keys, escape_paths(keys))
}
projection = Task.get_projection(call_data)
@ -204,7 +204,7 @@ def _hidden_query(data: dict) -> Q:
@endpoint("tasks.get_all_ex")
def get_all_ex(call: APICall, company_id, request: GetAllReq):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call)
call_data = escape_execution_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
@ -221,7 +221,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq):
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call)
call_data = escape_execution_parameters(call.data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
@ -233,7 +233,7 @@ def get_by_id_ex(call: APICall, company_id, _):
@endpoint("tasks.get_all", required_fields=[])
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call)
call_data = escape_execution_parameters(call.data)
process_include_subprojects(call_data)
ret_params = {}