Fix csv export handling "," in fields

This commit is contained in:
allegroai 2023-07-26 18:35:31 +03:00
parent 5d3ba4fa73
commit 5cd59ea6e3
3 changed files with 19 additions and 22 deletions

View File

@ -29,17 +29,12 @@ class EntitiesCountRequest(models.Base):
allow_public = fields.BoolField(default=True) allow_public = fields.BoolField(default=True)
class DownloadType(StringEnum):
csv = auto()
class EntityType(StringEnum): class EntityType(StringEnum):
task = auto() task = auto()
model = auto() model = auto()
class PrepareDownloadForGetAll(models.Base): class PrepareDownloadForGetAll(models.Base):
download_type = ActualEnumField(DownloadType, default=DownloadType.csv)
entity_type = ActualEnumField(EntityType) entity_type = ActualEnumField(EntityType)
allow_public = fields.BoolField(default=True) allow_public = fields.BoolField(default=True)
search_hidden = fields.BoolField(default=False) search_hidden = fields.BoolField(default=False)

View File

@ -209,12 +209,6 @@ prepare_download_for_get_all {
type: array type: array
items {type: string} items {type: string}
} }
download_type {
description: "Download type. Determines the downloaded file's formatting and mime type."
type: string
enum: [ csv ]
default: csv
}
allow_public { allow_public {
description: "Allow public entities to be returned in the results" description: "Allow public entities to be returned in the results"
type: boolean type: boolean

View File

@ -1,5 +1,7 @@
import csv
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from operator import itemgetter from operator import itemgetter
from typing import Mapping, Type, Sequence, Optional, Callable from typing import Mapping, Type, Sequence, Optional, Callable
@ -213,9 +215,12 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
except Exception as ex: except Exception as ex:
raise errors.server_error.DataError("failed parsing prepared data", ex=ex) raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
class SingleLine:
def write(self, line: str) -> str:
return line
def generate(): def generate():
projection = call_data.get("only_fields", []) projection = call_data.get("only_fields", [])
headers = ",".join(projection)
get_fn = _get_download_getter_fn( get_fn = _get_download_getter_fn(
company, company,
@ -225,7 +230,8 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
entity_type=request.entity_type, entity_type=request.entity_type,
) )
if not get_fn: if not get_fn:
return headers yield csv.writer(SingleLine()).writerow(projection)
return
fields = [path.split(".") for path in projection] fields = [path.split(".") for path in projection]
@ -236,8 +242,8 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
return str(val) return str(val)
def get_string_from_entity_data(data: dict) -> str: def get_projected_fields(data: dict) -> Sequence[str]:
return ",".join(get_entity_field_as_str(data, f) for f in fields) return [get_entity_field_as_str(data, f) for f in fields]
with ThreadPoolExecutor(1) as pool: with ThreadPoolExecutor(1) as pool:
page = 0 page = 0
@ -245,7 +251,6 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
config.get("services.organization.download.batch_size", 500) config.get("services.organization.download.batch_size", 500)
) )
future = pool.submit(get_fn, page, page_size) future = pool.submit(get_fn, page, page_size)
out = [headers]
while True: while True:
result = future.result() result = future.result()
@ -255,13 +260,16 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
page += 1 page += 1
future = pool.submit(get_fn, page, page_size) future = pool.submit(get_fn, page, page_size)
out.extend(get_string_from_entity_data(r) for r in result) with StringIO() as fp:
yield "\n".join(out) + "\n" writer = csv.writer(fp)
out = [] if page == 1:
writer.writerow(projection)
writer.writerows(get_projected_fields(r) for r in result)
yield fp.getvalue()
if out: if page == 0:
yield "\n".join(out) yield csv.writer(SingleLine()).writerow(projection)
call.result.filename = f"{request.entity_type}_export.{request.download_type}" call.result.filename = f"{request.entity_type}_export.csv"
call.result.content_type = "text/csv" call.result.content_type = "text/csv"
call.result.raw_data = stream_with_context(generate()) call.result.raw_data = stream_with_context(generate())