mirror of
https://github.com/deepseek-ai/smallpond
synced 2025-06-26 18:27:45 +00:00
157 lines
7.5 KiB
Python
157 lines
7.5 KiB
Python
import glob
|
|
import os.path
|
|
import tempfile
|
|
import unittest
|
|
|
|
import pyarrow.parquet as parquet
|
|
from loguru import logger
|
|
|
|
from smallpond.io.arrow import (
|
|
RowRange,
|
|
build_batch_reader_from_files,
|
|
cast_columns_to_large_string,
|
|
dump_to_parquet_files,
|
|
load_from_parquet_files,
|
|
)
|
|
from smallpond.utility import ConcurrentIter
|
|
from tests.test_fabric import TestFabric
|
|
|
|
|
|
class TestArrow(TestFabric, unittest.TestCase):
|
|
def test_load_from_parquet_files(self):
|
|
for dataset_path in (
|
|
"tests/data/arrow/*.parquet",
|
|
"tests/data/large_array/*.parquet",
|
|
):
|
|
with self.subTest(dataset_path=dataset_path):
|
|
parquet_files = glob.glob(dataset_path)
|
|
expected = self._load_parquet_files(parquet_files)
|
|
actual = load_from_parquet_files(parquet_files)
|
|
self._compare_arrow_tables(expected, actual)
|
|
|
|
def test_load_parquet_row_ranges(self):
|
|
for dataset_path in (
|
|
"tests/data/arrow/data0.parquet",
|
|
"tests/data/large_array/large_array.parquet",
|
|
):
|
|
with self.subTest(dataset_path=dataset_path):
|
|
metadata = parquet.read_metadata(dataset_path)
|
|
file_num_rows = metadata.num_rows
|
|
data_size = sum(metadata.row_group(i).total_byte_size for i in range(metadata.num_row_groups))
|
|
row_range = RowRange(
|
|
path=dataset_path,
|
|
begin=100,
|
|
end=200,
|
|
data_size=data_size,
|
|
file_num_rows=file_num_rows,
|
|
)
|
|
expected = self._load_parquet_files([dataset_path]).slice(offset=100, length=100)
|
|
actual = load_from_parquet_files([row_range])
|
|
self._compare_arrow_tables(expected, actual)
|
|
|
|
def test_dump_to_parquet_files(self):
|
|
for dataset_path in (
|
|
"tests/data/arrow/*.parquet",
|
|
"tests/data/large_array/*.parquet",
|
|
):
|
|
with self.subTest(dataset_path=dataset_path):
|
|
parquet_files = glob.glob(dataset_path)
|
|
expected = self._load_parquet_files(parquet_files)
|
|
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
|
ok = dump_to_parquet_files(expected, output_dir)
|
|
self.assertTrue(ok)
|
|
actual = self._load_parquet_files(glob.glob(f"{output_dir}/*.parquet"))
|
|
self._compare_arrow_tables(expected, actual)
|
|
|
|
def test_dump_load_empty_table(self):
|
|
# create empty table
|
|
empty_table = self._load_parquet_files(["tests/data/arrow/data0.parquet"]).slice(length=0)
|
|
self.assertEqual(empty_table.num_rows, 0)
|
|
# dump empty table
|
|
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
|
ok = dump_to_parquet_files(empty_table, output_dir)
|
|
self.assertTrue(ok)
|
|
parquet_files = glob.glob(f"{output_dir}/*.parquet")
|
|
# load empty table from file
|
|
actual_table = load_from_parquet_files(parquet_files)
|
|
self._compare_arrow_tables(empty_table, actual_table)
|
|
|
|
def test_parquet_batch_reader(self):
|
|
for dataset_path in (
|
|
"tests/data/arrow/*.parquet",
|
|
"tests/data/large_array/*.parquet",
|
|
):
|
|
with self.subTest(dataset_path=dataset_path):
|
|
parquet_files = glob.glob(dataset_path)
|
|
expected_num_rows = sum(parquet.read_metadata(file).num_rows for file in parquet_files)
|
|
with build_batch_reader_from_files(
|
|
parquet_files,
|
|
batch_size=expected_num_rows,
|
|
max_batch_byte_size=None,
|
|
) as batch_reader, ConcurrentIter(batch_reader) as concurrent_iter:
|
|
total_num_rows = 0
|
|
for batch in concurrent_iter:
|
|
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}")
|
|
self.assertLessEqual(batch.num_rows, expected_num_rows)
|
|
total_num_rows += batch.num_rows
|
|
self.assertEqual(total_num_rows, expected_num_rows)
|
|
|
|
def test_table_to_batches(self):
|
|
for dataset_path in (
|
|
"tests/data/arrow/*.parquet",
|
|
"tests/data/large_array/*.parquet",
|
|
):
|
|
with self.subTest(dataset_path=dataset_path):
|
|
parquet_files = glob.glob(dataset_path)
|
|
table = self._load_parquet_files(parquet_files)
|
|
total_num_rows = 0
|
|
for batch in table.to_batches(max_chunksize=table.num_rows):
|
|
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}")
|
|
self.assertLessEqual(batch.num_rows, table.num_rows)
|
|
total_num_rows += batch.num_rows
|
|
self.assertEqual(total_num_rows, table.num_rows)
|
|
|
|
def test_arrow_schema_metadata(self):
|
|
table = self._load_parquet_files(glob.glob("tests/data/arrow/*.parquet"))
|
|
metadata = {b"a": b"1", b"b": b"2"}
|
|
table_with_meta = table.replace_schema_metadata(metadata)
|
|
print(f"table_with_meta.schema.metadata {table_with_meta.schema.metadata}")
|
|
|
|
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
|
self.assertTrue(dump_to_parquet_files(table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2))
|
|
parquet_files = glob.glob(os.path.join(output_dir, "arrow_schema_metadata*.parquet"))
|
|
loaded_table = load_from_parquet_files(parquet_files, table.column_names[:1])
|
|
print(f"loaded_table.schema.metadata {loaded_table.schema.metadata}")
|
|
self.assertEqual(table_with_meta.schema.metadata, loaded_table.schema.metadata)
|
|
with parquet.ParquetFile(parquet_files[0]) as file:
|
|
print(f"file.schema_arrow.metadata {file.schema_arrow.metadata}")
|
|
self.assertEqual(table_with_meta.schema.metadata, file.schema_arrow.metadata)
|
|
|
|
def test_load_mixed_string_types(self):
|
|
parquet_paths = glob.glob("tests/data/arrow/*.parquet")
|
|
table = self._load_parquet_files(parquet_paths)
|
|
|
|
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
|
|
dump_to_parquet_files(cast_columns_to_large_string(table), output_dir)
|
|
parquet_paths += glob.glob(os.path.join(output_dir, "*.parquet"))
|
|
loaded_table = load_from_parquet_files(parquet_paths)
|
|
self.assertEqual(table.num_rows * 2, loaded_table.num_rows)
|
|
batch_reader = build_batch_reader_from_files(parquet_paths)
|
|
self.assertEqual(table.num_rows * 2, sum(batch.num_rows for batch in batch_reader))
|
|
|
|
@logger.catch(reraise=True, message="failed to load parquet files")
|
|
def _load_from_parquet_files_with_log(self, paths, columns):
|
|
load_from_parquet_files(paths, columns)
|
|
|
|
def test_load_not_exist_column(self):
|
|
parquet_files = glob.glob("tests/data/arrow/*.parquet")
|
|
with self.assertRaises(AssertionError) as context:
|
|
self._load_from_parquet_files_with_log(parquet_files, ["not_exist_column"])
|
|
|
|
def test_change_ordering_of_columns(self):
|
|
parquet_files = glob.glob("tests/data/arrow/*.parquet")
|
|
loaded_table = load_from_parquet_files(parquet_files)
|
|
reversed_cols = list(reversed(loaded_table.column_names))
|
|
loaded_table = load_from_parquet_files(parquet_files, reversed_cols)
|
|
self.assertEqual(loaded_table.column_names, reversed_cols)
|