mirror of
https://github.com/clearml/clearml-server
synced 2025-06-25 11:45:48 +00:00
Add organization.download_for_get_all endpoint
This commit is contained in:
parent
5239755066
commit
ff34da3c88
@ -1,6 +1,10 @@
|
||||
from jsonmodels import fields, models
|
||||
from enum import auto
|
||||
|
||||
from apiserver.apimodels import DictField
|
||||
from jsonmodels import fields, models
|
||||
from jsonmodels.validators import Length
|
||||
|
||||
from apiserver.apimodels import DictField, ActualEnumField
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
@ -23,3 +27,26 @@ class EntitiesCountRequest(models.Base):
|
||||
active_users = fields.ListField(str)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class DownloadType(StringEnum):
|
||||
csv = auto()
|
||||
|
||||
|
||||
class EntityType(StringEnum):
|
||||
task = auto()
|
||||
model = auto()
|
||||
|
||||
|
||||
class PrepareDownloadForGetAll(models.Base):
|
||||
download_type = ActualEnumField(DownloadType, default=DownloadType.csv)
|
||||
entity_type = ActualEnumField(EntityType)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
only_fields = fields.ListField(
|
||||
items_types=[str], validators=[Length(1)], required=True
|
||||
)
|
||||
|
||||
|
||||
class DownloadForGetAll(models.Base):
|
||||
prepare_id = fields.StringField(required=True)
|
||||
|
@ -5,7 +5,6 @@ from mongoengine import Document
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
@ -87,13 +86,13 @@ class Metadata:
|
||||
return paths
|
||||
|
||||
@classmethod
|
||||
def escape_query_parameters(cls, call: APICall) -> dict:
|
||||
if not call.data:
|
||||
return call.data
|
||||
def escape_query_parameters(cls, call_data: dict) -> dict:
|
||||
if not call_data:
|
||||
return call_data
|
||||
|
||||
keys = list(call.data)
|
||||
keys = list(call_data)
|
||||
call_data = {
|
||||
safe_key: call.data[key]
|
||||
safe_key: call_data[key]
|
||||
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
|
||||
}
|
||||
|
||||
|
@ -1,3 +1,7 @@
|
||||
tags_cache {
|
||||
expiration_seconds: 3600
|
||||
}
|
||||
download {
|
||||
redis_timeout_sec: 300
|
||||
batch_size: 500
|
||||
}
|
@ -197,3 +197,70 @@ get_entities_count {
|
||||
}
|
||||
}
|
||||
}
|
||||
prepare_download_for_get_all {
|
||||
"999.0": {
|
||||
description: Prepares download from get_all_ex parameters
|
||||
request {
|
||||
type: object
|
||||
required: [ entity_type, only_fields]
|
||||
properties {
|
||||
only_fields {
|
||||
description: "List of task field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
download_type {
|
||||
description: "Download type. Determines the downloaded file's formatting and mime type."
|
||||
type: string
|
||||
enum: [ csv ]
|
||||
default: csv
|
||||
}
|
||||
allow_public {
|
||||
description: "Allow public entities to be returned in the results"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
search_hidden {
|
||||
description: "If set to 'true' then hidden entities are included in the search results"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
entity_type {
|
||||
description: "The type of the entity to retrieve"
|
||||
type: string
|
||||
enum: [
|
||||
task
|
||||
model
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
prepare_id {
|
||||
description: "Prepare ID (use when calling 'download_for_get_all')"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
download_for_get_all {
|
||||
"999.0": {
|
||||
description: Generates a file for the download
|
||||
request {
|
||||
type: object
|
||||
required: [ prepare_id ]
|
||||
properties {
|
||||
prepare_id {
|
||||
description: "Call ID returned by a call to prepare_download_for_get_all"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
@ -70,9 +70,7 @@ def _assert_task_or_model_exists(
|
||||
@endpoint("events.add")
|
||||
def add(call: APICall, company_id, _):
|
||||
data = call.data.copy()
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id, [data], call.worker
|
||||
)
|
||||
added, err_count, err_info = event_bll.add_events(company_id, [data], call.worker)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
|
||||
@ -82,11 +80,7 @@ def add_batch(call: APICall, company_id, _):
|
||||
if events is None or len(events) == 0:
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id,
|
||||
events,
|
||||
call.worker,
|
||||
)
|
||||
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker,)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
|
||||
@ -576,6 +570,7 @@ def _get_multitask_plots(
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
size=10000,
|
||||
)
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=last_iters, task_names=task_names
|
||||
@ -595,10 +590,7 @@ def get_multi_task_plots(call, company_id, _):
|
||||
company_id, task_ids, model_events=model_events
|
||||
)
|
||||
return_events, total_events, next_scroll_id = _get_multitask_plots(
|
||||
companies=companies,
|
||||
last_iters=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
companies=companies, last_iters=iters, scroll_id=scroll_id, no_scroll=no_scroll,
|
||||
)
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
@ -784,6 +776,7 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id,
|
||||
size=10000,
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
@ -960,12 +953,12 @@ def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest)
|
||||
def _get_top_iter_unique_events_per_task(
|
||||
events, max_iters: int, task_names: Mapping[str, str]
|
||||
):
|
||||
key = itemgetter("metric", "variant", "task", "iter")
|
||||
|
||||
key_fields = ("metric", "variant", "task")
|
||||
unique_events = itertools.chain.from_iterable(
|
||||
itertools.islice(group, max_iters)
|
||||
for _, group in itertools.groupby(
|
||||
sorted(events, key=key, reverse=True), key=key
|
||||
sorted(events, key=itemgetter(*(key_fields + ("iter",))), reverse=True),
|
||||
key=itemgetter(*key_fields),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -67,7 +67,7 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
|
||||
@endpoint("models.get_by_id", required_fields=["model"])
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
call_data = Metadata.escape_query_parameters(call)
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
@ -112,7 +112,7 @@ def get_by_task_id(call: APICall, company_id, _):
|
||||
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
|
||||
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = Metadata.escape_query_parameters(call)
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
process_include_subprojects(call_data)
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
@ -139,7 +139,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
|
||||
@endpoint("models.get_by_id_ex", required_fields=["id"])
|
||||
def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = Metadata.escape_query_parameters(call)
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True
|
||||
)
|
||||
@ -150,7 +150,7 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = Metadata.escape_query_parameters(call)
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
process_include_subprojects(call_data)
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
|
@ -1,21 +1,42 @@
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from typing import Mapping, Type
|
||||
from typing import Mapping, Type, Sequence, Optional, Callable
|
||||
|
||||
from flask import stream_with_context
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apimodels.organization import TagsRequest, EntitiesCountRequest
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.organization import (
|
||||
TagsRequest,
|
||||
EntitiesCountRequest,
|
||||
DownloadForGetAll,
|
||||
EntityType,
|
||||
PrepareDownloadForGetAll,
|
||||
)
|
||||
from apiserver.bll.model import Metadata
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import User, AttributedDocument, EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskType
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import endpoint, APICall
|
||||
from apiserver.services.models import conform_model_data
|
||||
from apiserver.services.tasks import (
|
||||
escape_execution_parameters,
|
||||
_hidden_query,
|
||||
conform_task_data,
|
||||
)
|
||||
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
redis = redman.connection("apiserver")
|
||||
|
||||
|
||||
@endpoint("organization.get_tags", request_data_model=TagsRequest)
|
||||
@ -105,3 +126,139 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
|
||||
)
|
||||
|
||||
call.result.data = ret
|
||||
|
||||
|
||||
def _get_download_getter_fn(
|
||||
company: str,
|
||||
call: APICall,
|
||||
call_data: dict,
|
||||
allow_public: bool,
|
||||
entity_type: EntityType,
|
||||
) -> Optional[Callable[[int, int], Sequence[dict]]]:
|
||||
def get_task_data() -> Sequence[dict]:
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company,
|
||||
query_dict=call_data,
|
||||
query=_hidden_query(call_data),
|
||||
allow_public=allow_public,
|
||||
)
|
||||
conform_task_data(call, tasks)
|
||||
return tasks
|
||||
|
||||
def get_model_data() -> Sequence[dict]:
|
||||
models = Model.get_many_with_join(
|
||||
company=company, query_dict=call_data, allow_public=allow_public,
|
||||
)
|
||||
conform_model_data(call, models)
|
||||
return models
|
||||
|
||||
if entity_type == EntityType.task:
|
||||
call_data = escape_execution_parameters(call_data)
|
||||
get_fn = get_task_data
|
||||
elif entity_type == EntityType.model:
|
||||
call_data = Metadata.escape_query_parameters(call_data)
|
||||
get_fn = get_model_data
|
||||
else:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"Unsupported entity type: {str(entity_type)}"
|
||||
)
|
||||
|
||||
def getter(page: int, page_size: int) -> Sequence[dict]:
|
||||
call_data.pop("scroll_id", None)
|
||||
call_data["page"] = page
|
||||
call_data["page_size"] = page_size
|
||||
return get_fn()
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
@endpoint("organization.prepare_download_for_get_all")
|
||||
def prepare_download_for_get_all(
|
||||
call: APICall, company: str, request: PrepareDownloadForGetAll
|
||||
):
|
||||
# validate input params
|
||||
getter = _get_download_getter_fn(
|
||||
company,
|
||||
call,
|
||||
call_data=call.data.copy(),
|
||||
allow_public=request.allow_public,
|
||||
entity_type=request.entity_type,
|
||||
)
|
||||
if getter:
|
||||
getter(0, 1)
|
||||
|
||||
redis.setex(
|
||||
f"get_all_download_{call.id}",
|
||||
int(config.get("services.organization.download.redis_timeout_sec", 300)),
|
||||
json.dumps(call.data),
|
||||
)
|
||||
|
||||
call.result.data = dict(prepare_id=call.id)
|
||||
|
||||
|
||||
@endpoint("organization.download_for_get_all")
|
||||
def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
|
||||
request_data = redis.get(f"get_all_download_{request.prepare_id}")
|
||||
if not request_data:
|
||||
raise errors.bad_request.InvalidId(
|
||||
f"prepare ID not found", prepare_id=request.prepare_id
|
||||
)
|
||||
|
||||
try:
|
||||
call_data = json.loads(request_data)
|
||||
request = PrepareDownloadForGetAll(**call_data)
|
||||
except Exception as ex:
|
||||
raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
|
||||
|
||||
def generate():
|
||||
projection = call_data.get("only_fields", [])
|
||||
headers = ",".join(projection)
|
||||
|
||||
get_fn = _get_download_getter_fn(
|
||||
company,
|
||||
call,
|
||||
call_data=call_data,
|
||||
allow_public=request.allow_public,
|
||||
entity_type=request.entity_type,
|
||||
)
|
||||
if not get_fn:
|
||||
return headers
|
||||
|
||||
fields = [path.split(".") for path in projection]
|
||||
|
||||
def get_entity_field_as_str(data: dict, field: Sequence[str]) -> str:
|
||||
val = nested_get(data, field, "")
|
||||
if isinstance(val, dict):
|
||||
val = val.get("id", "")
|
||||
|
||||
return str(val)
|
||||
|
||||
def get_string_from_entity_data(data: dict) -> str:
|
||||
return ",".join(get_entity_field_as_str(data, f) for f in fields)
|
||||
|
||||
with ThreadPoolExecutor(1) as pool:
|
||||
page = 0
|
||||
page_size = int(
|
||||
config.get("services.organization.download.batch_size", 500)
|
||||
)
|
||||
future = pool.submit(get_fn, page, page_size)
|
||||
out = [headers]
|
||||
|
||||
while True:
|
||||
result = future.result()
|
||||
if not result:
|
||||
break
|
||||
|
||||
page += 1
|
||||
future = pool.submit(get_fn, page, page_size)
|
||||
|
||||
out.extend(get_string_from_entity_data(r) for r in result)
|
||||
yield "\n".join(out) + "\n"
|
||||
out = []
|
||||
|
||||
if out:
|
||||
yield "\n".join(out)
|
||||
|
||||
call.result.filename = f"{request.entity_type}_export.{request.download_type}"
|
||||
call.result.content_type = "text/csv"
|
||||
call.result.raw_data = stream_with_context(generate())
|
||||
|
@ -83,7 +83,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
|
||||
call_data = Metadata.escape_query_parameters(call)
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
queues = queue_bll.get_queue_infos(
|
||||
company_id=company,
|
||||
query_dict=call_data,
|
||||
@ -99,7 +99,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
|
||||
def get_all(call: APICall, company: str, request: GetAllRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
call_data = Metadata.escape_query_parameters(call)
|
||||
call_data = Metadata.escape_query_parameters(call.data)
|
||||
queues = queue_bll.get_all(
|
||||
company_id=company,
|
||||
query_dict=call_data,
|
||||
|
@ -222,7 +222,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
entity_cls = Task
|
||||
conform_data = conform_task_data
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call.data)
|
||||
process_include_subprojects(call_data)
|
||||
|
||||
ret_params = {}
|
||||
|
@ -171,13 +171,13 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
||||
call.result.data = {"task": task_dict}
|
||||
|
||||
|
||||
def escape_execution_parameters(call: APICall) -> dict:
|
||||
if not call.data:
|
||||
return call.data
|
||||
def escape_execution_parameters(call_data: dict) -> dict:
|
||||
if not call_data:
|
||||
return call_data
|
||||
|
||||
keys = list(call.data)
|
||||
keys = list(call_data)
|
||||
call_data = {
|
||||
safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys))
|
||||
safe_key: call_data[key] for key, safe_key in zip(keys, escape_paths(keys))
|
||||
}
|
||||
|
||||
projection = Task.get_projection(call_data)
|
||||
@ -204,7 +204,7 @@ def _hidden_query(data: dict) -> Q:
|
||||
@endpoint("tasks.get_all_ex")
|
||||
def get_all_ex(call: APICall, company_id, request: GetAllReq):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call.data)
|
||||
process_include_subprojects(call_data)
|
||||
ret_params = {}
|
||||
tasks = Task.get_many_with_join(
|
||||
@ -221,7 +221,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq):
|
||||
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
|
||||
def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call.data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
@ -233,7 +233,7 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
@endpoint("tasks.get_all", required_fields=[])
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
call_data = escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call.data)
|
||||
process_include_subprojects(call_data)
|
||||
|
||||
ret_params = {}
|
||||
|
Loading…
Reference in New Issue
Block a user