import csv from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from io import StringIO from operator import itemgetter from typing import Mapping, Type, Sequence, Optional, Callable, Hashable from flask import stream_with_context from mongoengine import Q from apiserver.apierrors import errors from apiserver.apimodels.organization import ( TagsRequest, EntitiesCountRequest, DownloadForGetAllRequest, EntityType, PrepareDownloadForGetAllRequest, ) from apiserver.bll.model import Metadata from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL from apiserver.config_repo import config from apiserver.database.model import User, AttributedDocument, EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskType from apiserver.redis_manager import redman from apiserver.service_repo import endpoint, APICall from apiserver.services.models import conform_model_data from apiserver.services.tasks import ( escape_execution_parameters, _hidden_query, conform_task_data, ) from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response from apiserver.utilities import json from apiserver.utilities.dicts import nested_get org_bll = OrgBLL() project_bll = ProjectBLL() redis = redman.connection("apiserver") conf = config.get("services.organization") @endpoint("organization.get_tags", request_data_model=TagsRequest) def get_tags(call: APICall, company, request: TagsRequest): filter_dict = get_tags_filter_dictionary(request.filter) ret = defaultdict(set) for entity in Tags.Model, Tags.Task: tags = org_bll.get_tags( company, entity, include_system=request.include_system, filter_=filter_dict, ) for field, vals in tags.items(): ret[field] |= vals call.result.data = sort_tags_response(ret) @endpoint("organization.get_user_companies") def get_user_companies(call: APICall, company_id: str, _): users = [ {"id": u.id, "name": u.name, "avatar": u.avatar} for u in User.objects(company=company_id).only("avatar", "name", "company") ] call.result.data = { "companies": [ { "id": company_id, "name": call.identity.company_name, "allocated": len(users), "owners": sorted(users, key=itemgetter("name")), } ] } @endpoint("organization.get_entities_count") def get_entities_count(call: APICall, company, request: EntitiesCountRequest): entity_classes: Mapping[str, Type[AttributedDocument]] = { "projects": Project, "tasks": Task, "models": Model, "pipelines": Project, "datasets": Project, "reports": Task, } ret = {} for field, entity_cls in entity_classes.items(): data = call.data.get(field) if data is None: continue if field == "reports": data["type"] = TaskType.report data["include_subprojects"] = True if request.active_users: if entity_cls is Project: requested_ids = data.get("id") if isinstance(requested_ids, str): requested_ids = [requested_ids] ids, _ = project_bll.get_projects_with_selected_children( company=company, users=request.active_users, project_ids=requested_ids, allow_public=request.allow_public, ) if not ids: ret[field] = 0 continue data["id"] = ids elif not data.get("user"): data["user"] = request.active_users query = Q() if ( entity_cls in (Project, Task) and field not in ("reports", "pipelines", "datasets") and not request.search_hidden ): query &= Q(system_tags__ne=EntityVisibility.hidden.value) ret[field] = entity_cls.get_count( company=company, query_dict=data, query=query, allow_public=request.allow_public, ) call.result.data = ret def _get_download_getter_fn( company: str, call: APICall, call_data: dict, allow_public: bool, entity_type: EntityType, ) -> Optional[Callable[[int, int], Sequence[dict]]]: def get_task_data() -> Sequence[dict]: tasks = Task.get_many_with_join( company=company, query_dict=call_data, query=_hidden_query(call_data), allow_public=allow_public, ) conform_task_data(call, tasks) return tasks def get_model_data() -> Sequence[dict]: models = Model.get_many_with_join( company=company, query_dict=call_data, allow_public=allow_public, ) conform_model_data(call, models) return models if entity_type == EntityType.task: call_data = escape_execution_parameters(call_data) get_fn = get_task_data elif entity_type == EntityType.model: call_data = Metadata.escape_query_parameters(call_data) get_fn = get_model_data else: raise errors.bad_request.ValidationError( f"Unsupported entity type: {str(entity_type)}" ) def getter(page: int, page_size: int) -> Sequence[dict]: call_data.pop("scroll_id", None) call_data.pop("start", None) call_data.pop("size", None) call_data.pop("refresh_scroll", None) call_data["page"] = page call_data["page_size"] = page_size return get_fn() return getter @endpoint("organization.prepare_download_for_get_all") def prepare_download_for_get_all( call: APICall, company: str, request: PrepareDownloadForGetAllRequest ): # validate input params field_names = set() for fm in request.field_mappings: name = fm.name or fm.field if name in field_names: raise errors.bad_request.ValidationError( f"Field_name appears more than once in field_mappings: {str(name)}" ) field_names.add(name) if fm.values: value_keys = set() for v in fm.values: if v.key in value_keys: raise errors.bad_request.ValidationError( f"Value key appears more than once in field_mappings: {str(v.key)}" ) value_keys.add(v.key) getter = _get_download_getter_fn( company, call, call_data=call.data.copy(), allow_public=request.allow_public, entity_type=request.entity_type, ) # retrieve one element just to make sure that there are no issues with the call parameters if getter: getter(0, 1) redis.setex( f"get_all_download_{call.id}", int(conf.get("download.redis_timeout_sec", 300)), json.dumps(call.data), ) call.result.data = dict(prepare_id=call.id) @endpoint("organization.download_for_get_all") def download_for_get_all(call: APICall, company, request: DownloadForGetAllRequest): request_data = redis.get(f"get_all_download_{request.prepare_id}") if not request_data: raise errors.bad_request.InvalidId( f"prepare ID not found", prepare_id=request.prepare_id ) try: call_data = json.loads(request_data) request = PrepareDownloadForGetAllRequest(**call_data) except Exception as ex: raise errors.server_error.DataError("failed parsing prepared data", ex=ex) class SingleLine: @staticmethod def write(line: str) -> str: return line def generate(): field_mappings = { mapping.get("name", mapping["field"]): { "field_path": mapping["field"].split("."), "values": { v.get("key"): v.get("value") for v in (mapping.get("values") or []) }, } for mapping in call_data.get("field_mappings", []) } get_fn = _get_download_getter_fn( company, call, call_data=call_data, allow_public=request.allow_public, entity_type=request.entity_type, ) if not get_fn: yield csv.writer(SingleLine()).writerow(field_mappings) return def get_entity_field_as_str( data: dict, field_path: Sequence[str], values: Mapping ) -> str: val = nested_get(data, field_path, "") if isinstance(val, dict): val = val.get("id", "") if values and isinstance(val, Hashable): val = values.get(val, val) return str(val) def get_projected_fields(data: dict) -> Sequence[str]: return [ get_entity_field_as_str( data, field_path=m["field_path"], values=m["values"] ) for m in field_mappings.values() ] with ThreadPoolExecutor(1) as pool: page = 0 page_size = int(conf.get("download.batch_size", 500)) items_left = int(conf.get("download.max_download_items", 1000)) future = pool.submit(get_fn, page, min(page_size, items_left)) while items_left > 0: result = future.result() if not result: break items_left -= len(result) page += 1 if items_left > 0: future = pool.submit(get_fn, page, min(page_size, items_left)) with StringIO() as fp: writer = csv.writer(fp) if page == 1: fp.write("\ufeff") # utf-8 signature writer.writerow(field_mappings) writer.writerows(get_projected_fields(r) for r in result) yield fp.getvalue() if page == 0: yield csv.writer(SingleLine()).writerow(field_mappings) def get_project_name() -> Optional[str]: projects = call_data.get("project") if not projects or not isinstance(projects, (list, str)): return if isinstance(projects, list): if len(projects) > 1: return projects = projects[0] if projects is None: return "root" project: Project = Project.objects(id=projects).only("basename").first() if not project: return return project.basename[: conf.get("download.max_project_name_length", 60)] call.result.filename = "-".join( filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv")) ) call.result.content_type = "text/csv" call.result.raw_data = stream_with_context(generate())