From ff34da3c88197afd0eff8f08eba145fd14de8492 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 26 Jul 2023 18:31:20 +0300 Subject: [PATCH] Add organization.download_for_get_all endpoint --- apiserver/apimodels/organization.py | 31 +++- apiserver/bll/model/metadata.py | 11 +- .../config/default/services/organization.conf | 4 + apiserver/schema/services/organization.conf | 67 ++++++++ apiserver/services/events.py | 23 +-- apiserver/services/models.py | 8 +- apiserver/services/organization.py | 161 +++++++++++++++++- apiserver/services/queues.py | 4 +- apiserver/services/reports.py | 2 +- apiserver/services/tasks.py | 16 +- 10 files changed, 287 insertions(+), 40 deletions(-) diff --git a/apiserver/apimodels/organization.py b/apiserver/apimodels/organization.py index 46bcd1f..ab67393 100644 --- a/apiserver/apimodels/organization.py +++ b/apiserver/apimodels/organization.py @@ -1,6 +1,10 @@ -from jsonmodels import fields, models +from enum import auto -from apiserver.apimodels import DictField +from jsonmodels import fields, models +from jsonmodels.validators import Length + +from apiserver.apimodels import DictField, ActualEnumField +from apiserver.utilities.stringenum import StringEnum class Filter(models.Base): @@ -23,3 +27,26 @@ class EntitiesCountRequest(models.Base): active_users = fields.ListField(str) search_hidden = fields.BoolField(default=False) allow_public = fields.BoolField(default=True) + + +class DownloadType(StringEnum): + csv = auto() + + +class EntityType(StringEnum): + task = auto() + model = auto() + + +class PrepareDownloadForGetAll(models.Base): + download_type = ActualEnumField(DownloadType, default=DownloadType.csv) + entity_type = ActualEnumField(EntityType) + allow_public = fields.BoolField(default=True) + search_hidden = fields.BoolField(default=False) + only_fields = fields.ListField( + items_types=[str], validators=[Length(1)], required=True + ) + + +class DownloadForGetAll(models.Base): + prepare_id = fields.StringField(required=True) diff --git a/apiserver/bll/model/metadata.py b/apiserver/bll/model/metadata.py index 7c40576..3a1fee3 100644 --- a/apiserver/bll/model/metadata.py +++ b/apiserver/bll/model/metadata.py @@ -5,7 +5,6 @@ from mongoengine import Document from apiserver.apierrors import errors from apiserver.apimodels.metadata import MetadataItem from apiserver.database.model.base import GetMixin -from apiserver.service_repo import APICall from apiserver.utilities.parameter_key_escaper import ( ParameterKeyEscaper, mongoengine_safe, @@ -87,13 +86,13 @@ class Metadata: return paths @classmethod - def escape_query_parameters(cls, call: APICall) -> dict: - if not call.data: - return call.data + def escape_query_parameters(cls, call_data: dict) -> dict: + if not call_data: + return call_data - keys = list(call.data) + keys = list(call_data) call_data = { - safe_key: call.data[key] + safe_key: call_data[key] for key, safe_key in zip(keys, Metadata.escape_paths(keys)) } diff --git a/apiserver/config/default/services/organization.conf b/apiserver/config/default/services/organization.conf index 03a322b..a0b7969 100644 --- a/apiserver/config/default/services/organization.conf +++ b/apiserver/config/default/services/organization.conf @@ -1,3 +1,7 @@ tags_cache { expiration_seconds: 3600 +} +download { + redis_timeout_sec: 300 + batch_size: 500 } \ No newline at end of file diff --git a/apiserver/schema/services/organization.conf b/apiserver/schema/services/organization.conf index abd8c7c..c754a62 100644 --- a/apiserver/schema/services/organization.conf +++ b/apiserver/schema/services/organization.conf @@ -197,3 +197,70 @@ get_entities_count { } } } +prepare_download_for_get_all { + "999.0": { + description: Prepares download from get_all_ex parameters + request { + type: object + required: [ entity_type, only_fields] + properties { + only_fields { + description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)" + type: array + items {type: string} + } + download_type { + description: "Download type. Determines the downloaded file's formatting and mime type." + type: string + enum: [ csv ] + default: csv + } + allow_public { + description: "Allow public entities to be returned in the results" + type: boolean + default: true + } + search_hidden { + description: "If set to 'true' then hidden entities are included in the search results" + type: boolean + default: false + } + entity_type { + description: "The type of the entity to retrieve" + type: string + enum: [ + task + model + ] + } + } + } + response { + type: object + properties { + prepare_id { + description: "Prepare ID (use when calling 'download_for_get_all')" + type: string + } + } + } + } +} +download_for_get_all { + "999.0": { + description: Generates a file for the download + request { + type: object + required: [ prepare_id ] + properties { + prepare_id { + description: "Call ID returned by a call to prepare_download_for_get_all" + type: string + } + } + } + response { + type: string + } + } +} \ No newline at end of file diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 977194c..28f162c 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -70,9 +70,7 @@ def _assert_task_or_model_exists( @endpoint("events.add") def add(call: APICall, company_id, _): data = call.data.copy() - added, err_count, err_info = event_bll.add_events( - company_id, [data], call.worker - ) + added, err_count, err_info = event_bll.add_events(company_id, [data], call.worker) call.result.data = dict(added=added, errors=err_count, errors_info=err_info) @@ -82,11 +80,7 @@ def add_batch(call: APICall, company_id, _): if events is None or len(events) == 0: raise errors.bad_request.BatchContainsNoItems() - added, err_count, err_info = event_bll.add_events( - company_id, - events, - call.worker, - ) + added, err_count, err_info = event_bll.add_events(company_id, events, call.worker,) call.result.data = dict(added=added, errors=err_count, errors_info=err_info) @@ -576,6 +570,7 @@ def _get_multitask_plots( sort=[{"iter": {"order": "desc"}}], scroll_id=scroll_id, no_scroll=no_scroll, + size=10000, ) return_events = _get_top_iter_unique_events_per_task( result.events, max_iters=last_iters, task_names=task_names @@ -595,10 +590,7 @@ def get_multi_task_plots(call, company_id, _): company_id, task_ids, model_events=model_events ) return_events, total_events, next_scroll_id = _get_multitask_plots( - companies=companies, - last_iters=iters, - scroll_id=scroll_id, - no_scroll=no_scroll, + companies=companies, last_iters=iters, scroll_id=scroll_id, no_scroll=no_scroll, ) call.result.data = dict( plots=return_events, @@ -784,6 +776,7 @@ def get_debug_images_v1_8(call, company_id, _): sort=[{"iter": {"order": "desc"}}], last_iter_count=iters, scroll_id=scroll_id, + size=10000, ) return_events = result.events @@ -960,12 +953,12 @@ def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest) def _get_top_iter_unique_events_per_task( events, max_iters: int, task_names: Mapping[str, str] ): - key = itemgetter("metric", "variant", "task", "iter") - + key_fields = ("metric", "variant", "task") unique_events = itertools.chain.from_iterable( itertools.islice(group, max_iters) for _, group in itertools.groupby( - sorted(events, key=key, reverse=True), key=key + sorted(events, key=itemgetter(*(key_fields + ("iter",))), reverse=True), + key=itemgetter(*key_fields), ) ) diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 1816520..cdc4d7b 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -67,7 +67,7 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]): @endpoint("models.get_by_id", required_fields=["model"]) def get_by_id(call: APICall, company_id, _): model_id = call.data["model"] - call_data = Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call.data) models = Model.get_many( company=company_id, query_dict=call_data, @@ -112,7 +112,7 @@ def get_by_task_id(call: APICall, company_id, _): @endpoint("models.get_all_ex", request_data_model=ModelsGetRequest) def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): conform_tag_fields(call, call.data) - call_data = Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call.data) process_include_subprojects(call_data) ret_params = {} models = Model.get_many_with_join( @@ -139,7 +139,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): @endpoint("models.get_by_id_ex", required_fields=["id"]) def get_by_id_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) - call_data = Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call.data) models = Model.get_many_with_join( company=company_id, query_dict=call_data, allow_public=True ) @@ -150,7 +150,7 @@ def get_by_id_ex(call: APICall, company_id, _): @endpoint("models.get_all", required_fields=[]) def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) - call_data = Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call.data) process_include_subprojects(call_data) ret_params = {} models = Model.get_many( diff --git a/apiserver/services/organization.py b/apiserver/services/organization.py index 20ff54d..1bbd245 100644 --- a/apiserver/services/organization.py +++ b/apiserver/services/organization.py @@ -1,21 +1,42 @@ from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from operator import itemgetter -from typing import Mapping, Type +from typing import Mapping, Type, Sequence, Optional, Callable +from flask import stream_with_context from mongoengine import Q -from apiserver.apimodels.organization import TagsRequest, EntitiesCountRequest +from apiserver.apierrors import errors +from apiserver.apimodels.organization import ( + TagsRequest, + EntitiesCountRequest, + DownloadForGetAll, + EntityType, + PrepareDownloadForGetAll, +) +from apiserver.bll.model import Metadata from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL +from apiserver.config_repo import config from apiserver.database.model import User, AttributedDocument, EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskType +from apiserver.redis_manager import redman from apiserver.service_repo import endpoint, APICall +from apiserver.services.models import conform_model_data +from apiserver.services.tasks import ( + escape_execution_parameters, + _hidden_query, + conform_task_data, +) from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response +from apiserver.utilities import json +from apiserver.utilities.dicts import nested_get org_bll = OrgBLL() project_bll = ProjectBLL() +redis = redman.connection("apiserver") @endpoint("organization.get_tags", request_data_model=TagsRequest) @@ -105,3 +126,139 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest): ) call.result.data = ret + + +def _get_download_getter_fn( + company: str, + call: APICall, + call_data: dict, + allow_public: bool, + entity_type: EntityType, +) -> Optional[Callable[[int, int], Sequence[dict]]]: + def get_task_data() -> Sequence[dict]: + tasks = Task.get_many_with_join( + company=company, + query_dict=call_data, + query=_hidden_query(call_data), + allow_public=allow_public, + ) + conform_task_data(call, tasks) + return tasks + + def get_model_data() -> Sequence[dict]: + models = Model.get_many_with_join( + company=company, query_dict=call_data, allow_public=allow_public, + ) + conform_model_data(call, models) + return models + + if entity_type == EntityType.task: + call_data = escape_execution_parameters(call_data) + get_fn = get_task_data + elif entity_type == EntityType.model: + call_data = Metadata.escape_query_parameters(call_data) + get_fn = get_model_data + else: + raise errors.bad_request.ValidationError( + f"Unsupported entity type: {str(entity_type)}" + ) + + def getter(page: int, page_size: int) -> Sequence[dict]: + call_data.pop("scroll_id", None) + call_data["page"] = page + call_data["page_size"] = page_size + return get_fn() + + return getter + + +@endpoint("organization.prepare_download_for_get_all") +def prepare_download_for_get_all( + call: APICall, company: str, request: PrepareDownloadForGetAll +): + # validate input params + getter = _get_download_getter_fn( + company, + call, + call_data=call.data.copy(), + allow_public=request.allow_public, + entity_type=request.entity_type, + ) + if getter: + getter(0, 1) + + redis.setex( + f"get_all_download_{call.id}", + int(config.get("services.organization.download.redis_timeout_sec", 300)), + json.dumps(call.data), + ) + + call.result.data = dict(prepare_id=call.id) + + +@endpoint("organization.download_for_get_all") +def download_for_get_all(call: APICall, company, request: DownloadForGetAll): + request_data = redis.get(f"get_all_download_{request.prepare_id}") + if not request_data: + raise errors.bad_request.InvalidId( + f"prepare ID not found", prepare_id=request.prepare_id + ) + + try: + call_data = json.loads(request_data) + request = PrepareDownloadForGetAll(**call_data) + except Exception as ex: + raise errors.server_error.DataError("failed parsing prepared data", ex=ex) + + def generate(): + projection = call_data.get("only_fields", []) + headers = ",".join(projection) + + get_fn = _get_download_getter_fn( + company, + call, + call_data=call_data, + allow_public=request.allow_public, + entity_type=request.entity_type, + ) + if not get_fn: + return headers + + fields = [path.split(".") for path in projection] + + def get_entity_field_as_str(data: dict, field: Sequence[str]) -> str: + val = nested_get(data, field, "") + if isinstance(val, dict): + val = val.get("id", "") + + return str(val) + + def get_string_from_entity_data(data: dict) -> str: + return ",".join(get_entity_field_as_str(data, f) for f in fields) + + with ThreadPoolExecutor(1) as pool: + page = 0 + page_size = int( + config.get("services.organization.download.batch_size", 500) + ) + future = pool.submit(get_fn, page, page_size) + out = [headers] + + while True: + result = future.result() + if not result: + break + + page += 1 + future = pool.submit(get_fn, page, page_size) + + out.extend(get_string_from_entity_data(r) for r in result) + yield "\n".join(out) + "\n" + out = [] + + if out: + yield "\n".join(out) + + call.result.filename = f"{request.entity_type}_export.{request.download_type}" + call.result.content_type = "text/csv" + call.result.raw_data = stream_with_context(generate()) diff --git a/apiserver/services/queues.py b/apiserver/services/queues.py index 54e255a..c9fdf36 100644 --- a/apiserver/services/queues.py +++ b/apiserver/services/queues.py @@ -83,7 +83,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest): conform_tag_fields(call, call.data) ret_params = {} - call_data = Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call.data) queues = queue_bll.get_queue_infos( company_id=company, query_dict=call_data, @@ -99,7 +99,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest): def get_all(call: APICall, company: str, request: GetAllRequest): conform_tag_fields(call, call.data) ret_params = {} - call_data = Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call.data) queues = queue_bll.get_all( company_id=company, query_dict=call_data, diff --git a/apiserver/services/reports.py b/apiserver/services/reports.py index 1a52472..13172b6 100644 --- a/apiserver/services/reports.py +++ b/apiserver/services/reports.py @@ -222,7 +222,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): entity_cls = Task conform_data = conform_task_data - call_data = escape_execution_parameters(call) + call_data = escape_execution_parameters(call.data) process_include_subprojects(call_data) ret_params = {} diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 9f0961d..9c7be8d 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -171,13 +171,13 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest): call.result.data = {"task": task_dict} -def escape_execution_parameters(call: APICall) -> dict: - if not call.data: - return call.data +def escape_execution_parameters(call_data: dict) -> dict: + if not call_data: + return call_data - keys = list(call.data) + keys = list(call_data) call_data = { - safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys)) + safe_key: call_data[key] for key, safe_key in zip(keys, escape_paths(keys)) } projection = Task.get_projection(call_data) @@ -204,7 +204,7 @@ def _hidden_query(data: dict) -> Q: @endpoint("tasks.get_all_ex") def get_all_ex(call: APICall, company_id, request: GetAllReq): conform_tag_fields(call, call.data) - call_data = escape_execution_parameters(call) + call_data = escape_execution_parameters(call.data) process_include_subprojects(call_data) ret_params = {} tasks = Task.get_many_with_join( @@ -221,7 +221,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq): @endpoint("tasks.get_by_id_ex", required_fields=["id"]) def get_by_id_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) - call_data = escape_execution_parameters(call) + call_data = escape_execution_parameters(call.data) tasks = Task.get_many_with_join( company=company_id, query_dict=call_data, allow_public=True, ) @@ -233,7 +233,7 @@ def get_by_id_ex(call: APICall, company_id, _): @endpoint("tasks.get_all", required_fields=[]) def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) - call_data = escape_execution_parameters(call) + call_data = escape_execution_parameters(call.data) process_include_subprojects(call_data) ret_params = {}