expose ArrowBatchNode to DataFrame API

This commit is contained in:
Runji Wang
2025-02-28 17:40:31 +08:00
parent ed112db42a
commit 947af97bab
2 changed files with 54 additions and 11 deletions

View File

@@ -5,7 +5,7 @@ import time
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
import pandas as pd
import pyarrow as arrow
@@ -578,6 +578,7 @@ class DataFrame:
func: Callable[[arrow.Table], arrow.Table],
*,
batch_size: int = 122880,
streaming: bool = False,
**kwargs,
) -> DataFrame:
"""
@@ -590,18 +591,35 @@ class DataFrame:
It should take a `arrow.Table` as input and returns a `arrow.Table`.
batch_size, optional
The number of rows in each batch. Defaults to 122880.
streaming, optional
If true, the function takes an iterator of `arrow.Table` as input and yields a streaming of `arrow.Table` as output.
i.e. func: Callable[[Iterator[arrow.Table]], Iterator[arrow.Table]]
Defaults to false.
"""
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
return func(tables[0])
if streaming:
def process_func(_runtime_ctx, readers: List[arrow.RecordBatchReader]) -> Iterator[arrow.Table]:
tables = map(lambda batch: arrow.Table.from_batches([batch]), readers[0])
return func(tables)
plan = ArrowBatchNode(
self.session._ctx,
(self.plan,),
process_func=process_func,
streaming_batch_size=batch_size,
**kwargs,
)
plan = ArrowStreamNode(
self.session._ctx,
(self.plan,),
process_func=process_func,
streaming_batch_size=batch_size,
**kwargs,
)
else:
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
return func(tables[0])
plan = ArrowBatchNode(
self.session._ctx,
(self.plan,),
process_func=process_func,
streaming_batch_size=batch_size,
**kwargs,
)
return DataFrame(self.session, plan, recompute=self.need_recompute)
def limit(self, limit: int) -> DataFrame:

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import Iterator, List
import pandas as pd
import pyarrow as pa
@@ -84,6 +84,31 @@ def test_map_batches(sp: Session):
assert df.take_all() == [{"num_rows": 350}, {"num_rows": 350}, {"num_rows": 300}]
def test_map_batches_streaming(sp: Session):
df = sp.read_parquet("tests/data/mock_urls/*.parquet")
def batched2(tables: Iterator[pa.Table]) -> Iterator[pa.Table]:
# same as itertools.pairwise
num_rows = 0
count = 0
for batch in tables:
num_rows += batch.num_rows
count += 1
if count == 2:
yield pa.table({"num_rows": [num_rows]})
num_rows = 0
count = 0
if count > 0:
yield pa.table({"num_rows": [num_rows]})
df = df.map_batches(
batched2,
batch_size=350,
streaming=True,
)
assert df.take_all() == [{"num_rows": 700}, {"num_rows": 300}]
def test_filter(sp: Session):
df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
df1 = df.filter("a > 1")