smallpond/tests/test_arrow.py
2025-03-05 22:46:23 +08:00

157 lines
7.5 KiB
Python

import glob
import os.path
import tempfile
import unittest
import pyarrow.parquet as parquet
from loguru import logger
from smallpond.io.arrow import (
RowRange,
build_batch_reader_from_files,
cast_columns_to_large_string,
dump_to_parquet_files,
load_from_parquet_files,
)
from smallpond.utility import ConcurrentIter
from tests.test_fabric import TestFabric
class TestArrow(TestFabric, unittest.TestCase):
def test_load_from_parquet_files(self):
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected = self._load_parquet_files(parquet_files)
actual = load_from_parquet_files(parquet_files)
self._compare_arrow_tables(expected, actual)
def test_load_parquet_row_ranges(self):
for dataset_path in (
"tests/data/arrow/data0.parquet",
"tests/data/large_array/large_array.parquet",
):
with self.subTest(dataset_path=dataset_path):
metadata = parquet.read_metadata(dataset_path)
file_num_rows = metadata.num_rows
data_size = sum(metadata.row_group(i).total_byte_size for i in range(metadata.num_row_groups))
row_range = RowRange(
path=dataset_path,
begin=100,
end=200,
data_size=data_size,
file_num_rows=file_num_rows,
)
expected = self._load_parquet_files([dataset_path]).slice(offset=100, length=100)
actual = load_from_parquet_files([row_range])
self._compare_arrow_tables(expected, actual)
def test_dump_to_parquet_files(self):
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected = self._load_parquet_files(parquet_files)
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ok = dump_to_parquet_files(expected, output_dir)
self.assertTrue(ok)
actual = self._load_parquet_files(glob.glob(f"{output_dir}/*.parquet"))
self._compare_arrow_tables(expected, actual)
def test_dump_load_empty_table(self):
# create empty table
empty_table = self._load_parquet_files(["tests/data/arrow/data0.parquet"]).slice(length=0)
self.assertEqual(empty_table.num_rows, 0)
# dump empty table
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ok = dump_to_parquet_files(empty_table, output_dir)
self.assertTrue(ok)
parquet_files = glob.glob(f"{output_dir}/*.parquet")
# load empty table from file
actual_table = load_from_parquet_files(parquet_files)
self._compare_arrow_tables(empty_table, actual_table)
def test_parquet_batch_reader(self):
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected_num_rows = sum(parquet.read_metadata(file).num_rows for file in parquet_files)
with build_batch_reader_from_files(
parquet_files,
batch_size=expected_num_rows,
max_batch_byte_size=None,
) as batch_reader, ConcurrentIter(batch_reader) as concurrent_iter:
total_num_rows = 0
for batch in concurrent_iter:
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}")
self.assertLessEqual(batch.num_rows, expected_num_rows)
total_num_rows += batch.num_rows
self.assertEqual(total_num_rows, expected_num_rows)
def test_table_to_batches(self):
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
table = self._load_parquet_files(parquet_files)
total_num_rows = 0
for batch in table.to_batches(max_chunksize=table.num_rows):
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}")
self.assertLessEqual(batch.num_rows, table.num_rows)
total_num_rows += batch.num_rows
self.assertEqual(total_num_rows, table.num_rows)
def test_arrow_schema_metadata(self):
table = self._load_parquet_files(glob.glob("tests/data/arrow/*.parquet"))
metadata = {b"a": b"1", b"b": b"2"}
table_with_meta = table.replace_schema_metadata(metadata)
print(f"table_with_meta.schema.metadata {table_with_meta.schema.metadata}")
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
self.assertTrue(dump_to_parquet_files(table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2))
parquet_files = glob.glob(os.path.join(output_dir, "arrow_schema_metadata*.parquet"))
loaded_table = load_from_parquet_files(parquet_files, table.column_names[:1])
print(f"loaded_table.schema.metadata {loaded_table.schema.metadata}")
self.assertEqual(table_with_meta.schema.metadata, loaded_table.schema.metadata)
with parquet.ParquetFile(parquet_files[0]) as file:
print(f"file.schema_arrow.metadata {file.schema_arrow.metadata}")
self.assertEqual(table_with_meta.schema.metadata, file.schema_arrow.metadata)
def test_load_mixed_string_types(self):
parquet_paths = glob.glob("tests/data/arrow/*.parquet")
table = self._load_parquet_files(parquet_paths)
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
dump_to_parquet_files(cast_columns_to_large_string(table), output_dir)
parquet_paths += glob.glob(os.path.join(output_dir, "*.parquet"))
loaded_table = load_from_parquet_files(parquet_paths)
self.assertEqual(table.num_rows * 2, loaded_table.num_rows)
batch_reader = build_batch_reader_from_files(parquet_paths)
self.assertEqual(table.num_rows * 2, sum(batch.num_rows for batch in batch_reader))
@logger.catch(reraise=True, message="failed to load parquet files")
def _load_from_parquet_files_with_log(self, paths, columns):
load_from_parquet_files(paths, columns)
def test_load_not_exist_column(self):
parquet_files = glob.glob("tests/data/arrow/*.parquet")
with self.assertRaises(AssertionError) as context:
self._load_from_parquet_files_with_log(parquet_files, ["not_exist_column"])
def test_change_ordering_of_columns(self):
parquet_files = glob.glob("tests/data/arrow/*.parquet")
loaded_table = load_from_parquet_files(parquet_files)
reversed_cols = list(reversed(loaded_table.column_names))
loaded_table = load_from_parquet_files(parquet_files, reversed_cols)
self.assertEqual(loaded_table.column_names, reversed_cols)