diff --git a/apiserver/apierrors/errors.conf b/apiserver/apierrors/errors.conf index 22dbb0d..ca78c81 100644 --- a/apiserver/apierrors/errors.conf +++ b/apiserver/apierrors/errors.conf @@ -27,7 +27,7 @@ 24: ["not_public_object", "object is not public"] # Auth / Login - 75: ["invalid_access_key", "access key not found for user"] + 75: ["invalid_access_key", "access key not found"] # Tasks 100: ["task_error", "general task error"] diff --git a/apiserver/config/default/services/organization.conf b/apiserver/config/default/services/organization.conf index f98638d..14a9432 100644 --- a/apiserver/config/default/services/organization.conf +++ b/apiserver/config/default/services/organization.conf @@ -4,5 +4,6 @@ tags_cache { download { redis_timeout_sec: 300 batch_size: 500 + max_download_items: 50000 max_project_name_length: 60 } \ No newline at end of file diff --git a/apiserver/schema/services/users.conf b/apiserver/schema/services/users.conf index dd3c37f..ed3619d 100644 --- a/apiserver/schema/services/users.conf +++ b/apiserver/schema/services/users.conf @@ -155,6 +155,17 @@ get_current_user { } } } + "2.26": ${get_current_user."2.20"} { + response.properties.settings { + type: object + properties { + max_download_items { + type: string + description: The maximum items downloaded for this user in csv file downloads + } + } + } + } } get_all_ex { diff --git a/apiserver/services/organization.py b/apiserver/services/organization.py index 0c67ddb..3f3868b 100644 --- a/apiserver/services/organization.py +++ b/apiserver/services/organization.py @@ -39,6 +39,7 @@ from apiserver.utilities.dicts import nested_get org_bll = OrgBLL() project_bll = ProjectBLL() redis = redman.connection("apiserver") +conf = config.get("services.organization") @endpoint("organization.get_tags", request_data_model=TagsRequest) @@ -182,9 +183,6 @@ def _get_download_getter_fn( return getter -download_conf = config.get("services.organization.download") - - @endpoint("organization.prepare_download_for_get_all") def prepare_download_for_get_all( call: APICall, company: str, request: PrepareDownloadForGetAllRequest @@ -214,12 +212,13 @@ def prepare_download_for_get_all( allow_public=request.allow_public, entity_type=request.entity_type, ) + # retrieve one element just to make sure that there are no issues with the call parameters if getter: getter(0, 1) redis.setex( f"get_all_download_{call.id}", - int(download_conf.get("redis_timeout_sec", 300)), + int(conf.get("download.redis_timeout_sec", 300)), json.dumps(call.data), ) @@ -250,7 +249,8 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAllReque mapping.get("name", mapping["field"]): { "field_path": mapping["field"].split("."), "values": { - v.get("key"): v.get("value") for v in (mapping.get("values") or []) + v.get("key"): v.get("value") + for v in (mapping.get("values") or []) }, } for mapping in call_data.get("field_mappings", []) @@ -287,16 +287,18 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAllReque with ThreadPoolExecutor(1) as pool: page = 0 - page_size = int(download_conf.get("batch_size", 500)) - future = pool.submit(get_fn, page, page_size) - - while True: + page_size = int(conf.get("download.batch_size", 500)) + items_left = int(conf.get("download.max_download_items", 1000)) + future = pool.submit(get_fn, page, min(page_size, items_left)) + while items_left > 0: result = future.result() if not result: break + items_left -= len(result) page += 1 - future = pool.submit(get_fn, page, page_size) + if items_left > 0: + future = pool.submit(get_fn, page, min(page_size, items_left)) with StringIO() as fp: writer = csv.writer(fp) @@ -323,7 +325,7 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAllReque if not project: return - return project.basename[: download_conf.get("max_project_name_length", 60)] + return project.basename[: conf.get("download.max_project_name_length", 60)] call.result.filename = "-".join( filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv")) diff --git a/apiserver/services/users.py b/apiserver/services/users.py index 70f1f93..e2a9365 100644 --- a/apiserver/services/users.py +++ b/apiserver/services/users.py @@ -98,9 +98,7 @@ def get_current_user(call: APICall, company_id, _): user_id = call.identity.user projection = ( - {"company.name"} - .union(User.get_fields()) - .difference(User.get_exclude_fields()) + {"company.name"}.union(User.get_fields()).difference(User.get_exclude_fields()) ) res = User.get_many_with_join( query=Q(id=user_id), @@ -114,9 +112,13 @@ def get_current_user(call: APICall, company_id, _): user = res[0] user["role"] = call.identity.role - resp = { - "user": user, - "getting_started": config.get("apiserver.getting_started_info", None), + resp = dict( + user=user, getting_started=config.get("apiserver.getting_started_info", None) + ) + resp["settings"] = { + "max_download_items": int( + config.get("services.organization.download.max_download_items", 1000) + ) } call.result.data = resp