Add max_download_items to users.get_current_user endpoint response

This commit is contained in:
allegroai 2023-07-26 18:45:42 +03:00
parent 752020c66a
commit c196043d2a
5 changed files with 34 additions and 18 deletions

View File

@ -27,7 +27,7 @@
24: ["not_public_object", "object is not public"] 24: ["not_public_object", "object is not public"]
# Auth / Login # Auth / Login
75: ["invalid_access_key", "access key not found for user"] 75: ["invalid_access_key", "access key not found"]
# Tasks # Tasks
100: ["task_error", "general task error"] 100: ["task_error", "general task error"]

View File

@ -4,5 +4,6 @@ tags_cache {
download { download {
redis_timeout_sec: 300 redis_timeout_sec: 300
batch_size: 500 batch_size: 500
max_download_items: 50000
max_project_name_length: 60 max_project_name_length: 60
} }

View File

@ -155,6 +155,17 @@ get_current_user {
} }
} }
} }
"2.26": ${get_current_user."2.20"} {
response.properties.settings {
type: object
properties {
max_download_items {
type: string
description: The maximum items downloaded for this user in csv file downloads
}
}
}
}
} }
get_all_ex { get_all_ex {

View File

@ -39,6 +39,7 @@ from apiserver.utilities.dicts import nested_get
org_bll = OrgBLL() org_bll = OrgBLL()
project_bll = ProjectBLL() project_bll = ProjectBLL()
redis = redman.connection("apiserver") redis = redman.connection("apiserver")
conf = config.get("services.organization")
@endpoint("organization.get_tags", request_data_model=TagsRequest) @endpoint("organization.get_tags", request_data_model=TagsRequest)
@ -182,9 +183,6 @@ def _get_download_getter_fn(
return getter return getter
download_conf = config.get("services.organization.download")
@endpoint("organization.prepare_download_for_get_all") @endpoint("organization.prepare_download_for_get_all")
def prepare_download_for_get_all( def prepare_download_for_get_all(
call: APICall, company: str, request: PrepareDownloadForGetAllRequest call: APICall, company: str, request: PrepareDownloadForGetAllRequest
@ -214,12 +212,13 @@ def prepare_download_for_get_all(
allow_public=request.allow_public, allow_public=request.allow_public,
entity_type=request.entity_type, entity_type=request.entity_type,
) )
# retrieve one element just to make sure that there are no issues with the call parameters
if getter: if getter:
getter(0, 1) getter(0, 1)
redis.setex( redis.setex(
f"get_all_download_{call.id}", f"get_all_download_{call.id}",
int(download_conf.get("redis_timeout_sec", 300)), int(conf.get("download.redis_timeout_sec", 300)),
json.dumps(call.data), json.dumps(call.data),
) )
@ -250,7 +249,8 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAllReque
mapping.get("name", mapping["field"]): { mapping.get("name", mapping["field"]): {
"field_path": mapping["field"].split("."), "field_path": mapping["field"].split("."),
"values": { "values": {
v.get("key"): v.get("value") for v in (mapping.get("values") or []) v.get("key"): v.get("value")
for v in (mapping.get("values") or [])
}, },
} }
for mapping in call_data.get("field_mappings", []) for mapping in call_data.get("field_mappings", [])
@ -287,16 +287,18 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAllReque
with ThreadPoolExecutor(1) as pool: with ThreadPoolExecutor(1) as pool:
page = 0 page = 0
page_size = int(download_conf.get("batch_size", 500)) page_size = int(conf.get("download.batch_size", 500))
future = pool.submit(get_fn, page, page_size) items_left = int(conf.get("download.max_download_items", 1000))
future = pool.submit(get_fn, page, min(page_size, items_left))
while True: while items_left > 0:
result = future.result() result = future.result()
if not result: if not result:
break break
items_left -= len(result)
page += 1 page += 1
future = pool.submit(get_fn, page, page_size) if items_left > 0:
future = pool.submit(get_fn, page, min(page_size, items_left))
with StringIO() as fp: with StringIO() as fp:
writer = csv.writer(fp) writer = csv.writer(fp)
@ -323,7 +325,7 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAllReque
if not project: if not project:
return return
return project.basename[: download_conf.get("max_project_name_length", 60)] return project.basename[: conf.get("download.max_project_name_length", 60)]
call.result.filename = "-".join( call.result.filename = "-".join(
filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv")) filter(None, ("clearml", get_project_name(), f"{request.entity_type}s.csv"))

View File

@ -98,9 +98,7 @@ def get_current_user(call: APICall, company_id, _):
user_id = call.identity.user user_id = call.identity.user
projection = ( projection = (
{"company.name"} {"company.name"}.union(User.get_fields()).difference(User.get_exclude_fields())
.union(User.get_fields())
.difference(User.get_exclude_fields())
) )
res = User.get_many_with_join( res = User.get_many_with_join(
query=Q(id=user_id), query=Q(id=user_id),
@ -114,9 +112,13 @@ def get_current_user(call: APICall, company_id, _):
user = res[0] user = res[0]
user["role"] = call.identity.role user["role"] = call.identity.role
resp = { resp = dict(
"user": user, user=user, getting_started=config.get("apiserver.getting_started_info", None)
"getting_started": config.get("apiserver.getting_started_info", None), )
resp["settings"] = {
"max_download_items": int(
config.get("services.organization.download.max_download_items", 1000)
)
} }
call.result.data = resp call.result.data = resp