clearml-server/apiserver/services/organization.py

335 lines
11 KiB
Python

import csv
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from operator import itemgetter
from typing import Mapping, Type, Sequence, Optional, Callable, Hashable
from flask import stream_with_context
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.apimodels.organization import (
TagsRequest,
EntitiesCountRequest,
DownloadForGetAllRequest,
EntityType,
PrepareDownloadForGetAllRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.config_repo import config
from apiserver.database.model import User, AttributedDocument, EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskType
from apiserver.redis_manager import redman
from apiserver.service_repo import endpoint, APICall
from apiserver.services.models import conform_model_data
from apiserver.services.tasks import (
escape_execution_parameters,
_hidden_query,
conform_task_data,
)
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get
org_bll = OrgBLL()
project_bll = ProjectBLL()
redis = redman.connection("apiserver")
conf = config.get("services.organization")
@endpoint("organization.get_tags", request_data_model=TagsRequest)
def get_tags(call: APICall, company, request: TagsRequest):
filter_dict = get_tags_filter_dictionary(request.filter)
ret = defaultdict(set)
for entity in Tags.Model, Tags.Task:
tags = org_bll.get_tags(
company,
entity,
include_system=request.include_system,
filter_=filter_dict,
)
for field, vals in tags.items():
ret[field] |= vals
call.result.data = sort_tags_response(ret)
@endpoint("organization.get_user_companies")
def get_user_companies(call: APICall, company_id: str, _):
users = [
{"id": u.id, "name": u.name, "avatar": u.avatar}
for u in User.objects(company=company_id).only("avatar", "name", "company")
]
call.result.data = {
"companies": [
{
"id": company_id,
"name": call.identity.company_name,
"allocated": len(users),
"owners": sorted(users, key=itemgetter("name")),
}
]
}
@endpoint("organization.get_entities_count")
def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
entity_classes: Mapping[str, Type[AttributedDocument]] = {
"projects": Project,
"tasks": Task,
"models": Model,
"pipelines": Project,
"datasets": Project,
"reports": Task,
}
ret = {}
for field, entity_cls in entity_classes.items():
data = call.data.get(field)
if data is None:
continue
if field == "reports":
data["type"] = TaskType.report
data["include_subprojects"] = True
if request.active_users:
if entity_cls is Project:
requested_ids = data.get("id")
if isinstance(requested_ids, str):
requested_ids = [requested_ids]
ids, _ = project_bll.get_projects_with_selected_children(
company=company,
users=request.active_users,
project_ids=requested_ids,
allow_public=request.allow_public,
)
if not ids:
ret[field] = 0
continue
data["id"] = ids
elif not data.get("user"):
data["user"] = request.active_users
query = Q()
if (
entity_cls in (Project, Task)
and field not in ("reports", "pipelines", "datasets")
and not request.search_hidden
):
query &= Q(system_tags__ne=EntityVisibility.hidden.value)
ret[field] = entity_cls.get_count(
company=company,
query_dict=data,
query=query,
allow_public=request.allow_public,
)
call.result.data = ret
def _get_download_getter_fn(
company: str,
call: APICall,
call_data: dict,
allow_public: bool,
entity_type: EntityType,
) -> Optional[Callable[[int, int], Sequence[dict]]]:
def get_task_data() -> Sequence[dict]:
tasks = Task.get_many_with_join(
company=company,
query_dict=call_data,
query=_hidden_query(call_data),
allow_public=allow_public,
)
conform_task_data(call, tasks)
return tasks
def get_model_data() -> Sequence[dict]:
models = Model.get_many_with_join(
company=company,
query_dict=call_data,
allow_public=allow_public,
)
conform_model_data(call, models)
return models
if entity_type == EntityType.task:
call_data = escape_execution_parameters(call_data)
get_fn = get_task_data
elif entity_type == EntityType.model:
call_data = Metadata.escape_query_parameters(call_data)
get_fn = get_model_data
else:
raise errors.bad_request.ValidationError(
f"Unsupported entity type: {str(entity_type)}"
)
def getter(page: int, page_size: int) -> Sequence[dict]:
call_data.pop("scroll_id", None)
call_data.pop("start", None)
call_data.pop("size", None)
call_data.pop("refresh_scroll", None)
call_data["page"] = page
call_data["page_size"] = page_size
return get_fn()
return getter
@endpoint("organization.prepare_download_for_get_all")
def prepare_download_for_get_all(
call: APICall, company: str, request: PrepareDownloadForGetAllRequest
):
# validate input params
field_names = set()
for fm in request.field_mappings:
name = fm.name or fm.field
if name in field_names:
raise errors.bad_request.ValidationError(
f"Field_name appears more than once in field_mappings: {str(name)}"
)
field_names.add(name)
if fm.values:
value_keys = set()
for v in fm.values:
if v.key in value_keys:
raise errors.bad_request.ValidationError(
f"Value key appears more than once in field_mappings: {str(v.key)}"
)
value_keys.add(v.key)
getter = _get_download_getter_fn(
company,
call,
call_data=call.data.copy(),
allow_public=request.allow_public,
entity_type=request.entity_type,
)
# retrieve one element just to make sure that there are no issues with the call parameters
if getter:
getter(0, 1)
redis.setex(
f"get_all_download_{call.id}",
int(conf.get("download.redis_timeout_sec", 300)),
json.dumps(call.data),
)
call.result.data = dict(prepare_id=call.id)
@endpoint("organization.download_for_get_all")
def download_for_get_all(call: APICall, company, request: DownloadForGetAllRequest):
request_data = redis.get(f"get_all_download_{request.prepare_id}")
if not request_data:
raise errors.bad_request.InvalidId(
f"prepare ID not found", prepare_id=request.prepare_id
)
try:
call_data = json.loads(request_data)
request = PrepareDownloadForGetAllRequest(**call_data)
except Exception as ex:
raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
class SingleLine:
@staticmethod
def write(line: str) -> str:
return line
def generate():
field_mappings = {
mapping.get("name", mapping["field"]): {
"field_path": mapping["field"].split("."),
"values": {
v.get("key"): v.get("value")
for v in (mapping.get("values") or [])
},
}
for mapping in call_data.get("field_mappings", [])
}
get_fn = _get_download_getter_fn(
company,
call,
call_data=call_data,
allow_public=request.allow_public,
entity_type=request.entity_type,
)
if not get_fn:
yield csv.writer(SingleLine()).writerow(field_mappings)
return
def get_entity_field_as_str(
data: dict, field_path: Sequence[str], values: Mapping
) -> str:
val = nested_get(data, field_path, "")
if isinstance(val, dict):
val = val.get("id", "")
if values and isinstance(val, Hashable):
val = values.get(val, val)
return str(val)
def get_projected_fields(data: dict) -> Sequence[str]:
return [
get_entity_field_as_str(
data, field_path=m["field_path"], values=m["values"]
)
for m in field_mappings.values()
]
with ThreadPoolExecutor(1) as pool:
page = 0
page_size = int(conf.get("download.batch_size", 500))
items_left = int(conf.get("download.max_download_items", 1000))
future = pool.submit(get_fn, page, min(page_size, items_left))
while items_left > 0:
result = future.result()
if not result:
break
items_left -= len(result)
page += 1
if items_left > 0:
future = pool.submit(get_fn, page, min(page_size, items_left))
with StringIO() as fp:
writer = csv.writer(fp)
if page == 1:
fp.write("\ufeff") # utf-8 signature
writer.writerow(field_mappings)
writer.writerows(get_projected_fields(r) for r in result)
yield fp.getvalue()
if page == 0:
yield csv.writer(SingleLine()).writerow(field_mappings)
def get_project_name() -> Optional[str]:
projects = call_data.get("project")
if not projects or not isinstance(projects, (list, str)):
return
if isinstance(projects, list):
if len(projects) > 1:
return
projects = projects[0]
if projects is None:
return "root"
project: Project = Project.objects(id=projects).only("basename").first()
if not project:
return
return project.basename[: conf.get("download.max_project_name_length", 60)]
call.result.filename = "-".join(
filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv"))
)
call.result.content_type = "text/csv"
call.result.raw_data = stream_with_context(generate())