Fix csv export handling "," in fields

This commit is contained in:
allegroai 2023-07-26 18:35:31 +03:00
parent 5d3ba4fa73
commit 5cd59ea6e3
3 changed files with 19 additions and 22 deletions

View File

@ -29,17 +29,12 @@ class EntitiesCountRequest(models.Base):
allow_public = fields.BoolField(default=True)
class DownloadType(StringEnum):
csv = auto()
class EntityType(StringEnum):
task = auto()
model = auto()
class PrepareDownloadForGetAll(models.Base):
download_type = ActualEnumField(DownloadType, default=DownloadType.csv)
entity_type = ActualEnumField(EntityType)
allow_public = fields.BoolField(default=True)
search_hidden = fields.BoolField(default=False)

View File

@ -209,12 +209,6 @@ prepare_download_for_get_all {
type: array
items {type: string}
}
download_type {
description: "Download type. Determines the downloaded file's formatting and mime type."
type: string
enum: [ csv ]
default: csv
}
allow_public {
description: "Allow public entities to be returned in the results"
type: boolean

View File

@ -1,5 +1,7 @@
import csv
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from io import StringIO
from operator import itemgetter
from typing import Mapping, Type, Sequence, Optional, Callable
@ -213,9 +215,12 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
except Exception as ex:
raise errors.server_error.DataError("failed parsing prepared data", ex=ex)
class SingleLine:
def write(self, line: str) -> str:
return line
def generate():
projection = call_data.get("only_fields", [])
headers = ",".join(projection)
get_fn = _get_download_getter_fn(
company,
@ -225,7 +230,8 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
entity_type=request.entity_type,
)
if not get_fn:
return headers
yield csv.writer(SingleLine()).writerow(projection)
return
fields = [path.split(".") for path in projection]
@ -236,8 +242,8 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
return str(val)
def get_string_from_entity_data(data: dict) -> str:
return ",".join(get_entity_field_as_str(data, f) for f in fields)
def get_projected_fields(data: dict) -> Sequence[str]:
return [get_entity_field_as_str(data, f) for f in fields]
with ThreadPoolExecutor(1) as pool:
page = 0
@ -245,7 +251,6 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
config.get("services.organization.download.batch_size", 500)
)
future = pool.submit(get_fn, page, page_size)
out = [headers]
while True:
result = future.result()
@ -255,13 +260,16 @@ def download_for_get_all(call: APICall, company, request: DownloadForGetAll):
page += 1
future = pool.submit(get_fn, page, page_size)
out.extend(get_string_from_entity_data(r) for r in result)
yield "\n".join(out) + "\n"
out = []
with StringIO() as fp:
writer = csv.writer(fp)
if page == 1:
writer.writerow(projection)
writer.writerows(get_projected_fields(r) for r in result)
yield fp.getvalue()
if out:
yield "\n".join(out)
if page == 0:
yield csv.writer(SingleLine()).writerow(projection)
call.result.filename = f"{request.entity_type}_export.{request.download_type}"
call.result.filename = f"{request.entity_type}_export.csv"
call.result.content_type = "text/csv"
call.result.raw_data = stream_with_context(generate())