mirror of
https://github.com/deepseek-ai/smallpond
synced 2025-06-26 18:27:45 +00:00
expose ArrowBatchNode to DataFrame API
This commit is contained in:
@@ -5,7 +5,7 @@ import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as arrow
|
||||
@@ -578,6 +578,7 @@ class DataFrame:
|
||||
func: Callable[[arrow.Table], arrow.Table],
|
||||
*,
|
||||
batch_size: int = 122880,
|
||||
streaming: bool = False,
|
||||
**kwargs,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
@@ -590,18 +591,35 @@ class DataFrame:
|
||||
It should take a `arrow.Table` as input and returns a `arrow.Table`.
|
||||
batch_size, optional
|
||||
The number of rows in each batch. Defaults to 122880.
|
||||
streaming, optional
|
||||
If true, the function takes an iterator of `arrow.Table` as input and yields a streaming of `arrow.Table` as output.
|
||||
i.e. func: Callable[[Iterator[arrow.Table]], Iterator[arrow.Table]]
|
||||
Defaults to false.
|
||||
"""
|
||||
|
||||
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
|
||||
return func(tables[0])
|
||||
if streaming:
|
||||
def process_func(_runtime_ctx, readers: List[arrow.RecordBatchReader]) -> Iterator[arrow.Table]:
|
||||
tables = map(lambda batch: arrow.Table.from_batches([batch]), readers[0])
|
||||
return func(tables)
|
||||
|
||||
plan = ArrowBatchNode(
|
||||
self.session._ctx,
|
||||
(self.plan,),
|
||||
process_func=process_func,
|
||||
streaming_batch_size=batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
plan = ArrowStreamNode(
|
||||
self.session._ctx,
|
||||
(self.plan,),
|
||||
process_func=process_func,
|
||||
streaming_batch_size=batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
|
||||
return func(tables[0])
|
||||
|
||||
plan = ArrowBatchNode(
|
||||
self.session._ctx,
|
||||
(self.plan,),
|
||||
process_func=process_func,
|
||||
streaming_batch_size=batch_size,
|
||||
**kwargs,
|
||||
)
|
||||
return DataFrame(self.session, plan, recompute=self.need_recompute)
|
||||
|
||||
def limit(self, limit: int) -> DataFrame:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import Iterator, List
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
@@ -84,6 +84,31 @@ def test_map_batches(sp: Session):
|
||||
assert df.take_all() == [{"num_rows": 350}, {"num_rows": 350}, {"num_rows": 300}]
|
||||
|
||||
|
||||
def test_map_batches_streaming(sp: Session):
|
||||
df = sp.read_parquet("tests/data/mock_urls/*.parquet")
|
||||
|
||||
def batched2(tables: Iterator[pa.Table]) -> Iterator[pa.Table]:
|
||||
# same as itertools.pairwise
|
||||
num_rows = 0
|
||||
count = 0
|
||||
for batch in tables:
|
||||
num_rows += batch.num_rows
|
||||
count += 1
|
||||
if count == 2:
|
||||
yield pa.table({"num_rows": [num_rows]})
|
||||
num_rows = 0
|
||||
count = 0
|
||||
if count > 0:
|
||||
yield pa.table({"num_rows": [num_rows]})
|
||||
|
||||
df = df.map_batches(
|
||||
batched2,
|
||||
batch_size=350,
|
||||
streaming=True,
|
||||
)
|
||||
assert df.take_all() == [{"num_rows": 700}, {"num_rows": 300}]
|
||||
|
||||
|
||||
def test_filter(sp: Session):
|
||||
df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
|
||||
df1 = df.filter("a > 1")
|
||||
|
||||
Reference in New Issue
Block a user