This commit is contained in:
Runji Wang
2025-02-25 18:16:31 +08:00
commit 770aa417d5
77 changed files with 18785 additions and 0 deletions

126
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,126 @@
name: CI Pipeline
on:
push:
branches: [ main ]
pull_request:
schedule:
- cron: '0 0 * * *'
jobs:
fmt:
name: Format Check
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install formatters
run: |
python -m pip install --upgrade pip
pip install black
- name: Check code formatting
run: |
black --check .
# - name: Check typos
# uses: crate-ci/typos@v1.29.10
test:
name: Test (${{ matrix.os }}, Python ${{ matrix.python-version }})
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [self-hosted] # macos-latest
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[dev]
- name: Install gensort on Linux
if: matrix.os == 'self-hosted'
run: |
wget https://www.ordinal.com/try.cgi/gensort-linux-1.5.tar.gz
tar -xzf gensort-linux-1.5.tar.gz
chmod +x 64/gensort 64/valsort
export PATH=$PATH:$(pwd)/64
- name: Run tests
timeout-minutes: 20
run: |
pytest -n 4 -v -x --durations=50 --timeout=600 \
--junitxml=pytest.xml \
--cov-report term --cov-report xml:coverage.xml \
--cov=smallpond --cov=examples --cov=benchmarks --cov=tests \
tests/test_*.py
- name: Archive test results
uses: actions/upload-artifact@v4
with:
name: test-results-${{ matrix.os }}-py${{ matrix.python-version }}
path: |
pytest.xml
coverage.xml
build-docs:
name: Build Documentation
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.8
uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install dependencies
run: pip install -e .[docs]
- name: Build HTML docs
run: |
cd docs
make html
- name: Archive documentation
uses: actions/upload-artifact@v4
with:
name: documentation
path: docs/build/html
deploy-docs:
name: Deploy Documentation
runs-on: self-hosted
needs: [build-docs]
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
steps:
- uses: actions/checkout@v4
- name: Download artifact
uses: actions/download-artifact@v4
with:
name: documentation
path: docs/build/html
- name: Deploy to docs branch
uses: peaceiris/actions-gh-pages@v4
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: docs/build/html
destination_dir: ./
keep_files: true
branch: docs
force_orphan: true

18
.gitignore vendored Normal file
View File

@@ -0,0 +1,18 @@
__pycache__
.ipynb_checkpoints
.tmp/
dist/
build/
*.egg-info/
tests/data/
tests/runtime/
*.log
*.pyc
*.xml
.tmp/
.idea
.coverage
.vscode/
.hypothesis/
docs/*/generated/
.venv*/

7
LICENSE Normal file
View File

@@ -0,0 +1,7 @@
Copyright 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

1
MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
exclude tests/data/** tests/runtime/**

81
README.md Normal file
View File

@@ -0,0 +1,81 @@
# smallpond
[![CI](https://github.com/deepseek-ai/smallpond/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/deepseek-ai/smallpond/actions/workflows/ci.yml)
[![PyPI](https://img.shields.io/pypi/v/smallpond)](https://pypi.org/project/smallpond/)
[![Docs](https://img.shields.io/badge/docs-latest-brightgreen.svg)](https://deepseek-ai.github.io/smallpond/)
[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
A lightweight data processing framework built on DuckDB and [3FS].
## Features
- 🚀 High-performance data processing powered by DuckDB
- 🌍 Scalable to handle PB-scale datasets
- 🛠️ Easy operations with no long-running services
## Installation
Python 3.8 to 3.12 is supported.
```bash
pip install smallpond
```
## Quick Start
```bash
# Download example data
wget https://duckdb.org/data/prices.parquet
```
```python
import smallpond
# Initialize session
sp = smallpond.init()
# Load data
df = sp.read_parquet("prices.parquet")
# Process data
df = df.repartition(3, hash_by="ticker")
df = sp.partial_sql("SELECT ticker, min(price), max(price) FROM {0} GROUP BY ticker", df)
# Save results
df.write_parquet("output/")
# Show results
print(df.to_pandas())
```
## Documentation
For detailed guides and API reference:
- [Getting Started](docs/source/getstarted.rst)
- [API Reference](docs/source/api.rst)
## Performance
We executed the [Gray Sort benchmark] using [smallpond] on a cluster comprising 50 compute nodes and 25 storage nodes running [3FS]. The benchmark sorted 110.5TiB of data in 30 minutes and 14 seconds, achieving a throughput of 3.66TiB/min.
[3FS]: https://github.com/deepseek-ai/3fs
[Gray Sort benchmark]: https://sortbenchmark.org/
[smallpond]: benchmarks/gray_sort_benchmark.py
## Development
```bash
pip install .[dev]
# run unit tests
pytest -v tests/test*.py
# build documentation
pip install .[docs]
cd docs
make html
python -m http.server --directory build/html
```
## License
This project is licensed under the [MIT License](LICENSE).

View File

@@ -0,0 +1,84 @@
from smallpond.common import DEFAULT_BATCH_SIZE, DEFAULT_ROW_GROUP_SIZE, GB
from smallpond.contrib.copy_table import CopyArrowTable, StreamCopy
from smallpond.execution.driver import Driver
from smallpond.logical.dataset import ParquetDataSet
from smallpond.logical.node import (
Context,
DataSetPartitionNode,
DataSourceNode,
LogicalPlan,
SqlEngineNode,
)
def file_io_benchmark(
input_paths,
npartitions,
io_engine="duckdb",
batch_size=DEFAULT_BATCH_SIZE,
row_group_size=DEFAULT_ROW_GROUP_SIZE,
output_name="data",
**kwargs,
) -> LogicalPlan:
ctx = Context()
dataset = ParquetDataSet(input_paths)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions)
if io_engine == "duckdb":
data_copy = SqlEngineNode(
ctx,
(data_partitions,),
r"select * from {0}",
parquet_row_group_size=row_group_size,
per_thread_output=False,
output_name=output_name,
cpu_limit=1,
memory_limit=10 * GB,
)
elif io_engine == "arrow":
data_copy = CopyArrowTable(
ctx,
(data_partitions,),
parquet_row_group_size=row_group_size,
output_name=output_name,
cpu_limit=1,
memory_limit=10 * GB,
)
elif io_engine == "stream":
data_copy = StreamCopy(
ctx,
(data_partitions,),
streaming_batch_size=batch_size,
parquet_row_group_size=row_group_size,
output_name=output_name,
cpu_limit=1,
memory_limit=10 * GB,
)
plan = LogicalPlan(ctx, data_copy)
return plan
def main():
driver = Driver()
driver.add_argument("-i", "--input_paths", nargs="+")
driver.add_argument("-n", "--npartitions", type=int, default=None)
driver.add_argument(
"-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream")
)
driver.add_argument("-b", "--batch_size", type=int, default=1024 * 1024)
driver.add_argument("-s", "--row_group_size", type=int, default=1024 * 1024)
driver.add_argument("-o", "--output_name", default="data")
driver.add_argument("-NC", "--cpus_per_node", type=int, default=128)
user_args, driver_args = driver.parse_arguments()
total_num_cpus = driver_args.num_executors * user_args.cpus_per_node
user_args.npartitions = user_args.npartitions or total_num_cpus
plan = file_io_benchmark(**driver.get_arguments())
driver.run(plan)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,417 @@
import functools
import logging
import os.path
import shutil
import subprocess
import tempfile
from pathlib import PurePath
from typing import Iterable, List
import duckdb
import polars
import psutil
import pyarrow as arrow
import pyarrow.compute as pc
from smallpond.common import GB, MB, next_power_of_two, pytest_running
from smallpond.execution.driver import Driver
from smallpond.execution.task import (
ArrowStreamTask,
PythonScriptTask,
RuntimeContext,
StreamOutput,
)
from smallpond.logical.dataset import ArrowTableDataSet, DataSet, ParquetDataSet
from smallpond.logical.node import (
ArrowStreamNode,
Context,
DataSetPartitionNode,
DataSourceNode,
LogicalPlan,
ProjectionNode,
PythonScriptNode,
ShuffleNode,
)
class SortBenchTool(object):
gensort_path = shutil.which("gensort")
valsort_path = shutil.which("valsort")
@staticmethod
def ensure_installed():
if not SortBenchTool.gensort_path or not SortBenchTool.valsort_path:
raise Exception("gensort or valsort not found")
def generate_records(
runtime_ctx: RuntimeContext,
input_readers: List[arrow.RecordBatchReader],
record_nbytes=100,
key_nbytes=10,
bucket_nbits=12,
gensort_batch_nbytes=500 * MB,
) -> Iterable[arrow.Table]:
runtime_task: ArrowStreamTask = runtime_ctx.task
batch_size = gensort_batch_nbytes // record_nbytes
schema = arrow.schema(
[
arrow.field("buckets", arrow.uint16()),
arrow.field("keys", arrow.binary()),
arrow.field("records", arrow.binary()),
]
)
with tempfile.NamedTemporaryFile(dir="/dev/shm", buffering=0) as shm_file:
for batch_idx, batch in enumerate(input_readers[0]):
for begin_at, num_records in zip(*batch.columns):
begin_at, num_records = begin_at.as_py(), num_records.as_py()
for offset in range(begin_at, begin_at + num_records, batch_size):
record_count = min(batch_size, begin_at + num_records - offset)
gensort_cmd = f"{SortBenchTool.gensort_path} -t2 -b{offset} {record_count} {shm_file.name},buf,trans=100m"
subprocess.run(gensort_cmd.split()).check_returncode()
runtime_task.add_elapsed_time("generate records (secs)")
shm_file.seek(0)
buffer = arrow.py_buffer(
shm_file.read(record_count * record_nbytes)
)
runtime_task.add_elapsed_time("read records (secs)")
# https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout
records = arrow.Array.from_buffers(
arrow.binary(record_nbytes), record_count, [None, buffer]
)
keys = pc.binary_slice(records, 0, key_nbytes)
# get first 2 bytes and convert to big-endian uint16
binary_prefix = pc.binary_slice(records, 0, 2).cast(arrow.binary())
reversed_prefix = pc.binary_reverse(binary_prefix).cast(
arrow.binary(2)
)
uint16_prefix = reversed_prefix.view(arrow.uint16())
buckets = pc.shift_right(uint16_prefix, 16 - bucket_nbits)
runtime_task.add_elapsed_time("build arrow table (secs)")
yield arrow.Table.from_arrays(
[buckets, keys, records], schema=schema
)
yield StreamOutput(
schema.empty_table(),
batch_indices=[batch_idx],
force_checkpoint=pytest_running(),
)
def sort_records(
runtime_ctx: RuntimeContext,
input_datasets: List[DataSet],
output_path: str,
sort_engine="polars",
write_io_nbytes=500 * MB,
) -> bool:
runtime_task: PythonScriptTask = runtime_ctx.task
data_file_path = os.path.join(
runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat"
)
if sort_engine == "polars":
input_data = polars.read_parquet(
input_datasets[0].resolved_paths,
rechunk=False,
hive_partitioning=False,
columns=input_datasets[0].columns,
)
runtime_task.perf_metrics["num input rows"] += len(input_data)
runtime_task.add_elapsed_time("input load time (secs)")
sorted_records = input_data.sort("keys").get_column("records")
runtime_task.add_elapsed_time("sort by keys (secs)")
record_arrays = [chunk.to_arrow() for chunk in sorted_records.get_chunks()]
runtime_task.add_elapsed_time("convert to chunks (secs)")
elif sort_engine == "arrow":
input_table = input_datasets[0].to_arrow_table(runtime_task.cpu_limit)
runtime_task.perf_metrics["num input rows"] += input_table.num_rows
runtime_task.add_elapsed_time("input load time (secs)")
sorted_table = input_table.sort_by("keys")
runtime_task.add_elapsed_time("sort by keys (secs)")
record_arrays = sorted_table.column("records").chunks
runtime_task.add_elapsed_time("convert to chunks (secs)")
elif sort_engine == "duckdb":
with duckdb.connect(
database=":memory:", config={"allow_unsigned_extensions": "true"}
) as conn:
runtime_task.prepare_connection(conn)
input_views = runtime_task.create_input_views(conn, input_datasets)
sql_query = "select records from {0} order by keys".format(*input_views)
sorted_table = conn.sql(sql_query).to_arrow_table()
runtime_task.add_elapsed_time("sort by keys (secs)")
record_arrays = sorted_table.column("records").chunks
runtime_task.add_elapsed_time("convert to chunks (secs)")
else:
raise Exception(f"unknown sort engine: {sort_engine}")
with open(data_file_path, "wb") as fout:
for record_array in record_arrays:
# https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-layout
validity_bitmap, offsets, values = record_array.buffers()
buffer_mem = memoryview(values)
total_write_nbytes = sum(
fout.write(buffer_mem[offset : offset + write_io_nbytes])
for offset in range(0, len(buffer_mem), write_io_nbytes)
)
assert total_write_nbytes == len(buffer_mem)
runtime_task.perf_metrics["num output rows"] += len(record_array)
runtime_task.add_elapsed_time("output dump time (secs)")
return True
def validate_records(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
for data_path in input_datasets[0].resolved_paths:
summary_path = os.path.join(
output_path, PurePath(data_path).with_suffix(".sum").name
)
cmdstr = (
f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m"
)
logging.debug(f"running command: {cmdstr}")
result = subprocess.run(cmdstr.split(), capture_output=True, encoding="utf8")
if result.stderr:
logging.info(f"valsort stderr: {result.stderr}")
if result.stdout:
logging.info(f"valsort stdout: {result.stdout}")
if result.returncode != 0:
return False
return True
def validate_summary(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
concated_summary_path = os.path.join(output_path, "merged.sum")
with open(concated_summary_path, "wb") as fout:
for path in input_datasets[0].resolved_paths:
with open(path, "rb") as fin:
fout.write(fin.read())
cmdstr = f"{SortBenchTool.valsort_path} -s {concated_summary_path}"
logging.debug(f"running command: {cmdstr}")
result = subprocess.run(cmdstr.split(), capture_output=True, encoding="utf8")
if result.stderr:
logging.info(f"valsort stderr: {result.stderr}")
if result.stdout:
logging.info(f"valsort stdout: {result.stdout}")
return result.returncode == 0
def generate_random_records(
ctx,
record_nbytes,
key_nbytes,
total_data_nbytes,
gensort_batch_nbytes,
num_data_partitions,
num_sort_partitions,
parquet_compression=None,
parquet_compression_level=None,
):
num_record_ranges = num_data_partitions * 10
total_num_records = total_data_nbytes // record_nbytes
record_range_size = (total_num_records + num_record_ranges - 1) // num_record_ranges
logging.warning(
f"{record_nbytes} bytes/record x total {total_num_records:,d} records = "
f"{total_data_nbytes/GB:.3f}GB / {num_record_ranges} record ranges = "
f"{record_range_size * record_nbytes/GB:.3f}GB ({record_range_size:,d} records) per record range"
)
range_begin_at = [pos for pos in range(0, total_num_records, record_range_size)]
range_num_records = [
min(total_num_records, record_range_size * (range_idx + 1)) - begin_at
for range_idx, begin_at in enumerate(range_begin_at)
]
assert sum(range_num_records) == total_num_records
record_range = DataSourceNode(
ctx,
ArrowTableDataSet(
arrow.Table.from_arrays(
[range_begin_at, range_num_records], names=["begin_at", "num_records"]
)
),
)
record_range_partitions = DataSetPartitionNode(
ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True
)
random_records = ArrowStreamNode(
ctx,
(record_range_partitions,),
process_func=functools.partial(
generate_records,
record_nbytes=record_nbytes,
key_nbytes=key_nbytes,
bucket_nbits=num_sort_partitions.bit_length() - 1,
gensort_batch_nbytes=gensort_batch_nbytes,
),
background_io_thread=True,
streaming_batch_size=10,
parquet_row_group_size=1024 * 1024,
parquet_compression=parquet_compression,
parquet_compression_level=parquet_compression_level,
output_name="random_records",
cpu_limit=2,
)
return random_records
def gray_sort_benchmark(
record_nbytes,
key_nbytes,
total_data_nbytes,
gensort_batch_nbytes,
num_data_partitions,
num_sort_partitions,
input_paths=None,
shuffle_engine="duckdb",
sort_engine="polars",
hive_partitioning=False,
validate_results=False,
shuffle_cpu_limit=32,
shuffle_memory_limit=None,
sort_cpu_limit=8,
sort_memory_limit=None,
parquet_compression=None,
parquet_compression_level=None,
**kwargs,
) -> LogicalPlan:
ctx = Context()
num_sort_partitions = next_power_of_two(num_sort_partitions)
if input_paths:
input_dataset = ParquetDataSet(input_paths)
input_nbytes = sum(os.path.getsize(p) for p in input_dataset.resolved_paths)
logging.warning(
f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files"
)
random_records = DataSourceNode(ctx, input_dataset)
else:
random_records = generate_random_records(
ctx,
record_nbytes,
key_nbytes,
total_data_nbytes,
gensort_batch_nbytes,
num_data_partitions,
num_sort_partitions,
parquet_compression,
parquet_compression_level,
)
partitioned_records = ShuffleNode(
ctx,
(random_records,),
npartitions=num_sort_partitions,
data_partition_column="buckets",
engine_type=shuffle_engine,
hive_partitioning=hive_partitioning,
parquet_row_group_size=10 * 1024 * 1024,
parquet_compression=parquet_compression,
parquet_compression_level=parquet_compression_level,
cpu_limit=shuffle_cpu_limit,
memory_limit=shuffle_memory_limit,
)
sorted_records = PythonScriptNode(
ctx,
(ProjectionNode(ctx, partitioned_records, ["keys", "records"]),),
process_func=functools.partial(sort_records, sort_engine=sort_engine),
output_name="sorted_records",
cpu_limit=sort_cpu_limit,
memory_limit=sort_memory_limit,
)
if validate_results:
partitioned_summaries = PythonScriptNode(
ctx,
(sorted_records,),
process_func=validate_records,
output_name="partitioned_summaries",
)
merged_summaries = DataSetPartitionNode(
ctx, (partitioned_summaries,), npartitions=1
)
final_check = PythonScriptNode(
ctx, (merged_summaries,), process_func=validate_summary
)
root = final_check
else:
root = sorted_records
return LogicalPlan(ctx, root)
def main():
SortBenchTool.ensure_installed()
driver = Driver()
driver.add_argument("-R", "--record_nbytes", type=int, default=100)
driver.add_argument("-K", "--key_nbytes", type=int, default=10)
driver.add_argument("-T", "--total_data_nbytes", type=int, default=None)
driver.add_argument("-B", "--gensort_batch_nbytes", type=int, default=512 * MB)
driver.add_argument("-n", "--num_data_partitions", type=int, default=None)
driver.add_argument("-t", "--num_sort_partitions", type=int, default=None)
driver.add_argument("-i", "--input_paths", nargs="+", default=[])
driver.add_argument(
"-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow")
)
driver.add_argument(
"-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars")
)
driver.add_argument("-H", "--hive_partitioning", action="store_true")
driver.add_argument("-V", "--validate_results", action="store_true")
driver.add_argument(
"-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit
)
driver.add_argument(
"-M",
"--shuffle_memory_limit",
type=int,
default=ShuffleNode.default_memory_limit,
)
driver.add_argument("-TC", "--sort_cpu_limit", type=int, default=8)
driver.add_argument("-TM", "--sort_memory_limit", type=int, default=None)
driver.add_argument(
"-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False)
)
driver.add_argument(
"-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total
)
driver.add_argument("-CP", "--parquet_compression", default=None)
driver.add_argument("-LV", "--parquet_compression_level", type=int, default=None)
user_args, driver_args = driver.parse_arguments()
assert len(user_args.input_paths) == 0 or user_args.num_sort_partitions is not None
total_num_cpus = max(1, driver_args.num_executors) * user_args.cpus_per_node
memory_per_cpu = user_args.memory_per_node // user_args.cpus_per_node
user_args.sort_cpu_limit = (
1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit
)
sort_memory_limit = (
user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu
)
user_args.total_data_nbytes = (
user_args.total_data_nbytes
or max(1, driver_args.num_executors) * user_args.memory_per_node
)
user_args.num_data_partitions = user_args.num_data_partitions or total_num_cpus // 2
user_args.num_sort_partitions = user_args.num_sort_partitions or max(
total_num_cpus // user_args.sort_cpu_limit,
user_args.total_data_nbytes // (sort_memory_limit // 4),
)
plan = gray_sort_benchmark(**vars(user_args))
driver.run(plan)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,97 @@
from smallpond.common import GB
from smallpond.contrib.log_dataset import LogDataSet
from smallpond.execution.driver import Driver
from smallpond.logical.dataset import ParquetDataSet
from smallpond.logical.node import (
ConsolidateNode,
Context,
DataSourceNode,
HashPartitionNode,
LogicalPlan,
SqlEngineNode,
)
def hash_partition_benchmark(
input_paths,
npartitions,
hash_columns,
engine_type="duckdb",
use_parquet_writer=False,
hive_partitioning=False,
cpu_limit=None,
memory_limit=None,
partition_stats=True,
**kwargs,
) -> LogicalPlan:
ctx = Context()
dataset = ParquetDataSet(input_paths)
data_files = DataSourceNode(ctx, dataset)
partitioned_datasets = HashPartitionNode(
ctx,
(data_files,),
npartitions=npartitions,
hash_columns=hash_columns,
data_partition_column="partition_keys",
engine_type=engine_type,
use_parquet_writer=use_parquet_writer,
hive_partitioning=hive_partitioning,
output_name="partitioned_datasets",
cpu_limit=cpu_limit,
memory_limit=memory_limit,
)
if partition_stats:
partition_stats = SqlEngineNode(
ctx,
(partitioned_datasets,),
f"""
select partition_keys, count(*) as row_cnt, count( distinct ( {', '.join(hash_columns)} ) ) as uniq_key_cnt from {{0}}
group by partition_keys""",
output_name="partition_stats",
cpu_limit=1,
memory_limit=10 * GB,
)
sorted_stats = SqlEngineNode(
ctx,
(ConsolidateNode(ctx, partition_stats, []),),
r"select * from {0} order by row_cnt desc",
)
plan = LogicalPlan(ctx, LogDataSet(ctx, (sorted_stats,), num_rows=npartitions))
else:
plan = LogicalPlan(ctx, partitioned_datasets)
return plan
def main():
driver = Driver()
driver.add_argument("-i", "--input_paths", nargs="+", required=True)
driver.add_argument("-n", "--npartitions", type=int, default=None)
driver.add_argument("-c", "--hash_columns", nargs="+", required=True)
driver.add_argument(
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
)
driver.add_argument("-S", "--partition_stats", action="store_true")
driver.add_argument("-W", "--use_parquet_writer", action="store_true")
driver.add_argument("-H", "--hive_partitioning", action="store_true")
driver.add_argument(
"-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit
)
driver.add_argument(
"-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit
)
driver.add_argument("-NC", "--cpus_per_node", type=int, default=192)
driver.add_argument("-NM", "--memory_per_node", type=int, default=2000 * GB)
user_args, driver_args = driver.parse_arguments()
total_num_cpus = driver_args.num_executors * user_args.cpus_per_node
user_args.npartitions = user_args.npartitions or total_num_cpus
plan = hash_partition_benchmark(**vars(user_args))
driver.run(plan)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,126 @@
from typing import List, OrderedDict
from smallpond.common import GB
from smallpond.dataframe import Session
from smallpond.execution.driver import Driver
from smallpond.logical.dataset import CsvDataSet
from smallpond.logical.node import (
Context,
DataSetPartitionNode,
DataSourceNode,
HashPartitionNode,
LogicalPlan,
SqlEngineNode,
)
def urls_sort_benchmark(
input_paths: List[str],
num_data_partitions: int,
num_hash_partitions: int,
engine_type="duckdb",
sort_cpu_limit=8,
sort_memory_limit=None,
) -> LogicalPlan:
ctx = Context()
dataset = CsvDataSet(
input_paths,
schema=OrderedDict([("urlstr", "varchar"), ("valstr", "varchar")]),
delim=r"\t",
)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=num_data_partitions
)
imported_urls = SqlEngineNode(
ctx,
(data_partitions,),
r"""
select urlstr, valstr from {0}
""",
output_name="imported_urls",
parquet_row_group_size=1024 * 1024,
cpu_limit=1,
memory_limit=16 * GB,
)
urls_partitions = HashPartitionNode(
ctx,
(imported_urls,),
npartitions=num_hash_partitions,
hash_columns=["urlstr"],
engine_type=engine_type,
parquet_row_group_size=1024 * 1024,
cpu_limit=1,
memory_limit=16 * GB,
)
sorted_urls = SqlEngineNode(
ctx,
(urls_partitions,),
r"select * from {0} order by urlstr",
output_name="sorted_urls",
parquet_row_group_size=1024 * 1024,
cpu_limit=sort_cpu_limit,
memory_limit=sort_memory_limit,
)
plan = LogicalPlan(ctx, sorted_urls)
return plan
def urls_sort_benchmark_v2(
sp: Session,
input_paths: List[str],
output_path: str,
num_data_partitions: int,
num_hash_partitions: int,
engine_type="duckdb",
sort_cpu_limit=8,
sort_memory_limit=None,
):
dataset = sp.read_csv(
input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t"
)
data_partitions = dataset.repartition(num_data_partitions)
urls_partitions = data_partitions.repartition(
num_hash_partitions, hash_by="urlstr", engine_type=engine_type
)
sorted_urls = urls_partitions.partial_sort(
by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit
)
sorted_urls.write_parquet(output_path)
def main():
driver = Driver()
driver.add_argument("-i", "--input_paths", nargs="+")
driver.add_argument("-n", "--num_data_partitions", type=int, default=None)
driver.add_argument("-m", "--num_hash_partitions", type=int, default=None)
driver.add_argument("-e", "--engine_type", default="duckdb")
driver.add_argument("-TC", "--sort_cpu_limit", type=int, default=8)
driver.add_argument("-TM", "--sort_memory_limit", type=int, default=None)
user_args, driver_args = driver.parse_arguments()
num_nodes = driver_args.num_executors
cpus_per_node = 120
partition_rounds = 2
user_args.num_data_partitions = (
user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds
)
user_args.num_hash_partitions = (
user_args.num_hash_partitions or num_nodes * cpus_per_node
)
# v1
plan = urls_sort_benchmark(**vars(user_args))
driver.run(plan)
# v2
# sp = smallpond.init()
# urls_sort_benchmark_v2(sp, **vars(user_args))
if __name__ == "__main__":
main()

20
docs/Makefile Normal file
View File

@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
docs/make.bat Normal file
View File

@@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

81
docs/source/api.rst Normal file
View File

@@ -0,0 +1,81 @@
API Reference
=============
Smallpond provides both high-level and low-level APIs.
.. note::
Currently, smallpond provides two different APIs, supporting dynamic and static construction of data flow graphs respectively. Due to historical reasons, these two APIs use different scheduler backends and support different configuration options.
- The High-level API currently uses Ray as the backend, supporting dynamic construction and execution of data flow graphs.
- The Low-level API uses a built-in scheduler and only supports one-time execution of static data flow graphs. However, it offers more performance optimizations and richer configuration options.
We are working to merge them so that in the future, you can use a unified high-level API and freely choose between Ray or the built-in scheduler.
High-level API
--------------
The high-level API is centered around :ref:`dataframe`. It allows dynamic construction of data flow graphs, execution, and result retrieval.
A typical workflow looks like this:
.. code-block:: python
import smallpond
sp = smallpond.init()
df = sp.read_parquet("path/to/dataset/*.parquet")
df = df.repartition(10)
df = df.map("x + 1")
df.write_parquet("path/to/output")
.. toctree::
:maxdepth: 2
api/dataframe
It is recommended to use the DataFrame API.
Low-level API
-------------
In the low-level API, users manually create :ref:`nodes` to construct static data flow graphs, then submit them to smallpond to generate :ref:`tasks` and wait for all tasks to complete.
A complete example is shown below.
.. code-block:: python
from smallpond.logical.dataset import ParquetDataSet
from smallpond.logical.node import Context, DataSourceNode, DataSetPartitionNode, SqlEngineNode, LogicalPlan
from smallpond.execution.driver import Driver
def my_pipeline(input_paths: List[str], npartitions: int):
ctx = Context()
dataset = ParquetDataSet(input_paths)
node = DataSourceNode(ctx, dataset)
node = DataSetPartitionNode(ctx, (node,), npartitions=npartitions)
node = SqlEngineNode(ctx, (node,), "SELECT * FROM {0}")
return LogicalPlan(ctx, node)
if __name__ == "__main__":
driver = Driver()
driver.add_argument("-i", "--input_paths", nargs="+")
driver.add_argument("-n", "--npartitions", type=int, default=10)
plan = my_pipeline(**driver.get_arguments())
driver.run(plan)
To run this script:
.. code-block:: bash
python script.py -i "path/to/*.parquet" -n 10
.. toctree::
:maxdepth: 2
api/dataset
api/nodes
api/tasks
api/execution

View File

@@ -0,0 +1,104 @@
.. _dataframe:
DataFrame
=========
DataFrame is the main class in smallpond. It represents a lazily computed, partitioned data set.
A typical workflow looks like this:
.. code-block:: python
import smallpond
sp = smallpond.init()
df = sp.read_parquet("path/to/dataset/*.parquet")
df = df.repartition(10)
df = df.map("x + 1")
df.write_parquet("path/to/output")
Initialization
--------------
.. autosummary::
:toctree: ../generated
smallpond.init
.. currentmodule:: smallpond.dataframe
.. _loading_data:
Loading Data
------------
.. autosummary::
:toctree: ../generated
Session.from_items
Session.from_arrow
Session.from_pandas
Session.read_csv
Session.read_json
Session.read_parquet
.. _partitioning_data:
Partitioning Data
-----------------
.. autosummary::
:toctree: ../generated
DataFrame.repartition
.. _transformations:
Transformations
---------------
Apply transformations and return a new DataFrame.
.. autosummary::
:toctree: ../generated
Session.partial_sql
DataFrame.map
DataFrame.map_batches
DataFrame.flat_map
DataFrame.filter
DataFrame.limit
DataFrame.partial_sort
DataFrame.random_shuffle
.. _consuming_data:
Consuming Data
--------------
These operations will trigger execution of the lazy transformations performed on this DataFrame.
.. autosummary::
:toctree: ../generated
DataFrame.count
DataFrame.take
DataFrame.take_all
DataFrame.to_arrow
DataFrame.to_pandas
DataFrame.write_parquet
DataFrame.write_parquet_lazy
Execution
---------
DataFrames are lazily computed. You can use these methods to manually trigger computation.
.. autosummary::
:toctree: ../generated
DataFrame.compute
DataFrame.is_computed
DataFrame.recompute
Session.wait

View File

@@ -0,0 +1,28 @@
.. currentmodule:: smallpond.logical.dataset
Dataset
=======
Dataset represents a collection of files.
To create a dataset:
.. code-block:: python
dataset = ParquetDataSet("path/to/dataset/*.parquet")
DataSets
--------
.. autosummary::
:toctree: ../generated
DataSet
FileSet
ParquetDataSet
CsvDataSet
JsonDataSet
ArrowTableDataSet
PandasDataSet
PartitionedDataSet
SqlQueryDataSet

View File

@@ -0,0 +1,85 @@
.. currentmodule:: smallpond.execution
.. _execution:
Execution
=========
Submit a Job
------------
After constructing the LogicalPlan, you can use the JobManager to create a Job in the cluster to execute it. However, in most cases, you only need to use the Driver as the entry point of the entire script and then submit the plan. The Driver is a simple wrapper around the JobManager. It reads the configuration from the command line arguments and passes it to the JobManager.
.. code-block:: python
from smallpond.execution.driver import Driver
if __name__ == "__main__":
driver = Driver()
# add your own arguments
driver.add_argument("-i", "--input_paths", nargs="+")
driver.add_argument("-n", "--npartitions", type=int, default=10)
# build and run logical plan
plan = my_pipeline(**driver.get_arguments())
driver.run(plan)
.. autosummary::
:toctree: ../generated
~driver.Driver
~manager.JobManager
Scheduler and Executor
----------------------
Scheduler and Executor are lower-level APIs. They are directly responsible for scheduling and executing tasks, respectively. Generally, users do not need to use them directly.
.. autosummary::
:toctree: ../generated
~scheduler.Scheduler
~executor.Executor
.. _platform:
Customize Platform
------------------
Smallpond supports user-defined task execution platforms. A Platform includes methods for submitting jobs and a series of default configurations. By default, smallpond automatically detects the current environment and selects the most suitable platform. If it cannot detect one, it uses the default platform.
You can specify a built-in platform via parameters:
.. code-block:: bash
# run with your platform
python script.py --platform mpi
Or implement your own Platform class:
.. code-block:: python
# path/to/my/platform.py
from smallpond.platform import Platform
class MyPlatform(Platform):
def start_job(self, ...) -> List[str]:
...
.. code-block:: bash
# run with your platform
# if using Driver
python script.py --platform path.to.my.platform
# if using smallpond.init
SP_PLATFORM=path.to.my.platform python script.py
.. currentmodule:: smallpond
.. autosummary::
:toctree: ../generated
~platform.Platform
~platform.MPI

89
docs/source/api/nodes.rst Normal file
View File

@@ -0,0 +1,89 @@
.. currentmodule:: smallpond.logical.node
.. _nodes:
Nodes
=====
Nodes represent the fundamental building blocks of a data processing pipeline. Each node encapsulates a specific operation or transformation that can be applied to a dataset.
Nodes can be chained together to form a logical plan, which is a directed acyclic graph (DAG) of nodes that represent the overall data processing workflow.
A typical workflow to create a logical plan is as follows:
.. code-block:: python
# Create a global context
ctx = Context()
# Create a dataset
dataset = ParquetDataSet("path/to/dataset/*.parquet")
# Create a data source node
node = DataSourceNode(ctx, dataset)
# Partition the data
node = DataSetPartitionNode(ctx, (node,), npartitions=2)
# Create a SQL engine node to transform the data
node = SqlEngineNode(ctx, (node,), "SELECT * FROM {0}")
# Create a logical plan from the root node
plan = LogicalPlan(ctx, node)
You can then create tasks from the logical plan, see :ref:`tasks`.
Notable properties of Node:
1. Nodes are partitioned. Each Node generates a series of tasks, with each task processing one partition of data.
2. The input and output of a Node are a series of partitioned Datasets. A Node may write data to shared storage and return a new Dataset, or it may simply recombine the input Datasets.
Context
-------
.. autosummary::
:toctree: ../generated
Context
NodeId
LogicalPlan
-----------
.. autosummary::
:toctree: ../generated
LogicalPlan
LogicalPlanVisitor
.. Planner
Nodes
-----
.. autosummary::
:toctree: ../generated
Node
DataSetPartitionNode
ArrowBatchNode
ArrowComputeNode
ArrowStreamNode
ConsolidateNode
DataSinkNode
DataSourceNode
EvenlyDistributedPartitionNode
HashPartitionNode
LimitNode
LoadPartitionedDataSetNode
PandasBatchNode
PandasComputeNode
PartitionNode
ProjectionNode
PythonScriptNode
RangePartitionNode
RepeatPartitionNode
RootNode
ShuffleNode
SqlEngineNode
UnionNode
UserDefinedPartitionNode
UserPartitionedDataSourceNode

76
docs/source/api/tasks.rst Normal file
View File

@@ -0,0 +1,76 @@
.. currentmodule:: smallpond.execution.task
.. _tasks:
Tasks
=====
.. code-block:: python
# create a runtime context
runtime_ctx = RuntimeContext(JobId.new(), data_root)
runtime_ctx.initialize(socket.gethostname(), cleanup_root=True)
# create a logical plan
plan = create_logical_plan()
# create an execution plan
planner = Planner(runtime_ctx)
exec_plan = planner.create_exec_plan(plan)
You can then execute the tasks in a scheduler, see :ref:`execution`.
RuntimeContext
--------------
.. autosummary::
:toctree: ../generated
RuntimeContext
JobId
TaskId
TaskRuntimeId
PartitionInfo
PerfStats
ExecutionPlan
-------------
.. autosummary::
:toctree: ../generated
ExecutionPlan
Tasks
-----
.. autosummary::
:toctree: ../generated
Task
ArrowBatchTask
ArrowComputeTask
ArrowStreamTask
DataSinkTask
DataSourceTask
EvenlyDistributedPartitionProducerTask
HashPartitionArrowTask
HashPartitionDuckDbTask
HashPartitionTask
LoadPartitionedDataSetProducerTask
MergeDataSetsTask
PandasBatchTask
PandasComputeTask
PartitionConsumerTask
PartitionProducerTask
ProjectionTask
PythonScriptTask
RangePartitionTask
RepeatPartitionProducerTask
RootTask
SplitDataSetTask
SqlEngineTask
UserDefinedPartitionProducerTask

44
docs/source/conf.py Normal file
View File

@@ -0,0 +1,44 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "smallpond"
copyright = "2025, deepseek"
author = "deepseek"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
]
templates_path = ["_templates"]
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "pydata_sphinx_theme"
html_static_path = ["_static"]
html_theme_options = {
"icon_links": [
{
"name": "GitHub",
"url": "https://github.com/deepseek-ai/smallpond",
"icon": "fa-brands fa-square-github",
}
]
}
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))

View File

@@ -0,0 +1,83 @@
Getting Started
===============
Installation
------------
Python 3.8+ is required.
.. code-block:: bash
pip install smallpond
Initialization
--------------
The first step is to initialize the smallpond session:
.. code-block:: python
import smallpond
sp = smallpond.init()
Loading Data
------------
Create a DataFrame from a set of files:
.. code-block:: python
df = sp.read_parquet("path/to/dataset/*.parquet")
To learn more about loading data, please refer to :ref:`loading_data`.
Partitioning Data
-----------------
Smallpond requires users to manually specify data partitions for now.
.. code-block:: python
df = df.repartition(3) # repartition by files
df = df.repartition(3, by_row=True) # repartition by rows
df = df.repartition(3, hash_by="host") # repartition by hash of column
To learn more about partitioning data, please refer to :ref:`partitioning_data`.
Transforming Data
-----------------
Apply python functions or SQL expressions to transform data.
.. code-block:: python
df = df.map('a + b as c')
df = df.map(lambda row: {'c': row['a'] + row['b']})
To learn more about transforming data, please refer to :ref:`transformations`.
Saving Data
-----------
Save the transformed data to a set of files:
.. code-block:: python
df.write_parquet("path/to/output")
To learn more about saving data, please refer to :ref:`consuming_data`.
Monitoring
----------
Smallpond uses `Ray Core`_ as the task scheduler. You can use `Ray Dashboard`_ to monitor the task execution.
.. _Ray Core: https://docs.ray.io/en/latest/ray-core/walkthrough.html
.. _Ray Dashboard: https://docs.ray.io/en/latest/ray-observability/getting-started.html
When smallpond starts, it will print the Ray Dashboard URL:
.. code-block:: bash
... Started a local Ray instance. View the dashboard at http://127.0.0.1:8008

27
docs/source/index.rst Normal file
View File

@@ -0,0 +1,27 @@
smallpond
=========
Smallpond is a lightweight distributed data processing framework.
It uses `duckdb`_ as the compute engine and stores data in `parquet`_ format on a distributed file system (e.g. `3FS`_).
.. _duckdb: https://duckdb.org/
.. _parquet: https://parquet.apache.org/
.. _3FS: https://github.com/deepseek-ai/3fs
Why smallpond?
--------------
- **Performance**: Smallpond uses DuckDB to deliver native-level performance for efficient data processing.
- **Scalability**: Leverages high-performance distributed file systems for intermediate storage, enabling PB-scale data handling without memory bottlenecks.
- **Simplicity**: No long-running services or complex dependencies, making it easy to deploy and maintain.
.. toctree::
:maxdepth: 1
getstarted
internals
.. toctree::
:maxdepth: 3
api

37
docs/source/internals.rst Normal file
View File

@@ -0,0 +1,37 @@
Internals
=========
Data Root
---------
Smallpond stores all data in a single directory called data root.
This directory has the following structure:
.. code-block:: bash
data_root
└── 2024-12-11-12-00-28.2cc39990-296f-48a3-8063-78cf6dca460b # job_time.job_id
├── config # configuration and state
│ ├── exec_plan.pickle
│ ├── logical_plan.pickle
│ └── runtime_ctx.pickle
├── log # logs
│ ├── graph.png
│ └── scheduler.log
├── queue # message queue between scheduler and workers
├── output # output data
├── staging # intermediate data
│ ├── DataSourceTask.000001
│ ├── EvenlyDistributedPartitionProducerTask.000002
│ ├── completed_tasks # output dataset of completed tasks
│ └── started_tasks # used for checkpoint
└── temp # temporary data
├── DataSourceTask.000001
└── EvenlyDistributedPartitionProducerTask.000002
Failure Recovery
----------------
Smallpond can recover from failure and resume execution from the last checkpoint.
Checkpoint is task-level. A few tasks, such as `ArrowBatchTask`, support checkpointing at the batch level.

0
examples/__init__.py Normal file
View File

225
examples/fstest.py Normal file
View File

@@ -0,0 +1,225 @@
# Test the correctness of file system read and write.
#
# This script runs multiple tasks to write and read data to/from the file system.
# Each task writes to an individual file in the given directory.
# Then it reads the data back and verifies the correctness.
import argparse
import glob
import logging
import os
import random
import time
from typing import Any, Dict, Iterator, Optional, Tuple, Union
import numpy as np
import smallpond
from smallpond.dataframe import Session
def fswrite(
path: str,
size: int,
blocksize: Union[int, Tuple[int, int]],
) -> Dict[str, Any]:
t0 = time.time()
with open(path, "wb") as f:
for start, length in iter_io_slice(0, size, blocksize):
logging.debug(f"writing {length} bytes at offset {start}")
f.write(generate_data(start, length))
t1 = time.time()
elapsed = t1 - t0
logging.info(f"write done: {path} in {elapsed:.2f}s")
return {
"path": path,
"size": size,
"elapsed(s)": elapsed,
"throughput(MB/s)": size / elapsed / 1024 / 1024,
}
def fsread(
path: str,
blocksize: Union[int, Tuple[int, int]],
randread: bool,
) -> Dict[str, Any]:
t0 = time.time()
with open(path, "rb") as f:
size = os.path.getsize(path)
slices = list(iter_io_slice(0, size, blocksize))
if randread:
random.shuffle(slices)
for start, length in slices:
logging.debug(f"reading {length} bytes at offset {start}")
f.seek(start)
data = f.read(length)
expected_data = generate_data(start, length)
check_data(data, expected_data, start)
t1 = time.time()
elapsed = t1 - t0
logging.info(f"read done: {path} in {elapsed:.2f}s")
return {
"path": path,
"size": size,
"elapsed(s)": elapsed,
"throughput(MB/s)": size / elapsed / 1024 / 1024,
}
def check_data(actual: bytes, expected: bytes, offset: int) -> None:
"""
Check if the expected data matches the actual data. Raise an error if there is a mismatch.
"""
if expected == actual:
return
# find the first mismatch
index = next(
(i for i, (b1, b2) in enumerate(zip(actual, expected)) if b1 != b2),
min(len(actual), len(expected)),
)
expected = expected[index : index + 16]
actual = actual[index : index + 16]
raise ValueError(
f"Data mismatch at offset {offset + index}.\nexpect: {expected}\nactual: {actual}"
)
def generate_data(offset: int, length: int) -> bytes:
"""
Generate data for the slice [offset, offset + length).
The full data is a repeated sequence of [0x00000000, 0x00000001, ..., 0xffffffff] in little-endian.
"""
istart = offset // 4
iend = (offset + length + 3) // 4
return (
np.arange(istart, iend)
.astype(np.uint32)
.tobytes()[offset % 4 : offset % 4 + length]
)
def iter_io_slice(
offset: int, length: int, block_size: Union[int, Tuple[int, int]]
) -> Iterator[Tuple[int, int]]:
"""
Generate the IO (offset, size) for the slice [offset, offset + length) with the given block size.
`block_size` can be an integer or a range [start, end]. If a range is provided, the IO size will be randomly selected from the range.
"""
start = offset
end = offset + length
while start < end:
if isinstance(block_size, int):
size = block_size
else:
smin, smax = block_size
size = random.randint(smin, smax)
size = min(size, end - start)
yield (start, size)
start += size
def size_str_to_bytes(size_str: str) -> int:
"""
Parse size string to bytes.
e.g. 1k -> 1024, 1M -> 1024^2, 1G -> 1024^3, 1T -> 1024^4
"""
if size_str.endswith("k"):
return int(size_str[:-1]) * 1024
elif size_str.endswith("M"):
return int(size_str[:-1]) * 1024 * 1024
elif size_str.endswith("G"):
return int(size_str[:-1]) * 1024 * 1024 * 1024
elif size_str.endswith("T"):
return int(size_str[:-1]) * 1024 * 1024 * 1024 * 1024
else:
return int(size_str)
def fstest(
sp: Session,
input_path: Optional[str],
output_path: Optional[str],
size: Optional[str],
npartitions: int,
blocksize: Optional[str] = "4k",
blocksize_range: Optional[str] = None,
randread: bool = False,
) -> None:
# preprocess arguments
if output_path is not None and size is None:
raise ValueError("--size is required if --output_path is provided")
if size is not None:
size = size_str_to_bytes(size)
if blocksize_range is not None:
start, end = blocksize_range.split("-")
blocksize = (size_str_to_bytes(start), size_str_to_bytes(end))
elif blocksize is not None:
blocksize = size_str_to_bytes(blocksize)
else:
raise ValueError("either --blocksize or --blocksize_range must be provided")
if output_path is not None:
os.makedirs(output_path, exist_ok=True)
df = sp.from_items(
[{"path": os.path.join(output_path, f"{i}")} for i in range(npartitions)]
)
df = df.repartition(npartitions, by_rows=True)
stats = df.map(lambda x: fswrite(x["path"], size, blocksize)).to_pandas()
logging.info(f"write stats:\n{stats}")
if input_path is not None:
paths = list(glob.glob(input_path))
df = sp.from_items([{"path": path} for path in paths])
df = df.repartition(len(paths), by_rows=True)
stats = df.map(lambda x: fsread(x["path"], blocksize, randread)).to_pandas()
logging.info(f"read stats:\n{stats}")
if __name__ == "__main__":
"""
Example usage:
- write only:
python example/fstest.py -o 'fstest' -j 8 -s 1G
- read only:
python example/fstest.py -i 'fstest/*'
- write and then read:
python example/fstest.py -o 'fstest' -j 8 -s 1G -i 'fstest/*'
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-o", "--output_path", type=str, help="The output path to write data to."
)
parser.add_argument(
"-i",
"--input_path",
type=str,
help="The input path to read data from. If -o is provided, this is ignored.",
)
parser.add_argument(
"-j", "--npartitions", type=int, help="The number of parallel jobs", default=10
)
parser.add_argument(
"-s",
"--size",
type=str,
help="The size for each file. Required if -o is provided.",
)
parser.add_argument("-bs", "--blocksize", type=str, help="Block size", default="4k")
parser.add_argument(
"-bsrange",
"--blocksize_range",
type=str,
help="A range of I/O block sizes. e.g. 4k-128k",
)
parser.add_argument(
"-randread",
"--randread",
action="store_true",
help="Whether to read data randomly",
default=False,
)
args = parser.parse_args()
sp = smallpond.init()
fstest(sp, **vars(args))

78
examples/shuffle_data.py Normal file
View File

@@ -0,0 +1,78 @@
from smallpond.contrib.copy_table import StreamCopy
from smallpond.execution.driver import Driver
from smallpond.logical.dataset import ParquetDataSet
from smallpond.logical.node import (
Context,
DataSetPartitionNode,
DataSourceNode,
HashPartitionNode,
LogicalPlan,
SqlEngineNode,
)
def shuffle_data(
input_paths,
num_out_data_partitions: int = 0,
num_data_partitions: int = 10,
num_hash_partitions: int = 10,
engine_type="duckdb",
skip_hash_partition=False,
) -> LogicalPlan:
ctx = Context()
dataset = ParquetDataSet(input_paths, union_by_name=True)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx,
(data_files,),
npartitions=num_data_partitions,
partition_by_rows=True,
random_shuffle=skip_hash_partition,
)
if skip_hash_partition:
urls_partitions = data_partitions
else:
urls_partitions = HashPartitionNode(
ctx,
(data_partitions,),
npartitions=num_hash_partitions,
hash_columns=None,
random_shuffle=True,
engine_type=engine_type,
)
shuffled_urls = SqlEngineNode(
ctx,
(urls_partitions,),
r"select *, cast(random() * 2147483647 as integer) as sort_key from {0} order by sort_key",
cpu_limit=16,
)
repartitioned = DataSetPartitionNode(
ctx,
(shuffled_urls,),
npartitions=num_out_data_partitions,
partition_by_rows=True,
)
shuffled_urls = StreamCopy(
ctx, (repartitioned,), output_name="data_copy", cpu_limit=1
)
plan = LogicalPlan(ctx, shuffled_urls)
return plan
def main():
driver = Driver()
driver.add_argument("-i", "--input_paths", nargs="+")
driver.add_argument("-nd", "--num_data_partitions", type=int, default=1024)
driver.add_argument("-nh", "--num_hash_partitions", type=int, default=3840)
driver.add_argument("-no", "--num_out_data_partitions", type=int, default=1920)
driver.add_argument(
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
)
driver.add_argument("-x", "--skip_hash_partition", action="store_true")
plan = shuffle_data(**driver.get_arguments())
driver.run(plan)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,73 @@
from smallpond.common import GB
from smallpond.execution.driver import Driver
from smallpond.logical.dataset import ParquetDataSet
from smallpond.logical.node import (
Context,
DataSetPartitionNode,
DataSourceNode,
HashPartitionNode,
LogicalPlan,
SqlEngineNode,
)
def shuffle_mock_urls(
input_paths, npartitions: int = 10, sort_rand_keys=True, engine_type="duckdb"
) -> LogicalPlan:
ctx = Context()
dataset = ParquetDataSet(input_paths)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions)
urls_partitions = HashPartitionNode(
ctx,
(data_partitions,),
npartitions=npartitions,
hash_columns=None,
random_shuffle=True,
engine_type=engine_type,
output_name="urls_partitions",
cpu_limit=1,
memory_limit=20 * GB,
)
if sort_rand_keys:
# shuffle as sorting partition keys
shuffled_urls = SqlEngineNode(
ctx,
(urls_partitions,),
r"select *, random() as partition_key from {0} order by partition_key",
output_name="shuffled_urls",
cpu_limit=1,
memory_limit=40 * GB,
)
else:
# shuffle as reservoir sampling
shuffled_urls = SqlEngineNode(
ctx,
(urls_partitions,),
r"select * from {0} using sample 100% (reservoir, {rand_seed})",
output_name="shuffled_urls",
cpu_limit=1,
memory_limit=40 * GB,
)
plan = LogicalPlan(ctx, shuffled_urls)
return plan
def main():
driver = Driver()
driver.add_argument("-i", "--input_paths", nargs="+")
driver.add_argument("-n", "--npartitions", type=int, default=500)
driver.add_argument("-s", "--sort_rand_keys", action="store_true")
driver.add_argument(
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
)
plan = shuffle_mock_urls(**driver.get_arguments())
driver.run(plan)
if __name__ == "__main__":
main()

104
examples/sort_mock_urls.py Normal file
View File

@@ -0,0 +1,104 @@
import logging
import os.path
from typing import List, Optional, OrderedDict
import pyarrow as arrow
from smallpond.execution.driver import Driver
from smallpond.execution.task import RuntimeContext
from smallpond.logical.dataset import CsvDataSet
from smallpond.logical.node import (
ArrowComputeNode,
Context,
DataSetPartitionNode,
DataSinkNode,
DataSourceNode,
HashPartitionNode,
LogicalPlan,
SqlEngineNode,
)
class SortUrlsNode(ArrowComputeNode):
def process(
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
logging.info(f"sorting urls by 'host', table shape: {input_tables[0].shape}")
return input_tables[0].sort_by("host")
def sort_mock_urls(
input_paths,
npartitions: int,
engine_type="duckdb",
external_output_path: Optional[str] = None,
) -> LogicalPlan:
ctx = Context()
dataset = CsvDataSet(
input_paths,
schema=OrderedDict([("urlstr", "varchar"), ("valstr", "varchar")]),
delim=r"\t",
)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions)
imported_urls = SqlEngineNode(
ctx,
(data_partitions,),
r"""
select split_part(urlstr, '/', 1) as host, split_part(urlstr, ' ', 1) as url, from_base64(valstr) AS payload from {0}
""",
output_name="imported_urls",
output_path=external_output_path,
)
urls_partitions = HashPartitionNode(
ctx,
(imported_urls,),
npartitions=npartitions,
hash_columns=["host"],
engine_type=engine_type,
output_name="urls_partitions",
output_path=external_output_path,
)
if engine_type == "duckdb":
sorted_urls = SqlEngineNode(
ctx,
(urls_partitions,),
r"select * from {0} order by host",
output_name="sorted_urls",
)
else:
sorted_urls = SortUrlsNode(
ctx,
(urls_partitions,),
output_name="sorted_urls",
output_path=external_output_path,
)
final_result = DataSetPartitionNode(ctx, (sorted_urls,), npartitions=1)
if external_output_path:
final_result = DataSinkNode(
ctx,
(final_result,),
output_path=os.path.join(external_output_path, "data_sink"),
)
plan = LogicalPlan(ctx, final_result)
return plan
def main():
driver = Driver()
driver.add_argument(
"-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"]
)
driver.add_argument("-n", "--npartitions", type=int, default=10)
driver.add_argument("-e", "--engine_type", default="duckdb")
plan = sort_mock_urls(**driver.get_arguments())
driver.run(plan)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,36 @@
import argparse
from typing import List
import smallpond
from smallpond.dataframe import Session
def sort_mock_urls_v2(
sp: Session, input_paths: List[str], output_path: str, npartitions: int
):
dataset = sp.read_csv(
input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t"
).repartition(npartitions)
urls = dataset.map(
"""
split_part(urlstr, '/', 1) as host,
split_part(urlstr, ' ', 1) as url,
from_base64(valstr) AS payload
"""
)
urls = urls.repartition(npartitions, hash_by="host")
sorted_urls = urls.partial_sort(by=["host"])
sorted_urls.write_parquet(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"]
)
parser.add_argument("-o", "--output_path", type=str, default="sort_mock_urls")
parser.add_argument("-n", "--npartitions", type=int, default=10)
args = parser.parse_args()
sp = smallpond.init()
sort_mock_urls_v2(sp, **vars(args))

79
pyproject.toml Normal file
View File

@@ -0,0 +1,79 @@
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
name = "smallpond"
version = "0.15.0"
description = "A lightweight data processing framework built on DuckDB and shared file system."
authors = [
{ name = "DeepSeek-AI", email = "research@deepseek.com" },
{ name = "Runji Wang" },
{ name = "Yiliang Xiong" },
{ name = "Yiyuan Liu" },
{ name = "Yuheng Zou" },
{ name = "Yichao Zhang" },
{ name = "Wenjun Gao" },
{ name = "Wentao Zhang" },
{ name = "Xiaotao Nie" },
{ name = "Minghua Zhang" },
{ name = "Zhewen Hao" },
]
urls = { Homepage = "https://github.com/deepseek-ai/smallpond" }
keywords = ["distributed query processing", "SQL", "parquet"]
requires-python = ">=3.8"
dependencies = [
"duckdb >= 1.2.0",
"pyarrow ~= 16.1.0",
"polars ~= 0.20.9",
"pandas >= 1.3.4",
"plotly >= 5.22.0",
"lxml >= 4.9.3",
"cloudpickle >= 2.0.0",
"zstandard >= 0.22.0",
"loguru >= 0.7.2",
"psutil >= 5.9.8",
"GPUtil >= 1.4.0",
"py-libnuma >= 1.2",
"fsspec >= 2023.12.2",
"ray[default] >= 2.10.0",
"graphviz >= 0.19.1",
]
[project.optional-dependencies]
dev = [
"coverage~=7.4.4",
"hypothesis~=6.100.0",
"pytest==8.2.1",
"pytest-cov==5.0.0",
"pytest-forked==1.6.0",
"pytest-xdist==3.6.1",
"pytest-timeout==2.3.1",
"pytest-benchmark==4.0.0",
"setproctitle==1.3.3",
"soupsieve~=2.5",
"setuptools-scm==8.1.0",
"packaging==24.2",
"jaraco.functools==4.1.0",
]
docs = [
"sphinx==7.1.2",
"pydata-sphinx-theme==0.14.4",
]
warc = [
"warcio >= 1.7.4",
"beautifulsoup4 >= 4.12.2",
]
[tool.setuptools]
packages = [
"smallpond",
"smallpond.io",
"smallpond.execution",
"smallpond.logical",
"smallpond.contrib",
"smallpond.platform",
]
[tool.setuptools_scm]
fallback_version = "0.0.0"

89
smallpond/__init__.py Normal file
View File

@@ -0,0 +1,89 @@
from importlib.metadata import PackageNotFoundError, version
from typing import Optional
try:
__version__ = version("smallpond")
except PackageNotFoundError:
# package is not installed
__version__ = "unknown"
def init(
job_id: Optional[str] = None,
job_time: Optional[float] = None,
job_name: Optional[str] = None,
data_root: Optional[str] = None,
num_executors: Optional[int] = None,
ray_address: Optional[str] = None,
bind_numa_node: Optional[bool] = None,
platform: Optional[str] = None,
_remove_output_root: bool = True,
**kwargs,
) -> "Session":
"""
Initialize smallpond environment.
This is the entry point for smallpond::
import smallpond
sp = smallpond.init()
By default, it will use a local ray cluster as worker node.
To use more worker nodes, please specify the argument::
sp = smallpond.init(num_executors=10)
It will create an task to run the workers.
Parameters
----------
All parameters are optional. If not specified, read from environment variables. If not set, use default values.
job_id (SP_JOBID)
Unique job id. Default to a random uuid.
job_time (SP_JOB_TIME)
Job create time (seconds since epoch). Default to current time.
job_name (SP_JOB_NAME)
Job display name. Default to the filename of the current script.
data_root (SP_DATA_ROOT)
The root folder for all files generated at runtime.
num_executors (SP_NUM_EXECUTORS)
The number of executors.
Default to 0, which means all tasks will be run on the current node.
ray_address (SP_RAY_ADDRESS)
If specified, use the given address to connect to an existing ray cluster.
Otherwise, create a new ray cluster.
bind_numa_node (SP_BIND_NUMA_NODE)
If true, bind executor processes to numa nodes.
memory_allocator (SP_MEMORY_ALLOCATOR)
The memory allocator used by worker processes.
Choices: "system", "jemalloc", "mimalloc". Default to "mimalloc".
platform (SP_PLATFORM)
The platform to use. Choices: "mpi".
By default, it will automatically detect the environment and choose the most suitable platform.
_remove_output_root
If true, remove the "{data_root}/output" directory after the job is finished.
Default to True. This is only used for compatibility. User should use `write_parquet` for saving outputs.
Spawning a new job
------------------
If the environment variable `SP_SPAWN` is set to "1", it will spawn a new job to run the current script.
"""
import atexit
from smallpond.dataframe import Session
session = Session(
job_id=job_id,
job_time=job_time,
job_name=job_name,
data_root=data_root,
num_executors=num_executors,
ray_address=ray_address,
bind_numa_node=bind_numa_node,
platform=platform,
_remove_output_root=_remove_output_root,
**kwargs,
)
atexit.register(session.shutdown)
return session

117
smallpond/common.py Normal file
View File

@@ -0,0 +1,117 @@
import itertools
import math
import sys
from typing import Dict, List, TypeVar
import numpy as np
from smallpond.logical.udf import *
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
TB = 1024 * GB
DEFAULT_MAX_RETRY_COUNT = 5
DEFAULT_MAX_FAIL_COUNT = 3
# duckdb default row group size https://duckdb.org/docs/data/parquet/tips#selecting-a-row_group_size
MAX_ROW_GROUP_SIZE = 10 * 1024 * 1024
MAX_ROW_GROUP_BYTES = 2 * GB
MAX_NUM_ROW_GROUPS = 256
MAX_PARQUET_FILE_BYTES = 8 * GB
DEFAULT_ROW_GROUP_SIZE = 122880
DEFAULT_ROW_GROUP_BYTES = DEFAULT_ROW_GROUP_SIZE * 4 * KB
DEFAULT_BATCH_SIZE = 122880
RAND_SEED_BYTE_LEN = 128
DATA_PARTITION_COLUMN_NAME = "__data_partition__"
PARQUET_METADATA_KEY_PREFIX = "SMALLPOND:"
INPUT_VIEW_PREFIX = "__input"
GENERATED_COLUMNS = ("filename", "file_row_number")
def pytest_running():
return "pytest" in sys.modules
def clamp_value(val, minval, maxval):
return max(minval, min(val, maxval))
def clamp_row_group_size(val, minval=DEFAULT_ROW_GROUP_SIZE, maxval=MAX_ROW_GROUP_SIZE):
return clamp_value(val, minval, maxval)
def clamp_row_group_bytes(
val, minval=DEFAULT_ROW_GROUP_BYTES, maxval=MAX_ROW_GROUP_BYTES
):
return clamp_value(val, minval, maxval)
class SmallpondError(Exception):
"""Base class for all errors in smallpond."""
class InjectedFault(SmallpondError):
pass
class OutOfMemory(SmallpondError):
pass
class NonzeroExitCode(SmallpondError):
pass
K = TypeVar("K")
V = TypeVar("V")
def first_value_in_dict(d: Dict[K, V]) -> V:
return next(iter(d.values())) if d else None
def split_into_cols(items: List[V], npartitions: int) -> List[List[V]]:
none = object()
chunks = [items[i : i + npartitions] for i in range(0, len(items), npartitions)]
return [
[x for x in col if x is not none]
for col in itertools.zip_longest([none] * npartitions, *chunks, fillvalue=none)
]
def split_into_rows(items: List[V], npartitions: int) -> List[List[V]]:
"""
Evenly split items into npartitions.
Example::
>>> split_into_rows(list(range(10)), 3)
[[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]]
"""
split_idxs = np.array_split(np.arange(len(items)), npartitions)
return [[items[i] for i in idxs] for idxs in split_idxs]
def get_nth_partition(items: List[V], n: int, npartitions: int) -> List[V]:
num_items = len(items)
large_partition_size = (num_items + npartitions - 1) // npartitions
small_partition_size = num_items // npartitions
num_large_partitions = num_items - small_partition_size * npartitions
if n < num_large_partitions:
start = n * large_partition_size
items_in_partition = items[start : start + large_partition_size]
else:
start = (
large_partition_size * num_large_partitions
+ (n - num_large_partitions) * small_partition_size
)
items_in_partition = items[start : start + small_partition_size]
return items_in_partition
def next_power_of_two(x) -> int:
return 2 ** (x - 1).bit_length()
def round_up(x, align_size=MB) -> int:
return math.ceil(x / align_size) * align_size

View File

View File

@@ -0,0 +1,24 @@
from typing import Iterable, List
import pyarrow as arrow
from loguru import logger
from smallpond.execution.task import RuntimeContext
from smallpond.logical.node import ArrowComputeNode, ArrowStreamNode
class CopyArrowTable(ArrowComputeNode):
def process(
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
logger.info(f"copying table: {input_tables[0].num_rows} rows ...")
return input_tables[0]
class StreamCopy(ArrowStreamNode):
def process(
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
) -> Iterable[arrow.Table]:
for batch in input_readers[0]:
logger.info(f"copying batch: {batch.num_rows} rows ...")
yield arrow.Table.from_batches([batch])

View File

@@ -0,0 +1,37 @@
from typing import List, Tuple
from smallpond.execution.task import PythonScriptTask, RuntimeContext
from smallpond.logical.dataset import DataSet
from smallpond.logical.node import Context, Node, PythonScriptNode
class LogDataSetTask(PythonScriptTask):
num_rows = 200
@property
def exec_on_scheduler(self) -> bool:
return True
def process(
self,
runtime_ctx: RuntimeContext,
input_datasets: List[DataSet],
output_path: str,
) -> bool:
for dataset in input_datasets:
dataset.log(self.num_rows)
return True
class LogDataSet(PythonScriptNode):
def __init__(
self, ctx: Context, input_deps: Tuple[Node, ...], num_rows=200, **kwargs
) -> None:
super().__init__(ctx, input_deps, **kwargs)
self.num_rows = num_rows
def spawn(self, *args, **kwargs) -> LogDataSetTask:
task = LogDataSetTask(*args, **kwargs)
task.num_rows = self.num_rows
return task

143
smallpond/contrib/warc.py Normal file
View File

@@ -0,0 +1,143 @@
import string
import sys
import unicodedata
from pathlib import PurePath
from typing import Iterable, List, Tuple
from urllib.parse import urlparse
import pyarrow as arrow
import zstandard as zstd
from bs4 import BeautifulSoup
from loguru import logger
from warcio import ArchiveIterator
from smallpond.common import MB
from smallpond.execution.task import RuntimeContext
from smallpond.io.arrow import dump_to_parquet_files
from smallpond.logical.dataset import DataSet
from smallpond.logical.node import ArrowStreamNode, PythonScriptNode
class ImportWarcFiles(PythonScriptNode):
schema = arrow.schema(
[
arrow.field("url", arrow.string()),
arrow.field("domain", arrow.string()),
arrow.field("date", arrow.string()),
arrow.field("content", arrow.binary()),
]
)
def import_warc_file(
self, warc_path: PurePath, parquet_path: PurePath
) -> Tuple[int, int]:
total_size = 0
docs = []
with open(warc_path, "rb") as warc_file:
zstd_reader = zstd.ZstdDecompressor().stream_reader(
warc_file, read_size=16 * MB
)
for record in ArchiveIterator(zstd_reader):
if record.rec_type == "response":
url = record.rec_headers.get_header("WARC-Target-URI")
domain = urlparse(url).netloc
date = record.rec_headers.get_header("WARC-Date")
content = record.content_stream().read()
total_size += len(content)
docs.append((url, domain, date, content))
table = arrow.Table.from_arrays(
[arrow.array(column) for column in zip(*docs)], schema=self.schema
)
dump_to_parquet_files(table, parquet_path.parent, parquet_path.name)
return len(docs), total_size
def process(
self,
runtime_ctx: RuntimeContext,
input_datasets: List[DataSet],
output_path: str,
) -> bool:
warc_paths = [
PurePath(warc_path)
for dataset in input_datasets
for warc_path in dataset.resolved_paths
]
parquet_paths = [
PurePath(output_path)
/ f"data{path_index}-{PurePath(warc_path.name).with_suffix('.parquet')}"
for path_index, warc_path in enumerate(warc_paths)
]
logger.info(f"importing web pages from {len(warc_paths)} warc files...")
for warc_path, parquet_path in zip(warc_paths, parquet_paths):
try:
doc_count, total_size = self.import_warc_file(warc_path, parquet_path)
logger.info(
f"imported {doc_count} web pages ({total_size/MB:.3f}MB) from file '{warc_path}' to '{parquet_path}'"
)
except Exception as ex:
logger.opt(exception=ex).error(
f"failed to import web pages from file '{warc_path}'"
)
return False
return True
class ExtractHtmlBody(ArrowStreamNode):
unicode_punctuation = "".join(
chr(i)
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
)
separator_str = string.whitespace + string.punctuation + unicode_punctuation
translator = str.maketrans(separator_str, " " * len(separator_str))
schema = arrow.schema(
[
arrow.field("url", arrow.string()),
arrow.field("domain", arrow.string()),
arrow.field("date", arrow.string()),
arrow.field("tokens", arrow.list_(arrow.string())),
]
)
def split_string(self, s: str):
return s.translate(self.translator).split()
def extract_tokens(self, url: arrow.string, content: arrow.binary) -> List[str]:
tokens = []
try:
doc = BeautifulSoup(content.as_py(), "lxml")
# if doc.title is not None and doc.title.string is not None:
# tokens.extend(self.split_string(doc.title.string.lower()))
tokens.extend(self.split_string(doc.get_text(" ", strip=True).lower()))
return tokens
except Exception as ex:
logger.opt(exception=ex).error(
f"failed to extract tokens from {url.as_py()}"
)
return []
def process(
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
) -> Iterable[arrow.Table]:
for batch in input_readers[0]:
urls, domains, dates, contents = batch.columns
doc_tokens = []
try:
for i, (url, content) in enumerate(zip(urls, contents)):
tokens = self.extract_tokens(url, content)
logger.info(
f"#{i}/{len(urls)} extracted {len(tokens)} tokens from {url}"
)
doc_tokens.append(tokens)
yield arrow.Table.from_arrays(
[urls, domains, dates, arrow.array(doc_tokens)], schema=self.schema
)
except Exception as ex:
logger.opt(exception=ex).error(f"failed to extract tokens")
break

715
smallpond/dataframe.py Normal file
View File

@@ -0,0 +1,715 @@
from __future__ import annotations
import os
import time
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import pandas as pd
import pyarrow as arrow
import ray
import ray.exceptions
from loguru import logger
from smallpond.execution.task import Task
from smallpond.io.filesystem import remove_path
from smallpond.logical.dataset import *
from smallpond.logical.node import *
from smallpond.logical.optimizer import Optimizer
from smallpond.logical.planner import Planner
from smallpond.session import SessionBase
class Session(SessionBase):
# Extended session class with additional methods to create DataFrames.
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._nodes: List[Node] = []
self._node_to_tasks: Dict[Node, List[Task]] = {}
"""
When a DataFrame is evaluated, the tasks of the logical plan are stored here.
Subsequent DataFrames can reuse the tasks to avoid recomputation.
"""
def read_csv(
self, paths: Union[str, List[str]], schema: Dict[str, str], delim=","
) -> DataFrame:
"""
Create a DataFrame from CSV files.
"""
dataset = CsvDataSet(paths, OrderedDict(schema), delim)
plan = DataSourceNode(self._ctx, dataset)
return DataFrame(self, plan)
def read_parquet(
self,
paths: Union[str, List[str]],
recursive: bool = False,
columns: Optional[List[str]] = None,
union_by_name: bool = False,
) -> DataFrame:
"""
Create a DataFrame from Parquet files.
"""
dataset = ParquetDataSet(
paths, columns=columns, union_by_name=union_by_name, recursive=recursive
)
plan = DataSourceNode(self._ctx, dataset)
return DataFrame(self, plan)
def read_json(
self, paths: Union[str, List[str]], schema: Dict[str, str]
) -> DataFrame:
"""
Create a DataFrame from JSON files.
"""
dataset = JsonDataSet(paths, schema)
plan = DataSourceNode(self._ctx, dataset)
return DataFrame(self, plan)
def from_items(self, items: List[Any]) -> DataFrame:
"""
Create a DataFrame from a list of local Python objects.
"""
assert isinstance(items, list), "items must be a list"
assert len(items) > 0, "items must not be empty"
if isinstance(items[0], dict):
return self.from_arrow(arrow.Table.from_pylist(items))
else:
return self.from_arrow(arrow.table({"item": items}))
def from_pandas(self, df: pd.DataFrame) -> DataFrame:
"""
Create a DataFrame from a pandas DataFrame.
"""
plan = DataSourceNode(self._ctx, PandasDataSet(df))
return DataFrame(self, plan)
def from_arrow(self, table: arrow.Table) -> DataFrame:
"""
Create a DataFrame from a pyarrow Table.
"""
plan = DataSourceNode(self._ctx, ArrowTableDataSet(table))
return DataFrame(self, plan)
def partial_sql(self, query: str, *inputs: DataFrame, **kwargs) -> DataFrame:
"""
Execute a SQL query on each partition of the input DataFrames.
The query can contain placeholder `{0}`, `{1}`, etc. for the input DataFrames.
If multiple DataFrames are provided, they must have the same number of partitions.
Examples
--------
Join two datasets. You need to make sure the join key is correctly partitioned.
.. code-block::
a = sp.read_parquet("a/*.parquet").repartition(10, hash_by="id")
b = sp.read_parquet("b/*.parquet").repartition(10, hash_by="id")
c = sp.partial_sql("select * from {0} join {1} on a.id = b.id", a, b)
"""
plan = SqlEngineNode(
self._ctx, tuple(input.plan for input in inputs), query, **kwargs
)
recompute = any(input.need_recompute for input in inputs)
return DataFrame(self, plan, recompute=recompute)
def wait(self, *dfs: DataFrame):
"""
Wait for all DataFrames to be computed.
Example
-------
This can be used to wait for multiple outputs from a pipeline:
.. code-block::
df = sp.read_parquet("input/*.parquet")
output1 = df.write_parquet("output1")
output2 = df.map("col1, col2").write_parquet("output2")
sp.wait(output1, output2)
"""
ray.get([task.run_on_ray() for df in dfs for task in df._get_or_create_tasks()])
def graph(self) -> Digraph:
"""
Get the DataFrame graph.
"""
dot = Digraph(comment="SmallPond")
for node in self._nodes:
dot.node(str(node.id), repr(node))
for dep in node.input_deps:
dot.edge(str(dep.id), str(node.id))
return dot
def shutdown(self):
"""
Shutdown the session.
"""
# prevent shutdown from being called multiple times
if hasattr(self, "_shutdown_called"):
return
self._shutdown_called = True
# log status
finished = self._all_tasks_finished()
with open(self._runtime_ctx.job_status_path, "a") as fout:
status = "success" if finished else "failure"
fout.write(f"{status}@{datetime.now():%Y-%m-%d-%H-%M-%S}\n")
# clean up runtime directories if success
if finished:
logger.info("all tasks are finished, cleaning up")
self._runtime_ctx.cleanup(remove_output_root=self.config.remove_output_root)
else:
logger.warning("tasks are not finished!")
super().shutdown()
def _summarize_task(self) -> Tuple[int, int]:
"""
Return the total number of tasks and the number of tasks that are finished.
"""
dataset_refs = [
task._dataset_ref
for tasks in self._node_to_tasks.values()
for task in tasks
if task._dataset_ref is not None
]
ready_tasks, _ = ray.wait(
dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False
)
return len(dataset_refs), len(ready_tasks)
def _all_tasks_finished(self) -> bool:
"""
Check if all tasks are finished.
"""
dataset_refs = [
task._dataset_ref
for tasks in self._node_to_tasks.values()
for task in tasks
]
try:
ray.get(dataset_refs, timeout=0)
except Exception:
# GetTimeoutError is raised if any task is not finished
# RuntimeError is raised if any task failed
return False
return True
class DataFrame:
"""
A distributed data collection. It represents a 2 dimensional table of rows and columns.
Internally, it's a wrapper around a `Node` and a `Session` required for execution.
"""
def __init__(self, session: Session, plan: Node, recompute: bool = False):
self.session = session
self.plan = plan
self.optimized_plan: Optional[Node] = None
self.need_recompute = recompute
"""Whether to recompute the data regardless of whether it's already computed."""
session._nodes.append(plan)
def __str__(self) -> str:
return repr(self.plan)
def _get_or_create_tasks(self) -> List[Task]:
"""
Get or create tasks to compute the data.
"""
# optimize the plan
if self.optimized_plan is None:
logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}")
self.optimized_plan = Optimizer(
exclude_nodes=set(self.session._node_to_tasks.keys())
).visit(self.plan)
logger.info(
f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}"
)
# return the tasks if already created
if tasks := self.session._node_to_tasks.get(self.optimized_plan):
return tasks
# remove all completed task files if recompute is needed
if self.need_recompute:
remove_path(
os.path.join(
self.session._runtime_ctx.completed_task_dir,
str(self.optimized_plan.id),
)
)
logger.info(f"cleared all results of {self.optimized_plan!r}")
# create tasks for the optimized plan
planner = Planner(self.session._runtime_ctx)
# let planner update self.session._node_to_tasks
planner.node_to_tasks = self.session._node_to_tasks
return planner.visit(self.optimized_plan)
def is_computed(self) -> bool:
"""
Check if the data is ready on disk.
"""
if tasks := self.session._node_to_tasks.get(self.plan):
_, unready_tasks = ray.wait(tasks, timeout=0)
return len(unready_tasks) == 0
return False
def compute(self) -> None:
"""
Compute the data.
This operation will trigger execution of the lazy transformations performed on this DataFrame.
"""
self._compute()
def _compute(self) -> List[DataSet]:
"""
Compute the data and return the datasets.
"""
for retry_count in range(3):
try:
return ray.get(
[task.run_on_ray() for task in self._get_or_create_tasks()]
)
except ray.exceptions.RuntimeEnvSetupError as e:
# XXX: Ray may raise this error when a worker is interrupted.
# ```
# ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment.
# Failed to create runtime env for job 01000000, status = IOError:
# on_read bad version, maybe there are some network problems, will fail the request.
# ```
# This is a bug of Ray and has been fixed in Ray 2.24: <https://github.com/ray-project/ray/pull/45513>
# However, since Ray dropped support for Python 3.8 since 2.11, we can not upgrade Ray.
# So we catch this error and retry by ourselves.
logger.error(f"found ray RuntimeEnvSetupError, retrying...\n{e}")
time.sleep(10 << retry_count)
raise RuntimeError("Failed to compute data after 3 retries")
# operations
def recompute(self) -> DataFrame:
"""
Always recompute the data regardless of whether it's already computed.
Examples
--------
Modify the code as follows and rerun:
.. code-block:: diff
- df = input.select('a')
+ df = input.select('b').recompute()
The result of `input` can be reused.
"""
self.need_recompute = True
return self
def repartition(
self,
npartitions: int,
hash_by: Union[str, List[str], None] = None,
by: Optional[str] = None,
by_rows: bool = False,
**kwargs,
) -> DataFrame:
"""
Repartition the data into the given number of partitions.
Parameters
----------
npartitions
The dataset would be split and distributed to `npartitions` partitions.
If not specified, the number of partitions would be the default partition size of the context.
hash_by, optional
If specified, the dataset would be repartitioned by the hash of the given columns.
by, optional
If specified, the dataset would be repartitioned by the given column.
by_rows, optional
If specified, the dataset would be repartitioned by rows instead of by files.
Examples
--------
.. code-block::
df = df.repartition(10) # evenly distributed
df = df.repartition(10, by_rows=True) # evenly distributed by rows
df = df.repartition(10, hash_by='host') # hash partitioned
df = df.repartition(10, by='bucket') # partitioned by column
"""
if by is not None:
assert hash_by is None, "cannot specify both by and hash_by"
plan = ShuffleNode(
self.session._ctx,
(self.plan,),
npartitions,
data_partition_column=by,
**kwargs,
)
elif hash_by is not None:
hash_columns = [hash_by] if isinstance(hash_by, str) else hash_by
plan = HashPartitionNode(
self.session._ctx, (self.plan,), npartitions, hash_columns, **kwargs
)
else:
plan = EvenlyDistributedPartitionNode(
self.session._ctx,
(self.plan,),
npartitions,
partition_by_rows=by_rows,
**kwargs,
)
return DataFrame(self.session, plan, recompute=self.need_recompute)
def random_shuffle(self, **kwargs) -> DataFrame:
"""
Randomly shuffle all rows globally.
"""
repartition = HashPartitionNode(
self.session._ctx,
(self.plan,),
self.plan.num_partitions,
random_shuffle=True,
**kwargs,
)
plan = SqlEngineNode(
self.session._ctx,
(repartition,),
r"select * from {0} order by random()",
**kwargs,
)
return DataFrame(self.session, plan, recompute=self.need_recompute)
def partial_sort(self, by: Union[str, List[str]], **kwargs) -> DataFrame:
"""
Sort rows by the given columns in each partition.
Parameters
----------
by
A column or a list of columns to sort by.
Examples
--------
.. code-block::
df = df.partial_sort(by='a')
df = df.partial_sort(by=['a', 'b desc'])
"""
by = [by] if isinstance(by, str) else by
plan = SqlEngineNode(
self.session._ctx,
(self.plan,),
f"select * from {{0}} order by {', '.join(by)}",
**kwargs,
)
return DataFrame(self.session, plan, recompute=self.need_recompute)
def filter(
self, sql_or_func: Union[str, Callable[[Dict[str, Any]], bool]], **kwargs
) -> DataFrame:
"""
Filter out rows that don't satisfy the given predicate.
Parameters
----------
sql_or_func
A SQL expression or a predicate function.
For functions, it should take a dictionary of columns as input and returns a boolean.
SQL expression is preferred as it's more efficient.
Examples
--------
.. code-block::
df = df.filter('a > 1')
df = df.filter(lambda r: r['a'] > 1)
"""
if isinstance(sql := sql_or_func, str):
plan = SqlEngineNode(
self.session._ctx,
(self.plan,),
f"select * from {{0}} where ({sql})",
**kwargs,
)
elif isinstance(func := sql_or_func, Callable):
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
table = tables[0]
return table.filter([func(row) for row in table.to_pylist()])
plan = ArrowBatchNode(
self.session._ctx, (self.plan,), process_func=process_func, **kwargs
)
else:
raise ValueError(
"condition must be a SQL expression or a predicate function"
)
return DataFrame(self.session, plan, recompute=self.need_recompute)
def map(
self,
sql_or_func: Union[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
*,
schema: Optional[arrow.Schema] = None,
**kwargs,
) -> DataFrame:
"""
Apply a function to each row.
Parameters
----------
sql_or_func
A SQL expression or a function to apply to each row.
For functions, it should take a dictionary of columns as input and returns a dictionary of columns.
SQL expression is preferred as it's more efficient.
schema, optional
The schema of the output DataFrame.
If not passed, will be inferred from the first row of the mapping values.
udfs, optional
A list of user defined functions to be referenced in the SQL expression.
Examples
--------
.. code-block::
df = df.map('a, b')
df = df.map('a + b as c')
df = df.map(lambda row: {'c': row['a'] + row['b']})
Use user-defined functions in SQL expression:
.. code-block::
@udf(params=[UDFType.INT, UDFType.INT], return_type=UDFType.INT)
def gcd(a: int, b: int) -> int:
while b:
a, b = b, a % b
return a
# load python udf
df = df.map('gcd(a, b)', udfs=[gcd])
# load udf from duckdb extension
df = df.map('gcd(a, b)', udfs=['path/to/udf.duckdb_extension'])
"""
if isinstance(sql := sql_or_func, str):
plan = SqlEngineNode(
self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs
)
elif isinstance(func := sql_or_func, Callable):
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
output_rows = [func(row) for row in tables[0].to_pylist()]
return arrow.Table.from_pylist(output_rows, schema=schema)
plan = ArrowBatchNode(
self.session._ctx, (self.plan,), process_func=process_func, **kwargs
)
else:
raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}")
return DataFrame(self.session, plan, recompute=self.need_recompute)
def flat_map(
self,
sql_or_func: Union[str, Callable[[Dict[str, Any]], List[Dict[str, Any]]]],
*,
schema: Optional[arrow.Schema] = None,
**kwargs,
) -> DataFrame:
"""
Apply a function to each row and flatten the result.
Parameters
----------
sql_or_func
A SQL expression or a function to apply to each row.
For functions, it should take a dictionary of columns as input and returns a list of dictionaries.
SQL expression is preferred as it's more efficient.
schema, optional
The schema of the output DataFrame.
If not passed, will be inferred from the first row of the mapping values.
Examples
--------
.. code-block::
df = df.flat_map('unnest(array[a, b]) as c')
df = df.flat_map(lambda row: [{'c': row['a']}, {'c': row['b']}])
"""
if isinstance(sql := sql_or_func, str):
plan = SqlEngineNode(
self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs
)
elif isinstance(func := sql_or_func, Callable):
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
output_rows = [
item for row in tables[0].to_pylist() for item in func(row)
]
return arrow.Table.from_pylist(output_rows, schema=schema)
plan = ArrowBatchNode(
self.session._ctx, (self.plan,), process_func=process_func, **kwargs
)
else:
raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}")
return DataFrame(self.session, plan, recompute=self.need_recompute)
def map_batches(
self,
func: Callable[[arrow.Table], arrow.Table],
*,
batch_size: int = 122880,
**kwargs,
) -> DataFrame:
"""
Apply the given function to batches of data.
Parameters
----------
func
A function or a callable class to apply to each batch of data.
It should take a `arrow.Table` as input and returns a `arrow.Table`.
batch_size, optional
The number of rows in each batch. Defaults to 122880.
"""
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
return func(tables[0])
plan = ArrowBatchNode(
self.session._ctx,
(self.plan,),
process_func=process_func,
streaming_batch_size=batch_size,
**kwargs,
)
return DataFrame(self.session, plan, recompute=self.need_recompute)
def limit(self, limit: int) -> DataFrame:
"""
Limit the number of rows to the given number.
Unlike `take`, this method doesn't trigger execution.
"""
plan = LimitNode(self.session._ctx, self.plan, limit)
return DataFrame(self.session, plan, recompute=self.need_recompute)
def write_parquet(self, path: str) -> None:
"""
Write data to a series of parquet files under the given path.
This is a blocking operation. See :func:`write_parquet_lazy` for a non-blocking version.
Examples
--------
.. code-block::
df.write_parquet('output')
"""
self.write_parquet_lazy(path).compute()
def write_parquet_lazy(self, path: str) -> DataFrame:
"""
Write data to a series of parquet files under the given path.
This is a non-blocking operation. See :func:`write_parquet` for a blocking version.
Examples
--------
.. code-block::
o1 = df.write_parquet_lazy('output1')
o2 = df.write_parquet_lazy('output2')
sp.wait(o1, o2)
"""
plan = DataSinkNode(
self.session._ctx, (self.plan,), os.path.abspath(path), type="link_or_copy"
)
return DataFrame(self.session, plan, recompute=self.need_recompute)
# inspection
def count(self) -> int:
"""
Count the number of rows.
If this dataframe consists of more than a read, or if the row count can't be determined from
the metadata provided by the datasource, then this operation will trigger execution of the
lazy transformations performed on this dataframe.
"""
datasets = self._compute()
# FIXME: don't use ThreadPoolExecutor because duckdb results will be mixed up
return sum(dataset.num_rows for dataset in datasets)
def take(self, limit: int) -> List[Dict[str, Any]]:
"""
Return up to `limit` rows.
This operation will trigger execution of the lazy transformations performed on this DataFrame.
"""
if self.is_computed() or isinstance(self.plan, DataSourceNode):
datasets = self._compute()
else:
datasets = self.limit(limit)._compute()
rows = []
for dataset in datasets:
for batch in dataset.to_batch_reader():
rows.extend(batch.to_pylist())
if len(rows) >= limit:
return rows[:limit]
return rows
def take_all(self) -> List[Dict[str, Any]]:
"""
Return all rows.
This operation will trigger execution of the lazy transformations performed on this DataFrame.
"""
datasets = self._compute()
rows = []
for dataset in datasets:
for batch in dataset.to_batch_reader():
rows.extend(batch.to_pylist())
return rows
def to_pandas(self) -> pd.DataFrame:
"""
Convert to a pandas DataFrame.
This operation will trigger execution of the lazy transformations performed on this DataFrame.
"""
datasets = self._compute()
with ThreadPoolExecutor() as pool:
return pd.concat(pool.map(lambda dataset: dataset.to_pandas(), datasets))
def to_arrow(self) -> arrow.Table:
"""
Convert to an arrow Table.
This operation will trigger execution of the lazy transformations performed on this DataFrame.
"""
datasets = self._compute()
with ThreadPoolExecutor() as pool:
return arrow.concat_tables(
pool.map(lambda dataset: dataset.to_arrow_table(), datasets)
)

View File

View File

@@ -0,0 +1,433 @@
import argparse
import os
import socket
import sys
from multiprocessing import Process
from typing import List, Optional
from loguru import logger
import smallpond
from smallpond.common import DEFAULT_MAX_FAIL_COUNT, DEFAULT_MAX_RETRY_COUNT
from smallpond.dataframe import DataFrame
from smallpond.execution.task import ExecutionPlan
from smallpond.io.filesystem import load
from smallpond.logical.node import LogicalPlan
class Driver(object):
"""
A helper class that includes boilerplate code to execute a logical plan.
"""
def __init__(self) -> None:
self.driver_args_parser = self._create_driver_args_parser()
self.user_args_parser = argparse.ArgumentParser(add_help=False)
self.driver_args = None
self.user_args = None
self.all_args = None
def _create_driver_args_parser(self):
parser = argparse.ArgumentParser(
prog="driver.py", description="Smallpond Driver", add_help=False
)
parser.add_argument(
"mode", choices=["executor", "scheduler", "ray"], default="executor"
)
parser.add_argument(
"--exec_id", default=socket.gethostname(), help="Unique executor id"
)
parser.add_argument("--job_id", type=str, help="Unique job id")
parser.add_argument(
"--job_time", type=float, help="Job create time (seconds since epoch)"
)
parser.add_argument(
"--job_name", default="smallpond", help="Display name of the job"
)
parser.add_argument(
"--job_priority",
type=int,
help="Job priority",
)
parser.add_argument("--resource_group", type=str, help="Resource group")
parser.add_argument(
"--env_variables", nargs="*", default=[], help="Env variables for the job"
)
parser.add_argument(
"--sidecars", nargs="*", default=[], help="Sidecars for the job"
)
parser.add_argument(
"--tags", nargs="*", default=[], help="Tags for submitted platform task"
)
parser.add_argument(
"--task_image", default="default", help="Container image of platform task"
)
parser.add_argument(
"--python_venv", type=str, help="Python virtual env for the job"
)
parser.add_argument(
"--data_root",
type=str,
help="The root folder for all files generated at runtime",
)
parser.add_argument(
"--runtime_ctx_path",
default=None,
help="The path of pickled runtime context passed from scheduler to executor",
)
parser.add_argument(
"--num_executors",
default=0,
type=int,
help="The number of nodes/executors (run all tasks on scheduler if set to zero)",
)
parser.add_argument(
"--num_executors_per_task",
default=5,
type=int,
help="The number of nodes/executors in each platform task.",
)
parser.add_argument(
"--random_seed",
type=int,
default=None,
help="Random seed for the job, default: int.from_bytes((os.urandom(128))",
)
parser.add_argument(
"--max_retry",
"--max_retry_count",
dest="max_retry_count",
default=DEFAULT_MAX_RETRY_COUNT,
type=int,
help="The max number of times a task is retried by speculative execution",
)
parser.add_argument(
"--max_fail",
"--max_fail_count",
dest="max_fail_count",
default=DEFAULT_MAX_FAIL_COUNT,
type=int,
help="The number of times a task is allowed to fail or crash before giving up",
)
parser.add_argument(
"--prioritize_retry",
action="store_true",
help="Prioritize retry tasks in scheduling",
)
parser.add_argument(
"--speculative_exec",
dest="speculative_exec",
choices=["disable", "enable", "aggressive"],
help="Level of speculative execution",
)
parser.add_argument(
"--enable_speculative_exec",
dest="speculative_exec",
action="store_const",
const="enable",
help="Enable speculative execution",
)
parser.add_argument(
"--disable_speculative_exec",
dest="speculative_exec",
action="store_const",
const="disable",
help="Disable speculative execution",
)
parser.set_defaults(speculative_exec="enable")
parser.add_argument(
"--stop_executor_on_failure",
action="store_true",
help="Stop an executor if any task fails on it",
)
parser.add_argument(
"--fail_fast",
"--fail_fast_on_failure",
dest="fail_fast_on_failure",
action="store_true",
help="Stop all executors if any task fails",
)
parser.add_argument(
"--nonzero_exitcode_as_oom",
action="store_true",
help="Treat task crash as oom error",
)
parser.add_argument(
"--fault_inject",
"--fault_inject_prob",
dest="fault_inject_prob",
type=float,
default=0.0,
help="Inject random errors at runtime (for test)",
)
parser.add_argument(
"--enable_profiling",
action="store_true",
help="Enable Python profiling for each task",
)
parser.add_argument(
"--enable_diagnostic_metrics",
action="store_true",
help="Enable diagnostic metrcis which may have performance impact",
)
parser.add_argument(
"--disable_sched_state_dump",
dest="enable_sched_state_dump",
action="store_false",
help="Disable periodic dump of scheduler state",
)
parser.add_argument(
"--enable_sched_state_dump",
dest="enable_sched_state_dump",
action="store_true",
help="Enable periodic dump of scheduler state so that scheduler can resume execution after restart",
)
parser.set_defaults(enable_sched_state_dump=False)
parser.add_argument(
"--remove_empty_parquet",
action="store_true",
help="Remove empty parquet files from hash partition output",
)
parser.add_argument(
"--remove_output_root",
action="store_true",
help="Remove all output files after job completed (for test)",
)
parser.add_argument(
"--skip_task_with_empty_input",
action="store_true",
help="Skip running a task if any of its input datasets is empty",
)
parser.add_argument(
"--self_contained_final_results",
action="store_true",
help="Build self-contained final results, i.e., create hard/symbolic links in output folder of final results",
)
parser.add_argument(
"--malloc",
"--memory_allocator",
dest="memory_allocator",
default="mimalloc",
choices=["system", "jemalloc", "mimalloc"],
help="Override memory allocator used by worker processes",
)
parser.add_argument(
"--memory_purge_delay",
default=10000,
help="The delay in milliseconds after which jemalloc/mimalloc will purge memory pages that are not in use.",
)
parser.add_argument(
"--bind_numa_node",
action="store_true",
help="Bind executor processes to numa nodes.",
)
parser.add_argument(
"--enforce_memory_limit",
action="store_true",
help="Set soft/hard memory limit for each task process",
)
parser.add_argument(
"--enable_log_analytics",
dest="share_log_analytics",
action="store_true",
help="Share log analytics with smallpond team",
)
parser.add_argument(
"--disable_log_analytics",
dest="share_log_analytics",
action="store_false",
help="Do not share log analytics with smallpond team",
)
log_level_choices = [
"CRITICAL",
"ERROR",
"WARNING",
"SUCCESS",
"INFO",
"DEBUG",
"TRACE",
]
parser.add_argument(
"--console_log_level",
default="INFO",
choices=log_level_choices,
)
parser.add_argument(
"--file_log_level",
default="DEBUG",
choices=log_level_choices,
)
parser.add_argument(
"--disable_log_rotation", action="store_true", help="Disable log rotation"
)
parser.add_argument(
"--output_path",
help="Set the output directory of final results and all nodes that have output_name but no output_path specified",
)
parser.add_argument(
"--platform",
type=str,
help="The platform to use for the job. available: mpi",
)
return parser
def add_argument(self, *args, **kwargs):
"""
Add a command-line argument. This is a wrapper of `argparse.ArgumentParser.add_argument(...)`.
"""
self.user_args_parser.add_argument(*args, **kwargs)
def parse_arguments(self, args=None):
if self.user_args is None or self.driver_args is None:
args_parser = argparse.ArgumentParser(
parents=[self.driver_args_parser, self.user_args_parser]
)
self.all_args = args_parser.parse_args(args)
self.user_args, other_args = self.user_args_parser.parse_known_args(args)
self.driver_args = self.driver_args_parser.parse_args(other_args)
return self.user_args, self.driver_args
def get_user_arguments(self, to_dict=True):
"""
Get user-defined arguments.
"""
user_args, _ = self.parse_arguments()
return vars(user_args) if to_dict else user_args
get_arguments = get_user_arguments
def get_driver_arguments(self, to_dict=True):
_, driver_args = self.parse_arguments()
return vars(driver_args) if to_dict else driver_args
@property
def mode(self) -> str:
return self.get_driver_arguments(to_dict=False).mode
@property
def job_id(self) -> str:
return self.get_driver_arguments(to_dict=False).job_id
@property
def data_root(self) -> str:
return self.get_driver_arguments(to_dict=False).data_root
@property
def num_executors(self) -> str:
return self.get_driver_arguments(to_dict=False).num_executors
def run(
self,
plan: LogicalPlan,
stop_process_on_done=True,
tags: List[str] = None,
) -> Optional[ExecutionPlan]:
"""
The entry point for executor and scheduler of `plan`.
"""
from smallpond.execution.executor import Executor
from smallpond.execution.manager import JobManager
from smallpond.execution.task import RuntimeContext
_, args = self.parse_arguments()
retval = None
def run_executor(runtime_ctx: RuntimeContext, exec_id: str, numa_node_id=None):
if numa_node_id is not None:
import numa
exec_id += f".numa{numa_node_id}"
numa.schedule.bind(numa_node_id)
runtime_ctx.numa_node_id = numa_node_id
runtime_ctx.initialize(exec_id)
executor = Executor.create(runtime_ctx, exec_id)
return executor.run()
if args.mode == "ray":
assert plan is not None
sp = smallpond.init(_remove_output_root=args.remove_output_root)
DataFrame(sp, plan.root_node).compute()
retval = True
elif args.mode == "executor":
assert os.path.isfile(
args.runtime_ctx_path
), f"cannot find runtime context: {args.runtime_ctx_path}"
runtime_ctx: RuntimeContext = load(args.runtime_ctx_path)
if runtime_ctx.bind_numa_node:
exec_procs = [
Process(
target=run_executor,
args=(runtime_ctx, args.exec_id, numa_node_id),
)
for numa_node_id in range(runtime_ctx.numa_node_count)
]
for proc in exec_procs:
proc.start()
for proc in exec_procs:
proc.join()
retval = all(proc.exitcode == 0 for proc in exec_procs)
else:
retval = run_executor(runtime_ctx, args.exec_id)
elif args.mode == "scheduler":
assert plan is not None
jobmgr = JobManager(
args.data_root, args.python_venv, args.task_image, args.platform
)
exec_plan = jobmgr.run(
plan,
job_id=args.job_id,
job_time=args.job_time,
job_name=args.job_name,
job_priority=args.job_priority,
num_executors=args.num_executors,
num_executors_per_task=args.num_executors_per_task,
resource_group=args.resource_group,
env_variables=args.env_variables,
sidecars=args.sidecars,
user_tags=args.tags + (tags or []),
random_seed=args.random_seed,
max_retry_count=args.max_retry_count,
max_fail_count=args.max_fail_count,
prioritize_retry=args.prioritize_retry,
speculative_exec=args.speculative_exec,
stop_executor_on_failure=args.stop_executor_on_failure,
fail_fast_on_failure=args.fail_fast_on_failure,
nonzero_exitcode_as_oom=args.nonzero_exitcode_as_oom,
fault_inject_prob=args.fault_inject_prob,
enable_profiling=args.enable_profiling,
enable_diagnostic_metrics=args.enable_diagnostic_metrics,
enable_sched_state_dump=args.enable_sched_state_dump,
remove_empty_parquet=args.remove_empty_parquet,
remove_output_root=args.remove_output_root,
skip_task_with_empty_input=args.skip_task_with_empty_input,
manifest_only_final_results=not args.self_contained_final_results,
memory_allocator=args.memory_allocator,
memory_purge_delay=args.memory_purge_delay,
bind_numa_node=args.bind_numa_node,
enforce_memory_limit=args.enforce_memory_limit,
share_log_analytics=args.share_log_analytics,
console_log_level=args.console_log_level,
file_log_level=args.file_log_level,
disable_log_rotation=args.disable_log_rotation,
output_path=args.output_path,
)
retval = exec_plan if exec_plan.successful else None
if stop_process_on_done:
exit_code = os.EX_OK if retval else os.EX_SOFTWARE
logger.info(f"exit code: {exit_code}")
sys.exit(exit_code)
else:
logger.info(f"return value: {repr(retval)}")
return retval
def main():
# run in executor mode
driver = Driver()
driver.run(plan=None)
if __name__ == "__main__":
main()

341
smallpond/execution/executor.py Executable file
View File

@@ -0,0 +1,341 @@
import multiprocessing as mp
import os.path
import socket
import time
from collections import OrderedDict
from typing import Dict, List, Tuple
from GPUtil import GPU
from loguru import logger
from smallpond.common import NonzeroExitCode, pytest_running
from smallpond.execution.task import Probe, RuntimeContext
from smallpond.execution.workqueue import (
StopExecutor,
StopWorkItem,
WorkItem,
WorkQueue,
WorkQueueOnFilesystem,
WorkStatus,
)
class SimplePoolTask(object):
def __init__(self, func, args, name: str):
self.proc: mp.Process = mp.Process(target=func, args=args, name=name)
self.stopping = False
def start(self):
self.proc.start()
def terminate(self):
self.proc.terminate()
self.stopping = True
def join(self, timeout=None):
self.proc.join(timeout)
if not self.ready() and timeout is not None:
logger.warning(
f"worker process {self.proc.name}({self.proc.pid}) does not exit after {timeout} secs, stopping it"
)
self.terminate()
self.proc.join()
def ready(self):
return self.proc.pid and not self.proc.is_alive()
def exitcode(self):
assert (
self.ready()
), f"worker process {self.proc.name}({self.proc.pid}) has not exited yet"
if self.stopping:
logger.info(
f"worker process stopped: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}"
)
elif self.proc.exitcode != 0:
logger.error(
f"worker process crashed: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}"
)
return self.proc.exitcode
class SimplePool(object):
def __init__(self, pool_size: int):
self.pool_size = pool_size
self.queued_tasks: List[SimplePoolTask] = []
self.running_tasks: List[SimplePoolTask] = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.join(1)
def apply_async(self, func, args, name: str = None):
task = SimplePoolTask(func, args, name)
self.queued_tasks.append(task)
return task
def update_queue(self):
self.running_tasks = [t for t in self.running_tasks if not t.ready()]
tasks_to_run = self.queued_tasks[: self.pool_size - len(self.running_tasks)]
self.queued_tasks = self.queued_tasks[
self.pool_size - len(self.running_tasks) :
]
for task in tasks_to_run:
task.start()
self.running_tasks += tasks_to_run
def join(self, timeout=None):
for task in self.running_tasks:
logger.info(f"joining process: {task.proc.name}({task.proc.pid})")
task.join(timeout)
class Executor(object):
"""
The task executor.
"""
def __init__(
self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue
) -> None:
self.ctx = ctx
self.id = id
self.wq = wq
self.cq = cq
self.running_works: Dict[str, Tuple[SimplePoolTask, WorkItem]] = OrderedDict()
self.running = True
self.epochs_to_skip = 0
self.numa_node = self.ctx.numa_node_id
self.local_gpus = {gpu: 1.0 for gpu in self.ctx.get_local_gpus()}
""" { GPU: available_quota } """
def __str__(self) -> str:
return f"Executor({self.id}), running_works[{len(self.running_works)}]={list(self.running_works.keys())[:3]}..."
@property
def busy(self) -> bool:
return len(self.running_works) > 0
@property
def available_gpu_quota(self) -> float:
return sum(self.local_gpus.values())
@staticmethod
def create(ctx: RuntimeContext, id: str) -> "Executor":
queue_dir = os.path.join(ctx.queue_root, id)
wq = WorkQueueOnFilesystem(os.path.join(queue_dir, "wq"))
cq = WorkQueueOnFilesystem(os.path.join(queue_dir, "cq"))
executor = Executor(ctx, id, wq, cq)
return executor
@staticmethod
@logger.catch(reraise=True, message="work item failed unexpectedly")
def process_work(item: WorkItem, cq: WorkQueue):
item.exec(cq)
cq.push(item)
logger.info(
f"finished work: {repr(item)}, status: {item.status}, elapsed time: {item.elapsed_time:.3f} secs"
)
logger.complete()
# for test
def stop(self):
self.wq.push(StopExecutor(f".FailStop-{self.id}", ack=False))
# for test
def skip_probes(self, epochs: int):
self.wq.push(
Probe(self.ctx, f".FalseFail-{self.id}", epoch=0, epochs_to_skip=epochs)
)
@logger.catch(reraise=True, message="executor terminated unexpectedly")
def run(self) -> bool:
mp.current_process().name = "ExecutorMainProcess"
logger.info(
f"start to run executor {self.id} on numa node #{self.ctx.numa_node_id} of {socket.gethostname()}"
)
with SimplePool(self.ctx.usable_cpu_count + 1) as pool:
retval = self.exec_loop(pool)
logger.info(f"executor exits: {self}")
logger.complete()
return retval
def exec_loop(self, pool: SimplePool) -> bool:
stop_request = None
latest_probe_time = time.time()
while self.running:
# get new work items
try:
items = self.wq.pop(count=self.ctx.usable_cpu_count)
except Exception as ex:
logger.opt(exception=ex).critical(
f"failed to pop from work queue: {self.wq}"
)
self.running = False
items = []
if not items:
secs_quiet_period = time.time() - latest_probe_time
if (
secs_quiet_period > self.ctx.secs_executor_probe_interval * 2
and os.path.exists(self.ctx.job_status_path)
):
with open(self.ctx.job_status_path) as status_file:
if (
status := status_file.read().strip()
) and not status.startswith("running"):
logger.critical(
f"job scheduler already stopped: {status}, stopping executor"
)
self.running = False
break
if (
secs_quiet_period > self.ctx.secs_executor_probe_timeout * 2
and not pytest_running()
):
logger.critical(
f"no probe received for {secs_quiet_period:.1f} secs, stopping executor"
)
self.running = False
break
# no pending works, so wait a few seconds before checking results
time.sleep(self.ctx.secs_wq_poll_interval)
for item in items:
if isinstance(item, StopExecutor):
logger.info(f"stop request received from scheduler: {item}")
stop_request = item
self.running = False
break
if isinstance(item, StopWorkItem):
running_work = self.running_works.get(item.work_to_stop, None)
if running_work is None:
logger.debug(
f"cannot find {item.work_to_stop} in running works of {self.id}"
)
self.cq.push(item)
else:
logger.info(f"stopping work: {item.work_to_stop}")
task, _ = running_work
task.terminate()
continue
if isinstance(item, Probe):
latest_probe_time = time.time()
if item.epochs_to_skip > 0:
self.epochs_to_skip += item.epochs_to_skip
if self.epochs_to_skip > 0:
self.epochs_to_skip -= 1
continue
if self.numa_node is not None:
item._numa_node = self.numa_node
# wait and allocate GPU to work item
if item.gpu_limit > 0:
if item.gpu_limit > len(self.local_gpus):
logger.warning(
f"task {item.key} requires more GPUs than physical GPUs, downgrading from {item.gpu_limit} to {len(self.local_gpus)}"
)
item.gpu_limit = len(self.local_gpus)
# FIXME: this will block the executor if there is no available GPU
while not (granted_gpus := self.acquire_gpu(item.gpu_limit)):
logger.info(f"collecting finished works to find available GPUs")
self.collect_finished_works()
time.sleep(self.ctx.secs_wq_poll_interval)
item._local_gpu = granted_gpus
logger.info(
f"{repr(item)} is assigned to run on GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }"
)
# enqueue work item to the pool
self.running_works[item.key] = (
pool.apply_async(
func=Executor.process_work, args=(item, self.cq), name=item.key
),
item,
)
logger.info(
f"started work: {repr(item)}, {len(self.running_works)} running works: {list(self.running_works.keys())[:10]}..."
)
# start to run works
pool.update_queue()
self.collect_finished_works()
pool.join(self.ctx.secs_executor_probe_interval)
if stop_request and stop_request.ack:
self.collect_finished_works()
stop_request.exec()
self.cq.push(stop_request)
return True
def collect_finished_works(self):
finished_works: List[WorkItem] = []
for work, item in self.running_works.values():
if not work.ready():
continue
else:
work.join()
if (exitcode := work.exitcode()) != 0:
item.status = WorkStatus.CRASHED
item.exception = NonzeroExitCode(
f"worker process {work.proc.name}({work.proc.pid}) exited with non-zero code {exitcode}"
)
try:
self.cq.push(item)
except Exception as ex:
logger.opt(exception=ex).critical(
f"failed to push into completion queue: {self.cq}"
)
self.running = False
finished_works.append(item)
# remove finished works
for item in finished_works:
self.running_works.pop(item.key)
if item._local_gpu:
self.release_gpu(item._local_gpu)
logger.info(
f"{repr(item)} released GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }"
)
def acquire_gpu(self, quota: float) -> Dict[GPU, float]:
"""
Acquire GPU resources with the given quota.
Return a dict of granted GPUs with their quotas.
`release_gpu` should be called later to release GPUs.
"""
if self.available_gpu_quota < quota:
# no enough GPU resources, return empty dict
return {}
granted_gpus: Dict[GPU, float] = {}
for gpu in self.local_gpus:
gpu_available_quota = self.local_gpus[gpu]
if gpu_available_quota <= 0:
continue
granted_quota = min(gpu_available_quota, quota)
granted_gpus[gpu] = granted_quota
self.local_gpus[gpu] -= granted_quota
quota -= granted_quota
if quota <= 0:
break
return granted_gpus
def release_gpu(self, gpus: Dict[GPU, float]):
"""
Release GPU resources to the pool.
"""
for gpu, quota in gpus.items():
self.local_gpus[gpu] += quota
assert (
self.local_gpus[gpu] <= 1.0
), f"GPU {gpu} quota is greater than 1.0: {self.local_gpus[gpu]}"

View File

@@ -0,0 +1,283 @@
import os
import shutil
import socket
import sys
from datetime import datetime
from typing import Dict, List, Literal, Optional
from loguru import logger
import smallpond
from smallpond.common import DEFAULT_MAX_FAIL_COUNT, DEFAULT_MAX_RETRY_COUNT, MB
from smallpond.execution.scheduler import Scheduler
from smallpond.execution.task import ExecutionPlan, JobId, RuntimeContext
from smallpond.io.filesystem import dump, load
from smallpond.logical.node import LogicalPlan
from smallpond.logical.planner import Planner
from smallpond.platform import get_platform
class SchedStateExporter(Scheduler.StateObserver):
def __init__(self, sched_state_path: str) -> None:
super().__init__()
self.sched_state_path = sched_state_path
def update(self, sched_state: Scheduler):
if sched_state.large_runtime_state:
logger.debug(f"pause exporting scheduler state")
elif sched_state.num_local_running_works > 0:
logger.debug(
f"pause exporting scheduler state: {sched_state.num_local_running_works} local running works"
)
else:
dump(
sched_state, self.sched_state_path, buffering=32 * MB, atomic_write=True
)
sched_state.log_overall_progress()
logger.debug(f"exported scheduler state to {self.sched_state_path}")
class JobManager(object):
jemalloc_filename = "libjemalloc.so.2"
mimalloc_filename = "libmimalloc.so.2.1"
env_template = r"""
ARROW_DEFAULT_MEMORY_POOL={arrow_default_malloc}
ARROW_IO_THREADS=2
OMP_NUM_THREADS=2
POLARS_MAX_THREADS=2
NUMEXPR_MAX_THREADS=2
"""
def __init__(
self,
data_root: Optional[str] = None,
python_venv: Optional[str] = None,
task_image: str = "default",
platform: Optional[str] = None,
) -> None:
self.platform = get_platform(platform)
self.data_root = os.path.abspath(data_root or self.platform.default_data_root())
self.python_venv = python_venv
self.task_image = task_image
@logger.catch(reraise=True, message="job manager terminated unexpectedly")
def run(
self,
plan: LogicalPlan,
job_id: Optional[str] = None,
job_time: Optional[float] = None,
job_name: str = "smallpond",
job_priority: Optional[int] = None,
num_executors: int = 1,
num_executors_per_task: int = 5,
resource_group: str = "localhost",
env_variables: List[str] = None,
sidecars: List[str] = None,
user_tags: List[str] = None,
random_seed: int = None,
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
max_fail_count: int = DEFAULT_MAX_FAIL_COUNT,
prioritize_retry=False,
speculative_exec: Literal["disable", "enable", "aggressive"] = "enable",
stop_executor_on_failure=False,
fail_fast_on_failure=False,
nonzero_exitcode_as_oom=False,
fault_inject_prob: float = 0.0,
enable_profiling=False,
enable_diagnostic_metrics=False,
enable_sched_state_dump=False,
remove_empty_parquet=False,
remove_output_root=False,
skip_task_with_empty_input=False,
manifest_only_final_results=True,
memory_allocator: Literal["system", "jemalloc", "mimalloc"] = "mimalloc",
memory_purge_delay: int = 10000,
bind_numa_node=False,
enforce_memory_limit=False,
share_log_analytics: Optional[bool] = None,
console_log_level: Literal[
"CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"
] = "INFO",
file_log_level: Literal[
"CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"
] = "DEBUG",
disable_log_rotation=False,
sched_state_observers: Optional[List[Scheduler.StateObserver]] = None,
output_path: Optional[str] = None,
**kwargs,
) -> ExecutionPlan:
logger.info(f"using platform: {self.platform}")
job_id = JobId(hex=job_id or self.platform.default_job_id())
job_time = (
datetime.fromtimestamp(job_time)
if job_time is not None
else self.platform.default_job_time()
)
malloc_path = ""
if memory_allocator == "system":
malloc_path = ""
elif memory_allocator == "jemalloc":
malloc_path = shutil.which(self.jemalloc_filename)
elif memory_allocator == "mimalloc":
malloc_path = shutil.which(self.mimalloc_filename)
else:
logger.critical(f"failed to find memory allocator: {memory_allocator}")
env_overrides = self.env_template.format(
arrow_default_malloc=memory_allocator,
).splitlines()
env_overrides = env_overrides + (env_variables or [])
env_overrides = dict(
tuple(kv.strip().split("=", maxsplit=1))
for kv in filter(None, env_overrides)
)
share_log_analytics = (
share_log_analytics
if share_log_analytics is not None
else self.platform.default_share_log_analytics()
)
shared_log_root = (
self.platform.shared_log_root() if share_log_analytics else None
)
runtime_ctx = RuntimeContext(
job_id,
job_time,
self.data_root,
num_executors=num_executors,
random_seed=random_seed,
env_overrides=env_overrides,
bind_numa_node=bind_numa_node,
enforce_memory_limit=enforce_memory_limit,
fault_inject_prob=fault_inject_prob,
enable_profiling=enable_profiling,
enable_diagnostic_metrics=enable_diagnostic_metrics,
remove_empty_parquet=remove_empty_parquet,
skip_task_with_empty_input=skip_task_with_empty_input,
shared_log_root=shared_log_root,
console_log_level=console_log_level,
file_log_level=file_log_level,
disable_log_rotation=disable_log_rotation,
output_path=output_path,
**kwargs,
)
runtime_ctx.initialize(socket.gethostname(), root_exist_ok=True)
logger.info(
f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}"
)
dump(runtime_ctx, runtime_ctx.runtime_ctx_path, atomic_write=True)
logger.info(f"saved runtime context at {runtime_ctx.runtime_ctx_path}")
dump(plan, runtime_ctx.logcial_plan_path, atomic_write=True)
logger.info(f"saved logcial plan at {runtime_ctx.logcial_plan_path}")
plan.graph().render(runtime_ctx.logcial_plan_graph_path, format="png")
logger.info(
f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png"
)
exec_plan = Planner(runtime_ctx).create_exec_plan(
plan, manifest_only_final_results
)
dump(exec_plan, runtime_ctx.exec_plan_path, atomic_write=True)
logger.info(f"saved execution plan at {runtime_ctx.exec_plan_path}")
sidecar_list = sidecars or []
fs_name, cluster = self.data_root.split("/")[1:3]
tag_list = [
"smallpond",
"executor",
smallpond.__version__,
job_name,
fs_name,
cluster,
f"malloc:{memory_allocator}",
] + (user_tags or [])
if self.python_venv:
tag_list.append(self.python_venv)
if fail_fast_on_failure:
max_fail_count = 0
if max_fail_count == 0:
tag_list.append("fail_fast")
tag_list.append(f"max_fail:{max_fail_count}")
tag_list.append(f"speculative_exec:{speculative_exec}")
tag_list.append(f"max_retry:{max_retry_count}")
if prioritize_retry:
tag_list.append("prioritize_retry")
if stop_executor_on_failure:
tag_list.append("stop_executor")
if bind_numa_node:
tag_list.append("bind_numa_node")
if enforce_memory_limit:
tag_list.append("enforce_memory_limit")
nonzero_exitcode_as_oom = True
sched_state_observers = sched_state_observers or []
if enable_sched_state_dump:
sched_state_exporter = SchedStateExporter(runtime_ctx.sched_state_path)
sched_state_observers.insert(0, sched_state_exporter)
if os.path.exists(runtime_ctx.sched_state_path):
logger.warning(
f"loading scheduler state from: {runtime_ctx.sched_state_path}"
)
scheduler: Scheduler = load(runtime_ctx.sched_state_path)
scheduler.sched_epoch += 1
scheduler.sched_state_observers = sched_state_observers
else:
scheduler = Scheduler(
exec_plan,
max_retry_count,
max_fail_count,
prioritize_retry,
speculative_exec,
stop_executor_on_failure,
nonzero_exitcode_as_oom,
remove_output_root,
sched_state_observers,
)
# start executors
self.platform.start_job(
num_nodes=num_executors,
entrypoint=os.path.join(os.path.dirname(__file__), "driver.py"),
args=[
"--job_id",
str(job_id),
"--data_root",
self.data_root,
"--runtime_ctx_path",
runtime_ctx.runtime_ctx_path,
"executor",
],
envs={
"LD_PRELOAD": malloc_path,
"MALLOC_CONF": f"percpu_arena:percpu,background_thread:true,metadata_thp:auto,dirty_decay_ms:{memory_purge_delay},muzzy_decay_ms:{memory_purge_delay},oversize_threshold:0,lg_tcache_max:16",
"MIMALLOC_PURGE_DELAY": memory_purge_delay,
},
extra_opts=dict(
job_id=job_id,
job_name=job_name,
job_priority=job_priority,
num_executors_per_task=num_executors_per_task,
resource_group=resource_group,
image=self.task_image,
python_venv=self.python_venv,
tags=tag_list,
sidecars=sidecar_list,
),
)
# run scheduler
scheduler.run()
return scheduler.exec_plan

File diff suppressed because it is too large Load Diff

3671
smallpond/execution/task.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,479 @@
import os
import os.path
import queue
import random
import sys
import time
import uuid
from enum import Enum
from typing import Dict, Iterable, List, Optional
import numpy as np
from GPUtil import GPU
from loguru import logger
from smallpond.common import MB, NonzeroExitCode, OutOfMemory
from smallpond.io.filesystem import dump, load
class WorkStatus(Enum):
INCOMPLETE = 1
SUCCEED = 2
FAILED = 3
CRASHED = 4
EXEC_FAILED = 5
class WorkItem(object):
__slots__ = (
"_cpu_limit",
"_gpu_limit",
"_memory_limit",
"_cpu_boost",
"_memory_boost",
"_numa_node",
"_local_gpu",
"key",
"status",
"start_time",
"finish_time",
"retry_count",
"fail_count",
"exception",
"exec_id",
"exec_cq",
"location",
)
def __init__(
self,
key: str,
cpu_limit: int = None,
gpu_limit: float = None,
memory_limit: int = None,
) -> None:
self._cpu_limit = cpu_limit
self._gpu_limit = gpu_limit
self._memory_limit = (
np.int64(memory_limit) if memory_limit is not None else None
)
self._cpu_boost = 1
self._memory_boost = 1
self._numa_node = None
self._local_gpu: Dict[GPU, float] = {}
self.key = key
self.status = WorkStatus.INCOMPLETE
self.start_time = None
self.finish_time = None
self.retry_count = 0
self.fail_count = 0
self.exception = None
self.exec_id = "unknown"
self.exec_cq = None
self.location: Optional[str] = None
def __repr__(self) -> str:
return self.key
__str__ = __repr__
@property
def cpu_limit(self) -> int:
return int(self._cpu_boost * self._cpu_limit)
@property
def gpu_limit(self) -> int:
return self._gpu_limit
@property
def memory_limit(self) -> np.int64:
return (
np.int64(self._memory_boost * self._memory_limit)
if self._memory_limit
else 0
)
@property
def elapsed_time(self) -> float:
if self.start_time is None:
return 0
if self.finish_time is None:
return time.time() - self.start_time
return self.finish_time - self.start_time
@property
def exec_on_scheduler(self) -> bool:
return False
@property
def local_gpu(self) -> Optional[GPU]:
"""
Return the first GPU granted to this task.
If gpu_limit is 0, return None.
If gpu_limit is greater than 1, only the first GPU is returned.
"""
return next(iter(self._local_gpu.keys()), None)
@property
# @deprecated("use `local_gpu_ranks` instead")
def local_rank(self) -> Optional[int]:
"""
Return the first GPU rank granted to this task.
If gpu_limit is 0, return None.
If gpu_limit is greater than 1, only the first GPU rank is returned. Caller should use `local_gpu_ranks` instead.
"""
if self.gpu_limit > 1:
logger.warning(
f"task {self.key} requires more than 1 GPU, but only the first GPU rank is returned. please use `local_gpu_ranks` instead."
)
return next(iter(self._local_gpu.keys())).id if self._local_gpu else None
@property
def local_gpu_ranks(self) -> List[int]:
"""Return all GPU ranks granted to this task."""
return [gpu.id for gpu in self._local_gpu.keys()]
@property
def numa_node(self) -> int:
return self._numa_node
def oom(self, nonzero_exitcode_as_oom=False):
return (
self._memory_limit is not None
and self.status == WorkStatus.CRASHED
and (
isinstance(self.exception, (OutOfMemory, MemoryError))
or (
isinstance(self.exception, NonzeroExitCode)
and nonzero_exitcode_as_oom
)
)
)
def run(self) -> bool:
return True
def initialize(self):
"""Called before run() to prepare for running the task."""
def finalize(self):
"""Called after run() to finalize the processing."""
def cleanup(self):
"""Called after run() even if there is an exception."""
def exec(self, cq: Optional["WorkQueue"] = None) -> WorkStatus:
if self.status == WorkStatus.INCOMPLETE:
try:
self.start_time = time.time()
self.exec_cq = cq
self.initialize()
if self.run():
self.status = WorkStatus.SUCCEED
self.finalize()
else:
self.status = WorkStatus.FAILED
except Exception as ex:
logger.opt(exception=ex).error(
f"{repr(self)} crashed with error. node location at {self.location}"
)
self.status = WorkStatus.CRASHED
self.exception = ex
finally:
self.cleanup()
self.exec_cq = None
self.finish_time = time.time()
return self.status
class StopExecutor(WorkItem):
def __init__(self, key: str, ack=True) -> None:
super().__init__(key, cpu_limit=0, gpu_limit=0, memory_limit=0)
self.ack = ack
class StopWorkItem(WorkItem):
def __init__(self, key: str, work_to_stop: str) -> None:
super().__init__(key, cpu_limit=0, gpu_limit=0, memory_limit=0)
self.work_to_stop = work_to_stop
class WorkBatch(WorkItem):
def __init__(self, key: str, works: List[WorkItem]) -> None:
cpu_limit = max(w.cpu_limit for w in works)
gpu_limit = max(w.gpu_limit for w in works)
memory_limit = max(w.memory_limit for w in works)
super().__init__(
f"{self.__class__.__name__}-{key}", cpu_limit, gpu_limit, memory_limit
)
self.works = works
def __str__(self) -> str:
return (
super().__str__()
+ f", works[{len(self.works)}]={self.works[:1]}...{self.works[-1:]}"
)
def run(self) -> bool:
logger.info(f"processing {len(self.works)} works in the batch")
for index, work in enumerate(self.works):
work.exec_id = self.exec_id
if work.exec(self.exec_cq) != WorkStatus.SUCCEED:
logger.error(
f"work item #{index+1}/{len(self.works)} in {self.key} failed: {work}"
)
return False
logger.info(f"done {len(self.works)} works in the batch")
return True
class WorkQueue(object):
def __init__(self) -> None:
self.outbound_works: List[WorkItem] = []
def _pop_unbuffered(self, count: int) -> List[WorkItem]:
raise NotImplementedError
def _push_unbuffered(self, item: WorkItem) -> bool:
raise NotImplementedError
def pop(self, count=1) -> List[WorkItem]:
inbound_works: List[WorkItem] = []
for item in self._pop_unbuffered(count):
if isinstance(item, WorkBatch):
inbound_works.extend(item.works)
else:
inbound_works.append(item)
return inbound_works
def push(self, item: WorkItem, buffering=False) -> bool:
if buffering:
self.outbound_works.append(item)
return True
elif len(self.outbound_works) > 0:
self.outbound_works.append(item)
return self.flush()
else:
return self._push_unbuffered(item)
def flush(self) -> bool:
if len(self.outbound_works) == 0:
return True
batch = WorkBatch(self.outbound_works[0].key, self.outbound_works)
self.outbound_works = []
return self._push_unbuffered(batch)
class WorkQueueInMemory(WorkQueue):
def __init__(self, queue_type=queue.Queue) -> None:
super().__init__()
self.queue = queue_type()
def _pop_unbuffered(self, count: int) -> List[WorkItem]:
try:
return [self.queue.get(block=False)]
except queue.Empty:
return []
def _push_unbuffered(self, item: WorkItem) -> bool:
self.queue.put(item)
return True
class WorkQueueOnFilesystem(WorkQueue):
buffer_size = 16 * MB
def __init__(self, workq_root: str, sort=True, random=False) -> None:
super().__init__()
self.workq_root = workq_root
self.sort = sort
self.random = random
self.buffered_works: List[WorkItem] = []
self.temp_dir = os.path.join(self.workq_root, "temp")
self.enqueue_dir = os.path.join(self.workq_root, "enqueued")
self.dequeue_dir = os.path.join(self.workq_root, "dequeued")
os.makedirs(self.temp_dir, exist_ok=True)
os.makedirs(self.enqueue_dir, exist_ok=True)
os.makedirs(self.dequeue_dir, exist_ok=True)
def __str__(self) -> str:
return f"{self.__class__.__name__}@{self.workq_root}"
def _get_entries(self, path=None) -> List[os.DirEntry]:
dentries: List[os.DirEntry] = []
with os.scandir(path or self.enqueue_dir) as dir_iter:
for entry in dir_iter:
if entry.is_file():
dentries.append(entry)
return dentries
def size(self) -> int:
return len(self._get_entries())
def _list_works(self, path: str, expand_batch=True) -> Iterable[WorkItem]:
dentries = self._get_entries(path)
logger.info("listing {} entries in {}", len(dentries), path)
for entry in dentries:
item: WorkItem = load(entry.path, self.buffer_size)
if expand_batch and isinstance(item, WorkBatch):
for work in item.works:
yield work
else:
yield item
def list_enqueued(self, expand_batch=True) -> Iterable[WorkItem]:
yield from self._list_works(self.enqueue_dir, expand_batch)
def list_dequeued(self, expand_batch=True) -> Iterable[WorkItem]:
yield from self._list_works(self.dequeue_dir, expand_batch)
def list_works(self, expand_batch=True) -> Iterable[WorkItem]:
yield from self.list_enqueued(expand_batch)
yield from self.list_dequeued(expand_batch)
def _pop_unbuffered(self, count: int) -> List[WorkItem]:
items = []
dentries = self._get_entries()
if self.sort:
dentries = sorted(dentries, key=lambda entry: entry.name)
elif self.random:
random.shuffle(dentries)
for entry in dentries:
uuid_hex = uuid.uuid4().hex
filename = f"{entry.name}-DEQ{uuid_hex}"
dequeued_path = os.path.join(self.dequeue_dir, filename)
try:
os.rename(entry.path, dequeued_path)
except OSError as err:
logger.debug(f"cannot rename {entry.path} to {dequeued_path}: {err}")
if items:
break
else:
continue
items.append(load(dequeued_path, self.buffer_size))
if len(items) >= count:
break
return items
def _push_unbuffered(self, item: WorkItem) -> bool:
timestamp = time.time_ns()
uuid_hex = uuid.uuid4().hex
filename = f"{item.key}-{timestamp:x}-ENQ{uuid_hex}"
tempfile_path = os.path.join(self.temp_dir, filename)
enqueued_path = os.path.join(self.enqueue_dir, filename)
dump(item, tempfile_path, self.buffer_size)
try:
os.rename(tempfile_path, enqueued_path)
return True
except OSError as err:
logger.critical(
f"failed to rename {tempfile_path} to {enqueued_path}: {err}"
)
return False
def count_objects(obj, object_cnt=None, visited_objs=None, depth=0):
object_cnt = {} if object_cnt is None else object_cnt
visited_objs = set() if visited_objs is None else visited_objs
if id(obj) in visited_objs:
return
else:
visited_objs.add(id(obj))
if isinstance(obj, dict):
for key, value in obj.items():
count_objects(key, object_cnt, visited_objs, depth + 1)
count_objects(value, object_cnt, visited_objs, depth + 1)
elif isinstance(obj, list) or isinstance(obj, tuple):
for item in obj:
count_objects(item, object_cnt, visited_objs, depth + 1)
else:
class_name = obj.__class__.__name__
if class_name not in object_cnt:
object_cnt[class_name] = (0, 0)
cnt, size = object_cnt[class_name]
object_cnt[class_name] = (cnt + 1, size + sys.getsizeof(obj))
key_attributes = ("__self__", "__dict__", "__slots__")
if not isinstance(obj, (bool, str, int, float, type(None))) and any(
attr_name in key_attributes for attr_name in dir(obj)
):
logger.debug(f"{' ' * depth}{class_name}@{id(obj):x}")
for attr_name in dir(obj):
try:
if (
not attr_name.startswith("__") or attr_name in key_attributes
) and not isinstance(
getattr(obj.__class__, attr_name, None), property
):
logger.debug(
f"{' ' * depth}{class_name}.{attr_name}@{id(obj):x}"
)
count_objects(
getattr(obj, attr_name), object_cnt, visited_objs, depth + 1
)
except Exception as ex:
logger.warning(
f"failed to get '{attr_name}' from {repr(obj)}: {ex}"
)
def main():
import argparse
from smallpond.execution.task import Probe
parser = argparse.ArgumentParser(
prog="workqueue.py", description="Work Queue Reader"
)
parser.add_argument("wq_root", help="Work queue root path")
parser.add_argument("-f", "--work_filter", default="", help="Work item filter")
parser.add_argument(
"-x", "--expand_batch", action="store_true", help="Expand batched works"
)
parser.add_argument(
"-c", "--count_object", action="store_true", help="Count number of objects"
)
parser.add_argument(
"-n", "--top_n_class", default=20, type=int, help="Show the top n classes"
)
parser.add_argument(
"-l", "--log_level", default="INFO", help="Logging message level"
)
args = parser.parse_args()
logger.remove()
logger.add(
sys.stdout,
format=r"[{time:%Y-%m-%d %H:%M:%S.%f}] {level} {message}",
level=args.log_level,
)
wq = WorkQueueOnFilesystem(args.wq_root)
for work in wq.list_works(args.expand_batch):
if isinstance(work, Probe):
continue
if args.work_filter in work.key:
logger.info(work)
if args.count_object:
object_cnt = {}
count_objects(work, object_cnt)
sorted_counts = sorted(
[(v, k) for k, v in object_cnt.items()], reverse=True
)
for count, class_name in sorted_counts[: args.top_n_class]:
logger.info(f" {class_name}: {count}")
if __name__ == "__main__":
main()

0
smallpond/io/__init__.py Normal file
View File

419
smallpond/io/arrow.py Normal file
View File

@@ -0,0 +1,419 @@
import copy
import math
import os.path
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Iterable, List, Optional, Union
import fsspec
import pyarrow as arrow
import pyarrow.parquet as parquet
from loguru import logger
from smallpond.common import (
DEFAULT_BATCH_SIZE,
DEFAULT_ROW_GROUP_BYTES,
DEFAULT_ROW_GROUP_SIZE,
MAX_PARQUET_FILE_BYTES,
MB,
split_into_rows,
)
@dataclass
class RowRange:
"""A range of rows in a file."""
path: str
"""Path to the file."""
data_size: int
"""The uncompressed data size in bytes."""
file_num_rows: int
"""The number of rows in the file."""
begin: int
"""Index of first row in the file."""
end: int
"""Index of last row + 1 in the file."""
@property
def num_rows(self) -> int:
"""The number of rows in the range."""
return self.end - self.begin
@property
def estimated_data_size(self) -> int:
"""The estimated uncompressed data size in bytes."""
return (
self.data_size * self.num_rows // self.file_num_rows
if self.file_num_rows > 0
else 0
)
def take(self, num_rows: int) -> "RowRange":
"""
Take `num_rows` rows from the range.
NOTE: this function modifies the current row range.
"""
num_rows = min(num_rows, self.num_rows)
head = copy.copy(self)
head.end = head.begin + num_rows
self.begin += num_rows
return head
@staticmethod
def partition_by_rows(
row_ranges: List["RowRange"], npartition: int
) -> List[List["RowRange"]]:
"""Evenly split a list of row ranges into `npartition` partitions."""
# NOTE: `row_ranges` should not be modified by this function
row_ranges = copy.deepcopy(row_ranges)
num_rows: int = sum(row_range.num_rows for row_range in row_ranges)
num_partitions: int = npartition
row_range_partitions: List[List[RowRange]] = []
while num_partitions:
rows_in_partition = (num_rows + num_partitions - 1) // num_partitions
num_rows -= rows_in_partition
num_partitions -= 1
row_ranges_in_partition = []
while rows_in_partition:
current_range = row_ranges[0]
if current_range.num_rows == 0:
row_ranges.pop(0)
continue
taken_range = current_range.take(rows_in_partition)
row_ranges_in_partition.append(taken_range)
rows_in_partition -= taken_range.num_rows
row_range_partitions.append(row_ranges_in_partition)
assert num_rows == 0 and num_partitions == 0
return row_range_partitions
def convert_type_to_large(type_: arrow.DataType) -> arrow.DataType:
"""
Convert all string and binary types to large types recursively.
"""
# Since arrow uses 32-bit signed offsets for string and binary types, convert all string and binary columns
# to large_string and large_binary to avoid offset overflow, see https://issues.apache.org/jira/browse/ARROW-17828.
if arrow.types.is_string(type_):
return arrow.large_string()
elif arrow.types.is_binary(type_):
return arrow.large_binary()
elif isinstance(type_, arrow.ListType):
return arrow.list_(convert_type_to_large(type_.value_type))
elif isinstance(type_, arrow.StructType):
return arrow.struct(
[
arrow.field(
field.name,
convert_type_to_large(field.type),
nullable=field.nullable,
)
for field in type_
]
)
elif isinstance(type_, arrow.MapType):
return arrow.map_(
convert_type_to_large(type_.key_type),
convert_type_to_large(type_.item_type),
)
else:
return type_
def convert_types_to_large_string(schema: arrow.Schema) -> arrow.Schema:
"""
Convert all string and binary types to large types in the schema.
"""
new_fields = []
for field in schema:
new_type = convert_type_to_large(field.type)
new_field = arrow.field(
field.name, new_type, nullable=field.nullable, metadata=field.metadata
)
new_fields.append(new_field)
return arrow.schema(new_fields, metadata=schema.metadata)
def cast_columns_to_large_string(table: arrow.Table) -> arrow.Table:
schema = convert_types_to_large_string(table.schema)
return table.cast(schema)
def filter_schema(
schema: arrow.Schema,
included_cols: Optional[List[str]] = None,
excluded_cols: Optional[List[str]] = None,
):
assert included_cols is None or excluded_cols is None
if included_cols is None and excluded_cols is None:
return schema
if included_cols is not None:
fields = [schema.field(col_name) for col_name in included_cols]
if excluded_cols is not None:
fields = [
schema.field(col_name)
for col_name in schema.names
if col_name not in excluded_cols
]
return arrow.schema(fields, metadata=schema.metadata)
def _iter_record_batches(
file: parquet.ParquetFile,
columns: List[str],
offset: int,
length: int,
batch_size: int,
) -> Iterable[arrow.RecordBatch]:
"""
Read record batches from a range of a parquet file.
"""
current_offset = 0
required_l, required_r = offset, offset + length
for batch in file.iter_batches(
batch_size=batch_size, columns=columns, use_threads=False
):
current_l, current_r = current_offset, current_offset + batch.num_rows
# check if intersection is null
if current_r <= required_l:
pass
elif current_l >= required_r:
break
else:
intersection_l = max(required_l, current_l)
intersection_r = min(required_r, current_r)
trimmed = batch.slice(
intersection_l - current_offset, intersection_r - intersection_l
)
assert (
trimmed.num_rows == intersection_r - intersection_l
), f"trimmed.num_rows {trimmed.num_rows} != batch_length {intersection_r - intersection_l}"
yield cast_columns_to_large_string(trimmed)
current_offset += batch.num_rows
def build_batch_reader_from_files(
paths_or_ranges: Union[List[str], List[RowRange]],
*,
columns: Optional[List[str]] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
max_batch_byte_size: Optional[int] = None,
filesystem: fsspec.AbstractFileSystem = None,
) -> arrow.RecordBatchReader:
assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list"
schema = _read_schema_from_file(paths_or_ranges[0], columns, filesystem)
iterator = _iter_record_batches_from_files(
paths_or_ranges, columns, batch_size, max_batch_byte_size, filesystem
)
return arrow.RecordBatchReader.from_batches(schema, iterator)
def _read_schema_from_file(
path_or_range: Union[str, RowRange],
columns: Optional[List[str]] = None,
filesystem: fsspec.AbstractFileSystem = None,
) -> arrow.Schema:
path = path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
schema = parquet.read_schema(
filesystem.unstrip_protocol(path) if filesystem else path, filesystem=filesystem
)
if columns is not None:
assert all(
c in schema.names for c in columns
), f"""some of {columns} cannot be found in schema of {path}:
{schema}
The following query can help to find files with missing columns:
duckdb-dev -c "select * from ( select file_name, list(name) as column_names, list_filter({columns}, c -> not list_contains(column_names, c)) as missing_columns FROM parquet_schema(['{os.path.join(os.path.dirname(path), '*.parquet')}']) group by file_name ) where len(missing_columns) > 0"
"""
schema = filter_schema(schema, columns)
return convert_types_to_large_string(schema)
def _iter_record_batches_from_files(
paths_or_ranges: Union[List[str], List[RowRange]],
columns: Optional[List[str]] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
max_batch_byte_size: Optional[int] = None,
filesystem: fsspec.AbstractFileSystem = None,
) -> Iterable[arrow.RecordBatch]:
"""
Build a batch reader from a list of row ranges.
"""
buffered_batches = []
buffered_rows = 0
buffered_bytes = 0
def combine_buffered_batches(
batches: List[arrow.RecordBatch],
) -> Iterable[arrow.RecordBatch]:
table = arrow.Table.from_batches(batches)
yield from table.combine_chunks().to_batches(batch_size)
for path_or_range in paths_or_ranges:
path = (
path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
)
with parquet.ParquetFile(
filesystem.unstrip_protocol(path) if filesystem else path,
buffer_size=16 * MB,
filesystem=filesystem,
) as file:
if isinstance(path_or_range, RowRange):
offset, length = path_or_range.begin, path_or_range.num_rows
else:
offset, length = 0, file.metadata.num_rows
for batch in _iter_record_batches(
file, columns, offset, length, batch_size
):
batch_size_exceeded = batch.num_rows + buffered_rows >= batch_size
batch_byte_size_exceeded = (
max_batch_byte_size is not None
and batch.nbytes + buffered_bytes >= max_batch_byte_size
)
if not batch_size_exceeded and not batch_byte_size_exceeded:
buffered_batches.append(batch)
buffered_rows += batch.num_rows
buffered_bytes += batch.nbytes
else:
if batch_size_exceeded:
buffered_batches.append(
batch.slice(0, batch_size - buffered_rows)
)
batch = batch.slice(batch_size - buffered_rows)
if buffered_batches:
yield from combine_buffered_batches(buffered_batches)
buffered_batches = [batch]
buffered_rows = batch.num_rows
buffered_bytes = batch.nbytes
if buffered_batches:
yield from combine_buffered_batches(buffered_batches)
def read_parquet_files_into_table(
paths_or_ranges: Union[List[str], List[RowRange]],
columns: List[str] = None,
filesystem: fsspec.AbstractFileSystem = None,
) -> arrow.Table:
batch_reader = build_batch_reader_from_files(
paths_or_ranges, columns=columns, filesystem=filesystem
)
return batch_reader.read_all()
def load_from_parquet_files(
paths_or_ranges: Union[List[str], List[RowRange]],
columns: List[str] = None,
max_workers: int = 16,
filesystem: fsspec.AbstractFileSystem = None,
) -> arrow.Table:
start_time = time.time()
assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list"
paths = [
path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
for path_or_range in paths_or_ranges
]
total_compressed_size = sum(
(
path_or_range.data_size
if isinstance(path_or_range, RowRange)
else os.path.getsize(path_or_range)
)
for path_or_range in paths_or_ranges
)
logger.debug(
f"loading {len(paths)} parquet files (compressed size: {total_compressed_size/MB:.3f}MB): {paths[:3]}..."
)
num_workers = min(len(paths), max_workers)
with ThreadPoolExecutor(num_workers) as pool:
running_works = [
pool.submit(read_parquet_files_into_table, batch, columns, filesystem)
for batch in split_into_rows(paths_or_ranges, num_workers)
]
tables = [work.result() for work in running_works]
logger.debug(
f"collected {len(tables)} tables from: {paths[:3]}... (elapsed: {time.time() - start_time:.3f} secs)"
)
return arrow.concat_tables(tables)
def parquet_write_table(
table, where, filesystem: fsspec.AbstractFileSystem = None, **write_table_args
) -> int:
if filesystem is not None:
return parquet.write_table(
table,
where=(filesystem.unstrip_protocol(where) if filesystem else where),
filesystem=filesystem,
**write_table_args,
)
with open(where, "wb", buffering=32 * MB) as file:
return parquet.write_table(table, where=file, **write_table_args)
def dump_to_parquet_files(
table: arrow.Table,
output_dir: str,
filename: str = "data",
compression="ZSTD",
compression_level=3,
row_group_size=DEFAULT_ROW_GROUP_SIZE,
row_group_bytes=DEFAULT_ROW_GROUP_BYTES,
use_dictionary=False,
max_workers=16,
filesystem: fsspec.AbstractFileSystem = None,
) -> bool:
table = cast_columns_to_large_string(table)
if table.num_rows == 0:
logger.warning(f"creating empty parquet file in {output_dir}")
parquet_write_table(
table,
os.path.join(output_dir, f"{filename}-0.parquet"),
compression=compression,
row_group_size=row_group_size,
)
return True
start_time = time.time()
avg_row_size = max(1, table.nbytes // table.num_rows)
row_group_size = min(row_group_bytes // avg_row_size, row_group_size)
logger.debug(
f"dumping arrow table ({table.nbytes/MB:.3f}MB, {table.num_rows} rows) to {output_dir}, avg row size: {avg_row_size}, row group size: {row_group_size}"
)
batches = table.to_batches(max_chunksize=row_group_size)
num_workers = min(len(batches), max_workers)
num_tables = max(math.ceil(table.nbytes / MAX_PARQUET_FILE_BYTES), num_workers)
logger.debug(f"evenly distributed {len(batches)} batches into {num_tables} files")
tables = [
arrow.Table.from_batches(batch, table.schema)
for batch in split_into_rows(batches, num_tables)
]
assert sum(t.num_rows for t in tables) == table.num_rows
logger.debug(f"writing {len(tables)} files to {output_dir}")
with ThreadPoolExecutor(num_workers) as pool:
running_works = [
pool.submit(
parquet_write_table,
table=table,
where=os.path.join(output_dir, f"{filename}-{i}.parquet"),
use_dictionary=use_dictionary,
compression=compression,
compression_level=compression_level,
row_group_size=row_group_size,
write_batch_size=max(16 * 1024, row_group_size // 8),
data_page_size=max(64 * MB, row_group_bytes // 8),
filesystem=filesystem,
)
for i, table in enumerate(tables)
]
assert all(work.result() or True for work in running_works)
logger.debug(
f"finished writing {len(tables)} files to {output_dir} (elapsed: {time.time() - start_time:.3f} secs)"
)
return True

136
smallpond/io/filesystem.py Normal file
View File

@@ -0,0 +1,136 @@
import io
import os
import shutil
import tempfile
import time
from typing import Any
import cloudpickle
import zstandard as zstd
from loguru import logger
from smallpond.common import MB
HF3FS_MOUNT_POINT_PREFIX = "/hf3fs"
HF3FS_FSSPEC_PROTOCOL = "hf3fs://"
def on_hf3fs(path: str):
return path.startswith(HF3FS_MOUNT_POINT_PREFIX)
def extract_hf3fs_mount_point(path: str):
return os.path.join("/", *path.split("/")[1:3]) if on_hf3fs(path) else None
def remove_path(path: str):
realpath = os.path.realpath(path)
if os.path.islink(path):
logger.debug(f"removing link: {path}")
os.unlink(path)
if not os.path.exists(realpath):
return
logger.debug(f"removing path: {realpath}")
if on_hf3fs(realpath):
try:
link = os.path.join(
extract_hf3fs_mount_point(realpath),
"3fs-virt/rm-rf",
f"{os.path.basename(realpath)}-{time.time_ns()}",
)
os.symlink(realpath, link)
return
except Exception as ex:
logger.opt(exception=ex).debug(
f"fast recursive remove failed, fall back to shutil.rmtree('{realpath}')"
)
shutil.rmtree(realpath, ignore_errors=True)
def find_mount_point(path: str):
path = os.path.abspath(path)
while not os.path.ismount(path):
path = os.path.dirname(path)
return path
def dump(obj: Any, path: str, buffering=2 * MB, atomic_write=False) -> int:
"""
Dump an object to a file.
Args:
obj: The object to dump.
path: The path to the file to dump the object to.
buffering: The buffering size.
atomic_write: Whether to atomically write the file.
Returns:
The size of the file.
"""
def get_pickle_trace(obj):
try:
import dill
import dill.detect
except ImportError:
return None, None
pickle_trace = io.StringIO()
pickle_error = None
with dill.detect.trace(pickle_trace):
try:
dill.dumps(obj, recurse=True)
except Exception as ex:
pickle_error = ex
return pickle_trace.getvalue(), pickle_error
def write_to_file(fout):
with zstd.ZstdCompressor().stream_writer(fout, closefd=False) as zstd_writer:
try:
cloudpickle.dump(obj, zstd_writer)
except zstd.ZstdError as ex:
raise
except Exception as ex:
trace_str, trace_err = get_pickle_trace(obj)
logger.opt(exception=ex).error(
f"pickle trace of {repr(obj)}:{os.linesep}{trace_str}"
)
if trace_err is None:
raise
else:
raise trace_err from ex
logger.trace("{} saved to {}", repr(obj), path)
size = 0
if atomic_write:
directory, filename = os.path.split(path)
with tempfile.NamedTemporaryFile(
"wb", buffering=buffering, dir=directory, prefix=filename, delete=False
) as fout:
write_to_file(fout)
fout.seek(0, os.SEEK_END)
size = fout.tell()
os.rename(fout.name, path)
else:
with open(path, "wb", buffering=buffering) as fout:
write_to_file(fout)
fout.seek(0, os.SEEK_END)
size = fout.tell()
if size >= buffering:
logger.debug(f"created a large pickle file ({size/MB:.3f}MB): {path}")
return size
def load(path: str, buffering=2 * MB) -> Any:
"""
Load an object from a file.
"""
with open(path, "rb", buffering=buffering) as fin:
with zstd.ZstdDecompressor().stream_reader(fin) as zstd_reader:
obj = cloudpickle.load(zstd_reader)
logger.trace("{} loaded from {}", repr(obj), path)
return obj

View File

1098
smallpond/logical/dataset.py Normal file

File diff suppressed because it is too large Load Diff

2136
smallpond/logical/node.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,59 @@
from smallpond.execution.task import *
from smallpond.logical.node import *
class Optimizer(LogicalPlanVisitor[Node]):
"""
Optimize the logical plan.
"""
def __init__(self, exclude_nodes: Set[Node]):
self.exclude_nodes = exclude_nodes
"""A set of nodes that will not be optimized."""
self.optimized_node_map: Dict[Node, Node] = {}
"""A map from original node to optimized node."""
def visit(self, node: Node, depth: int = 0) -> Node:
# stop recursion if the node is excluded
if node in self.exclude_nodes:
return node
# memoize the optimized node
if node in self.optimized_node_map:
return self.optimized_node_map[node]
optimized_node = super().visit(node, depth)
self.optimized_node_map[node] = optimized_node
return optimized_node
def generic_visit(self, node: Node, depth: int) -> Node:
# by default, recursively optimize the input deps
new_node = copy.copy(node)
new_node.input_deps = [self.visit(dep, depth + 1) for dep in node.input_deps]
return new_node
def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> Node:
# fuse consecutive SqlEngineNodes
if len(node.input_deps) == 1 and isinstance(
child := self.visit(node.input_deps[0], depth + 1), SqlEngineNode
):
fused = copy.copy(node)
fused.input_deps = child.input_deps
fused.udfs = node.udfs + child.udfs
fused.cpu_limit = max(node.cpu_limit, child.cpu_limit)
fused.gpu_limit = max(node.gpu_limit, child.gpu_limit)
fused.memory_limit = (
max(node.memory_limit, child.memory_limit)
if node.memory_limit is not None and child.memory_limit is not None
else node.memory_limit or child.memory_limit
)
# merge the sql queries
# example:
# ```
# child.sql_queries = ["select * from {0}"]
# node.sql_queries = ["select a, b from {0}"]
# fused.sql_queries = ["select a, b from (select * from {0})"]
# ```
fused.sql_queries = child.sql_queries[:-1] + [
query.format(f"({child.sql_queries[-1]})") for query in node.sql_queries
]
return fused
return self.generic_visit(node, depth)

View File

@@ -0,0 +1,348 @@
from smallpond.execution.task import *
from smallpond.logical.node import *
TaskGroup = List[Task]
class Planner(LogicalPlanVisitor[TaskGroup]):
"""
Create an execution plan (tasks) from a logical plan (nodes).
"""
def __init__(self, runtime_ctx: RuntimeContext) -> None:
self.runtime_ctx = runtime_ctx
self.node_to_tasks: Dict[Node, TaskGroup] = {}
@logger.catch(reraise=True, message="failed to build computation graph")
def create_exec_plan(
self, logical_plan: LogicalPlan, manifest_only_final_results=True
) -> ExecutionPlan:
logical_plan = copy.deepcopy(logical_plan)
# if --output_path is specified, copy files to the output path
# otherwise, create manifest files only
sink_type = (
"copy" if self.runtime_ctx.final_output_path is not None else "manifest"
)
final_sink_type = (
"copy"
if self.runtime_ctx.final_output_path is not None
else "manifest" if manifest_only_final_results else "link"
)
# create DataSinkNode for each named output node (same name share the same sink node)
nodes_groupby_output_name: Dict[str, List[Node]] = defaultdict(list)
for node in logical_plan.nodes.values():
if node.output_name is not None:
if node.output_name in nodes_groupby_output_name:
warnings.warn(
f"{node} has duplicate output name: {node.output_name}"
)
nodes_groupby_output_name[node.output_name].append(node)
sink_nodes = {} # { output_name: DataSinkNode }
for output_name, nodes in nodes_groupby_output_name.items():
output_path = os.path.join(
self.runtime_ctx.final_output_path or self.runtime_ctx.output_root,
output_name,
)
sink_nodes[output_name] = DataSinkNode(
logical_plan.ctx, tuple(nodes), output_path, type=sink_type
)
# create DataSinkNode for root node
# XXX: special case optimization to avoid copying files twice
# if root node is DataSetPartitionNode(npartitions=1), and all its input nodes are named, create manifest files instead of copying files.
if (
isinstance(logical_plan.root_node, ConsolidateNode)
and len(logical_plan.root_node.input_deps) == 1
and isinstance(
partition_node := logical_plan.root_node.input_deps[0],
EvenlyDistributedPartitionNode,
)
and all(node.output_name is not None for node in partition_node.input_deps)
):
sink_nodes["FinalResults"] = DataSinkNode(
logical_plan.ctx,
tuple(
sink_nodes[node.output_name] for node in partition_node.input_deps
),
output_path=os.path.join(
self.runtime_ctx.final_output_path or self.runtime_ctx.output_root,
"FinalResults",
),
type="manifest",
is_final_node=True,
)
# if root node also has output_name, create manifest files instead of copying files.
elif (output_name := logical_plan.root_node.output_name) is not None:
sink_nodes["FinalResults"] = DataSinkNode(
logical_plan.ctx,
(sink_nodes[output_name],),
output_path=os.path.join(
self.runtime_ctx.final_output_path or self.runtime_ctx.output_root,
"FinalResults",
),
type="manifest",
is_final_node=True,
)
else:
# normal case
sink_nodes["FinalResults"] = DataSinkNode(
logical_plan.ctx,
(logical_plan.root_node,),
output_path=os.path.join(
self.runtime_ctx.final_output_path or self.runtime_ctx.output_root,
"FinalResults",
),
type=final_sink_type,
is_final_node=True,
)
# assemble sink nodes as new root node
root_node = RootNode(logical_plan.ctx, tuple(sink_nodes.values()))
logical_plan = LogicalPlan(logical_plan.ctx, root_node)
# generate tasks
[root_task] = self.visit(root_node)
# print logical plan with the generated runtime tasks
logger.info(f"logical plan:{os.linesep}{str(logical_plan)}")
exec_plan = ExecutionPlan(self.runtime_ctx, root_task, logical_plan)
return exec_plan
def visit(self, node: Node, depth: int = 0) -> TaskGroup:
# memoize the tasks
if node in self.node_to_tasks:
return self.node_to_tasks[node]
retvals = super().visit(node, depth)
self.node_to_tasks[node] = retvals
return retvals
def visit_data_source_node(self, node: DataSourceNode, depth: int) -> TaskGroup:
assert not node.input_deps, f"data source should be leaf node: {node}"
return [node.create_task(self.runtime_ctx, [], [PartitionInfo()])]
def visit_data_sink_node(self, node: DataSinkNode, depth: int) -> TaskGroup:
all_input_deps = [
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
]
return [node.create_task(self.runtime_ctx, all_input_deps, [PartitionInfo()])]
def visit_root_node(self, node: RootNode, depth: int) -> TaskGroup:
all_input_deps = [
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
]
return [RootTask(self.runtime_ctx, all_input_deps, [PartitionInfo()])]
def visit_union_node(self, node: UnionNode, depth: int) -> TaskGroup:
all_input_deps = [
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
]
unique_partition_dims = set(task.partition_dims for task in all_input_deps)
assert (
len(unique_partition_dims) == 1
), f"cannot union partitions with different dimensions: {unique_partition_dims}"
return all_input_deps
def visit_consolidate_node(self, node: ConsolidateNode, depth: int) -> TaskGroup:
input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps]
assert (
len(input_deps_taskgroups) == 1
), f"consolidate node only accepts one input node, but found: {input_deps_taskgroups}"
unique_partition_dims = set(
task.partition_dims for task in input_deps_taskgroups[0]
)
assert (
len(unique_partition_dims) == 1
), f"cannot consolidate partitions with different dimensions: {unique_partition_dims}"
existing_dimensions = set(unique_partition_dims.pop())
assert (
node.dimensions.intersection(existing_dimensions) == node.dimensions
), f"cannot found some of {node.dimensions} in {existing_dimensions}"
# group tasks by partitions
input_deps_groupby_partitions: Dict[Tuple, List[Task]] = defaultdict(list)
for task in input_deps_taskgroups[0]:
partition_infos = tuple(
info
for info in task.partition_infos
if info.dimension in node.dimensions
)
input_deps_groupby_partitions[partition_infos].append(task)
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for partition_infos, input_deps in input_deps_groupby_partitions.items()
]
def visit_partition_node(self, node: PartitionNode, depth: int) -> TaskGroup:
all_input_deps = [
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
]
unique_partition_dims = set(task.partition_dims for task in all_input_deps)
assert (
len(unique_partition_dims) == 1
), f"cannot partition input_deps with different dimensions: {unique_partition_dims}"
if node.nested:
assert (
node.dimension not in unique_partition_dims
), f"found duplicate partition dimension '{node.dimension}', existing dimensions: {unique_partition_dims}"
assert (
len(all_input_deps) * node.npartitions
<= node.max_card_of_producers_x_consumers
), f"{len(all_input_deps)=} * {node.npartitions=} > {node.max_card_of_producers_x_consumers=}"
producer_tasks = [
node.create_producer_task(
self.runtime_ctx, [task], task.partition_infos
)
for task in all_input_deps
]
return [
node.create_consumer_task(
self.runtime_ctx,
[producer],
list(producer.partition_infos)
+ [PartitionInfo(partition_idx, node.npartitions, node.dimension)],
)
for producer in producer_tasks
for partition_idx in range(node.npartitions)
]
else:
max_num_producer_tasks = min(
node.max_num_producer_tasks,
math.ceil(node.max_card_of_producers_x_consumers / node.npartitions),
)
num_parallel_tasks = (
2
* self.runtime_ctx.num_executors
* math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit)
)
num_producer_tasks = max(1, min(max_num_producer_tasks, num_parallel_tasks))
if len(all_input_deps) < num_producer_tasks:
merge_datasets_task = node.create_merge_task(
self.runtime_ctx, all_input_deps, [PartitionInfo()]
)
split_dataset_tasks = [
node.create_split_task(
self.runtime_ctx,
[merge_datasets_task],
[PartitionInfo(partition_idx, num_producer_tasks)],
)
for partition_idx in range(num_producer_tasks)
]
else:
split_dataset_tasks = [
node.create_merge_task(
self.runtime_ctx,
tasks,
[PartitionInfo(partition_idx, num_producer_tasks)],
)
for partition_idx, tasks in enumerate(
split_into_rows(all_input_deps, num_producer_tasks)
)
]
producer_tasks = [
node.create_producer_task(
self.runtime_ctx, [split_dataset], split_dataset.partition_infos
)
for split_dataset in split_dataset_tasks
]
return [
node.create_consumer_task(
self.runtime_ctx,
producer_tasks,
[
PartitionInfo(),
PartitionInfo(partition_idx, node.npartitions, node.dimension),
],
)
for partition_idx in range(node.npartitions)
]
def broadcast_input_deps(self, node: Node, depth: int):
# if no input deps, return a single partition
if not node.input_deps:
yield [], (PartitionInfo(),)
return
input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps]
input_deps_most_ndims = max(
input_deps_taskgroups,
key=lambda taskgroup: (
len(taskgroup[0].partition_dims),
max(t.partition_infos for t in taskgroup),
),
)
input_deps_maps = [
(
taskgroup[0].partition_dims,
dict((t.partition_infos, t) for t in taskgroup),
)
for taskgroup in input_deps_taskgroups
]
for main_input in input_deps_most_ndims:
input_deps = []
for input_deps_dims, input_deps_map in input_deps_maps:
partition_infos = tuple(
info
for info in main_input.partition_infos
if info.dimension in input_deps_dims
)
input_dep = input_deps_map.get(partition_infos, None)
assert (
input_dep is not None
), f"""the partition dimensions or npartitions of inputs {node.input_deps} of {repr(node)} are not compatible
cannot match {main_input.partition_infos} against any of {input_deps_map.keys()}"""
input_deps.append(input_dep)
yield input_deps, main_input.partition_infos
def visit_python_script_node(self, node: PythonScriptNode, depth: int) -> TaskGroup:
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
]
def visit_arrow_compute_node(self, node: ArrowComputeNode, depth: int) -> TaskGroup:
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
]
def visit_arrow_stream_node(self, node: ArrowStreamNode, depth: int) -> TaskGroup:
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
]
def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> TaskGroup:
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
]
def visit_projection_node(self, node: ProjectionNode, depth: int) -> TaskGroup:
assert (
len(node.input_deps) == 1
), f"projection node only accepts one input node, but found: {node.input_deps}"
return [
node.create_task(self.runtime_ctx, [task], task.partition_infos)
for task in self.visit(node.input_deps[0], depth + 1)
]
def visit_limit_node(self, node: LimitNode, depth: int) -> TaskGroup:
assert (
len(node.input_deps) == 1
), f"limit node only accepts one input node, but found: {node.input_deps}"
all_input_deps = self.visit(node.input_deps[0], depth + 1)
partial_limit_tasks = [
node.create_task(self.runtime_ctx, [task], task.partition_infos)
for task in all_input_deps
]
merge_task = node.create_merge_task(
self.runtime_ctx, partial_limit_tasks, [PartitionInfo()]
)
global_limit_task = node.create_task(
self.runtime_ctx, [merge_task], merge_task.partition_infos
)
return [global_limit_task]

269
smallpond/logical/udf.py Normal file
View File

@@ -0,0 +1,269 @@
import importlib
import os.path
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Dict, List, Optional, Union
import duckdb
import duckdb.typing
class UDFType(Enum):
"""
A wrapper of duckdb.typing.DuckDBPyType
See https://duckdb.org/docs/api/python/types.html
"""
SQLNULL = 1
BOOLEAN = 2
TINYINT = 3
UTINYINT = 4
SMALLINT = 5
USMALLINT = 6
INTEGER = 7
UINTEGER = 8
BIGINT = 9
UBIGINT = 10
HUGEINT = 11
UUID = 12
FLOAT = 13
DOUBLE = 14
DATE = 15
TIMESTAMP = 16
TIMESTAMP_MS = 17
TIMESTAMP_NS = 18
TIMESTAMP_S = 19
TIME = 20
TIME_TZ = 21
TIMESTAMP_TZ = 22
VARCHAR = 23
BLOB = 24
BIT = 25
INTERVAL = 26
def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType:
if self == UDFType.SQLNULL:
return duckdb.typing.SQLNULL
elif self == UDFType.BOOLEAN:
return duckdb.typing.BOOLEAN
elif self == UDFType.TINYINT:
return duckdb.typing.TINYINT
elif self == UDFType.UTINYINT:
return duckdb.typing.UTINYINT
elif self == UDFType.SMALLINT:
return duckdb.typing.SMALLINT
elif self == UDFType.USMALLINT:
return duckdb.typing.USMALLINT
elif self == UDFType.INTEGER:
return duckdb.typing.INTEGER
elif self == UDFType.UINTEGER:
return duckdb.typing.UINTEGER
elif self == UDFType.BIGINT:
return duckdb.typing.BIGINT
elif self == UDFType.UBIGINT:
return duckdb.typing.UBIGINT
elif self == UDFType.HUGEINT:
return duckdb.typing.HUGEINT
elif self == UDFType.UUID:
return duckdb.typing.UUID
elif self == UDFType.FLOAT:
return duckdb.typing.FLOAT
elif self == UDFType.DOUBLE:
return duckdb.typing.DOUBLE
elif self == UDFType.DATE:
return duckdb.typing.DATE
elif self == UDFType.TIMESTAMP:
return duckdb.typing.TIMESTAMP
elif self == UDFType.TIMESTAMP_MS:
return duckdb.typing.TIMESTAMP_MS
elif self == UDFType.TIMESTAMP_NS:
return duckdb.typing.TIMESTAMP_NS
elif self == UDFType.TIMESTAMP_S:
return duckdb.typing.TIMESTAMP_S
elif self == UDFType.TIME:
return duckdb.typing.TIME
elif self == UDFType.TIME_TZ:
return duckdb.typing.TIME_TZ
elif self == UDFType.TIMESTAMP_TZ:
return duckdb.typing.TIMESTAMP_TZ
elif self == UDFType.VARCHAR:
return duckdb.typing.VARCHAR
elif self == UDFType.BLOB:
return duckdb.typing.BLOB
elif self == UDFType.BIT:
return duckdb.typing.BIT
elif self == UDFType.INTERVAL:
return duckdb.typing.INTERVAL
return None
class UDFStructType:
"""
A wrapper of duckdb.struct_type, eg: UDFStructType({'host': 'VARCHAR', 'path:' 'VARCHAR', 'query': 'VARCHAR'})
See https://duckdb.org/docs/api/python/types.html#a-field_one-b-field_two--n-field_n
"""
def __init__(self, fields: Union[Dict[str, str], List[str]]) -> None:
self.fields = fields
def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType:
return duckdb.struct_type(self.fields)
class UDFListType:
"""
A wrapper of duckdb.list_type, eg: UDFListType(UDFType.INTEGER)
See https://duckdb.org/docs/api/python/types.html#listchild_type
"""
def __init__(self, child) -> None:
self.child = child
def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType:
return duckdb.list_type(self.child.to_duckdb_type())
class UDFMapType:
"""
A wrapper of duckdb.map_type, eg: UDFMapType(UDFType.VARCHAR, UDFType.INTEGER)
See https://duckdb.org/docs/api/python/types.html#dictkey_type-value_type
"""
def __init__(self, key, value) -> None:
self.key = key
self.value = value
def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType:
return duckdb.map_type(self.key.to_duckdb_type(), self.value.to_duckdb_type())
class UDFAnyParameters:
"""
Accept parameters of any types in UDF.
"""
def __init__(self) -> None:
pass
def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType:
return None
class UDFContext(object):
def bind(self, conn: duckdb.DuckDBPyConnection):
raise NotImplementedError
class PythonUDFContext(UDFContext):
def __init__(
self,
name: str,
func: Callable,
params: Optional[List[UDFType]],
return_type: Optional[UDFType],
use_arrow_type=False,
):
self.name = name
self.func = func
self.params = params
self.return_type = return_type
self.use_arrow_type = use_arrow_type
def __str__(self) -> str:
return f"{self.name}@{self.func}"
__repr__ = __str__
def bind(self, conn: duckdb.DuckDBPyConnection):
if isinstance(self.params, UDFAnyParameters):
duckdb_args = self.params.to_duckdb_type()
else:
duckdb_args = [arg.to_duckdb_type() for arg in self.params]
conn.create_function(
self.name,
self.func,
duckdb_args,
self.return_type.to_duckdb_type(),
type=("arrow" if self.use_arrow_type else "native"),
)
# logger.debug(f"created python udf: {self.name}({self.params}) -> {self.return_type}")
class ExternalModuleContext(UDFContext):
def __init__(self, name: str, module_path: str) -> None:
self.name = name
self.module_path = module_path
def __str__(self) -> str:
return f"{self.name}@{self.module_path}"
__repr__ = __str__
def bind(self, conn: duckdb.DuckDBPyConnection):
module_name, _ = os.path.splitext(os.path.basename(self.module_path))
spec = importlib.util.spec_from_file_location(module_name, self.module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module.create_duckdb_udfs(conn)
# logger.debug(f"loaded external module at {self.module_path}, udf functions: {module.udfs}")
class DuckDbExtensionContext(UDFContext):
def __init__(self, name: str, extension_path: str) -> None:
self.name = name
self.extension_path = extension_path
def __str__(self) -> str:
return f"{self.name}@{self.extension_path}"
__repr__ = __str__
def bind(self, conn: duckdb.DuckDBPyConnection):
conn.load_extension(self.extension_path)
# logger.debug(f"loaded duckdb extension at {self.extension_path}")
@dataclass
class UserDefinedFunction:
"""
A python user-defined function.
"""
name: str
func: Callable
params: List[UDFType]
return_type: UDFType
use_arrow_type: bool
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def udf(
params: List[UDFType],
return_type: UDFType,
use_arrow_type: bool = False,
name: Optional[str] = None,
) -> Callable[[Callable], UserDefinedFunction]:
"""
A decorator to define a Python UDF.
Examples
--------
```
@udf(params=[UDFType.INTEGER, UDFType.INTEGER], return_type=UDFType.INTEGER)
def gcd(a: int, b: int) -> int:
while b:
a, b = b, a % b
return a
```
See `Context.create_function` for more details.
"""
return lambda func: UserDefinedFunction(
name or func.__name__, func, params, return_type, use_arrow_type
)

View File

@@ -0,0 +1,36 @@
from typing import Optional
from smallpond.platform.base import Platform
from smallpond.platform.mpi import MPI
_platforms = {
"mpi": MPI,
}
def get_platform(name: Optional[str] = None) -> Platform:
"""
Get a platform by name.
If name is not specified, try to get an available platform.
"""
if name is None:
for platform in _platforms.values():
if platform.is_available():
return platform()
return Platform()
if name in _platforms:
return _platforms[name]()
# load platform from a custom python module
from importlib import import_module
module = import_module(name)
# find the exact class that inherits from Platform
for name in dir(module):
cls = getattr(module, name)
if isinstance(cls, type) and issubclass(cls, Platform):
return cls()
raise RuntimeError(f"no Platform class found in module: {name}")

109
smallpond/platform/base.py Normal file
View File

@@ -0,0 +1,109 @@
import os
import signal
import subprocess
import uuid
from datetime import datetime
from typing import List, Optional
class Platform:
"""
Base class for all platforms.
"""
@staticmethod
def is_available() -> bool:
"""
Whether the platform is available in the current environment.
"""
return False
@classmethod
def __str__(cls) -> str:
return cls.__name__
def start_job(
self,
num_nodes: int,
entrypoint: str,
args: List[str],
envs: dict = {},
extra_opts: dict = {},
) -> List[str]:
"""
Start a job on the platform.
Return the job ids.
"""
pids = []
for _ in range(num_nodes):
popen = subprocess.Popen(
["python", entrypoint, *args],
env={**os.environ, **envs},
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT,
)
pids.append(str(popen.pid))
return pids
def stop_job(self, pid: str) -> None:
"""
Stop the job.
"""
os.kill(int(pid), signal.SIGKILL)
@staticmethod
def default_job_id() -> str:
"""
Return the default job id.
"""
return str(uuid.uuid4())
@staticmethod
def default_job_time() -> datetime:
"""
Return the default job time.
"""
return datetime.now()
@staticmethod
def default_data_root() -> Optional[str]:
"""
Get the default data root for the platform.
If the platform does not have a default data root, return None.
"""
from loguru import logger
default = os.path.expanduser("~/.smallpond/data")
logger.warning(f"data root is not set, using default: {default}")
return default
@staticmethod
def default_share_log_analytics() -> bool:
"""
Whether to share log analytics by default.
"""
return False
@staticmethod
def shared_log_root() -> Optional[str]:
"""
Return the shared log root.
"""
return None
@staticmethod
def grafana_homepath() -> Optional[str]:
"""
Return the homepath of grafana.
"""
homebrew_installed_homepath = "/opt/homebrew/opt/grafana/share/grafana"
if os.path.exists(homebrew_installed_homepath):
return homebrew_installed_homepath
return None
@staticmethod
def default_memory_allocator() -> str:
"""
Get the default memory allocator for the platform.
"""
return "system"

39
smallpond/platform/mpi.py Normal file
View File

@@ -0,0 +1,39 @@
import shutil
import subprocess
from typing import List
from loguru import logger
from smallpond.platform.base import Platform
class MPI(Platform):
"""
MPI platform.
"""
@staticmethod
def is_available() -> bool:
return shutil.which("mpirun") is not None
def start_job(
self,
num_nodes: int,
entrypoint: str,
args: List[str],
envs: dict = {},
extra_opts: dict = {},
) -> List[str]:
mpirun_cmd = ["mpirun", "-n", str(num_nodes)]
for key, value in envs.items():
mpirun_cmd += ["-x", f"{key}={value}"]
mpirun_cmd += ["python", entrypoint] + args
logger.debug(f"start job with command: {' '.join(mpirun_cmd)}")
subprocess.Popen(
mpirun_cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT,
text=True,
)
return []

388
smallpond/session.py Normal file
View File

@@ -0,0 +1,388 @@
"""
This module defines the `Session` class, which is the entry point for smallpond interactive mode.
"""
from __future__ import annotations
import json
import os
import shutil
import socket
import subprocess
import sys
import threading
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Tuple
import ray
from graphviz import Digraph
import graphviz.backend.execute
from loguru import logger
import smallpond
from smallpond.execution.task import JobId, RuntimeContext
from smallpond.logical.node import Context
from smallpond.platform import Platform, get_platform
class SessionBase:
def __init__(self, **kwargs):
"""
Create a smallpond environment.
"""
super().__init__()
self._ctx = Context()
self.config, self._platform = Config.from_args_and_env(**kwargs)
# construct runtime context for Tasks
runtime_ctx = RuntimeContext(
job_id=JobId(hex=self.config.job_id),
job_time=self.config.job_time,
data_root=self.config.data_root,
num_executors=self.config.num_executors,
bind_numa_node=self.config.bind_numa_node,
shared_log_root=self._platform.shared_log_root(),
)
self._runtime_ctx = runtime_ctx
# if `spawn` is specified, spawn a job and exit
if os.environ.get("SP_SPAWN") == "1":
self._spawn_self()
exit(0)
self._runtime_ctx.initialize(exec_id=socket.gethostname())
logger.info(f"using platform: {self._platform}")
logger.info(f"command-line arguments: {' '.join(sys.argv)}")
logger.info(f"session config: {self.config}")
def setup_worker():
runtime_ctx._init_logs(
exec_id=socket.gethostname(), capture_stdout_stderr=True
)
if self.config.ray_address is None:
# find the memory allocator
if self.config.memory_allocator == "system":
malloc_path = ""
elif self.config.memory_allocator == "jemalloc":
malloc_path = shutil.which("libjemalloc.so.2")
assert malloc_path is not None, "jemalloc is not installed"
elif self.config.memory_allocator == "mimalloc":
malloc_path = shutil.which("libmimalloc.so.2.1")
assert malloc_path is not None, "mimalloc is not installed"
else:
raise ValueError(
f"unsupported memory allocator: {self.config.memory_allocator}"
)
memory_purge_delay = 10000
# start ray head node
# for ray head node to access grafana
os.environ["RAY_GRAFANA_HOST"] = "http://localhost:8122"
self._ray_address = ray.init(
# start a new local cluster
address="local",
# disable local CPU resource if not running on localhost
num_cpus=(
0
if self.config.num_executors > 0
else self._runtime_ctx.usable_cpu_count
),
# set the memory limit to the available memory size
_memory=self._runtime_ctx.usable_memory_size,
# setup logging for workers
log_to_driver=False,
runtime_env={
"worker_process_setup_hook": setup_worker,
"env_vars": {
"LD_PRELOAD": malloc_path,
"MALLOC_CONF": f"percpu_arena:percpu,background_thread:true,metadata_thp:auto,dirty_decay_ms:{memory_purge_delay},muzzy_decay_ms:{memory_purge_delay},oversize_threshold:0,lg_tcache_max:16",
"MIMALLOC_PURGE_DELAY": f"{memory_purge_delay}",
"ARROW_DEFAULT_MEMORY_POOL": self.config.memory_allocator,
"ARROW_IO_THREADS": "2",
"OMP_NUM_THREADS": "2",
"POLARS_MAX_THREADS": "2",
"NUMEXPR_MAX_THREADS": "2",
"RAY_PROFILING": "1",
},
},
dashboard_host="0.0.0.0",
dashboard_port=8008,
# for prometheus to scrape metrics
_metrics_export_port=8080,
).address_info["gcs_address"]
logger.info(f"started ray cluster at {self._ray_address}")
self._prometheus_process = self._start_prometheus()
self._grafana_process = self._start_grafana()
else:
self._ray_address = self.config.ray_address
self._prometheus_process = None
self._grafana_process = None
logger.info(f"connected to ray cluster at {self._ray_address}")
# start workers
if self.config.num_executors > 0:
# override configs
kwargs["job_id"] = self.config.job_id
self._job_names = self._platform.start_job(
self.config.num_executors,
entrypoint=os.path.join(os.path.dirname(__file__), "worker.py"),
args=[
f"--ray_address={self._ray_address}",
f"--log_dir={self._runtime_ctx.log_root}",
*(["--bind_numa_node"] if self.config.bind_numa_node else []),
],
extra_opts=kwargs,
)
else:
self._job_names = []
# spawn a thread to periodically dump metrics
self._stop_event = threading.Event()
self._dump_thread = threading.Thread(
name="dump_thread", target=self._dump_periodically, daemon=True
)
self._dump_thread.start()
def shutdown(self):
"""
Shutdown the session.
"""
logger.info("shutting down session")
self._stop_event.set()
# stop all jobs
for job_name in self._job_names:
self._platform.stop_job(job_name)
self._job_names = []
self._dump_thread.join()
if self.config.ray_address is None:
ray.shutdown()
if self._prometheus_process is not None:
self._prometheus_process.terminate()
self._prometheus_process.wait()
self._prometheus_process = None
logger.info("stopped prometheus")
if self._grafana_process is not None:
self._grafana_process.terminate()
self._grafana_process.wait()
self._grafana_process = None
logger.info("stopped grafana")
def _spawn_self(self):
"""
Spawn a new job to run the current script.
"""
self._platform.start_job(
num_nodes=1,
entrypoint=sys.argv[0],
args=sys.argv[1:],
extra_opts=dict(
tags=["smallpond", "scheduler", smallpond.__version__],
),
envs={
k: v
for k, v in os.environ.items()
if k.startswith("SP_") and k != "SP_SPAWN"
},
)
def _start_prometheus(self) -> Optional[subprocess.Popen]:
"""
Start prometheus server if it exists.
"""
prometheus_path = shutil.which("prometheus")
if prometheus_path is None:
logger.warning("prometheus is not found")
return None
os.makedirs(f"{self._runtime_ctx.log_root}/prometheus", exist_ok=True)
proc = subprocess.Popen(
[
prometheus_path,
"--config.file=/tmp/ray/session_latest/metrics/prometheus/prometheus.yml",
f"--storage.tsdb.path={self._runtime_ctx.log_root}/prometheus/data",
],
stderr=open(f"{self._runtime_ctx.log_root}/prometheus/prometheus.log", "w"),
)
logger.info("started prometheus")
return proc
def _start_grafana(self) -> Optional[subprocess.Popen]:
"""
Start grafana server if it exists.
"""
homepath = self._platform.grafana_homepath()
if homepath is None:
logger.warning("grafana is not found")
return None
os.makedirs(f"{self._runtime_ctx.log_root}/grafana", exist_ok=True)
proc = subprocess.Popen(
[
shutil.which("grafana"),
"server",
"--config",
"/tmp/ray/session_latest/metrics/grafana/grafana.ini",
"-homepath",
homepath,
"web",
],
stdout=open(f"{self._runtime_ctx.log_root}/grafana/grafana.log", "w"),
env={
"GF_SERVER_HTTP_PORT": "8122", # redirect to an available port
"GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST")
or "http://localhost:8122",
"GF_PATHS_DATA": f"{self._runtime_ctx.log_root}/grafana/data",
},
)
logger.info(f"started grafana at http://localhost:8122")
return proc
@property
def runtime_ctx(self) -> RuntimeContext:
return self._runtime_ctx
def graph(self) -> Digraph:
"""
Get the logical plan graph.
"""
# implemented in Session class
raise NotImplementedError("graph")
def dump_graph(self, path: Optional[str] = None):
"""
Dump the logical plan graph to a file.
"""
path = path or os.path.join(self.runtime_ctx.log_root, "graph")
try:
self.graph().render(path, format="png")
logger.debug(f"dumped graph to {path}")
except graphviz.backend.execute.ExecutableNotFound as e:
logger.warning(f"graphviz is not installed, skipping graph dump")
def dump_timeline(self, path: Optional[str] = None):
"""
Dump the task timeline to a file.
"""
path = path or os.path.join(self.runtime_ctx.log_root, "timeline")
# the default timeline is grouped by worker
exec_path = f"{path}_exec"
ray.timeline(exec_path)
logger.debug(f"dumped timeline to {exec_path}")
# generate another timeline grouped by node
with open(exec_path) as f:
records = json.load(f)
new_records = []
for record in records:
# swap record name and pid-tid
name = record["name"]
try:
node_id = name.split(",")[-1]
task_id = name.split("-")[1].split(".")[0]
task_name = name.split("-")[0]
record["pid"] = f"{node_id}-{task_name}"
record["tid"] = f"task {task_id}"
new_records.append(record)
except Exception:
# filter out other records
pass
node_path = f"{path}_plan"
with open(node_path, "w") as f:
json.dump(new_records, f)
logger.debug(f"dumped timeline to {node_path}")
def _summarize_task(self) -> Tuple[int, int]:
# implemented in Session class
raise NotImplementedError("summarize_task")
def _dump_periodically(self):
"""
Dump the graph and timeline every minute.
Set `self._stop_event` to have a final dump and stop this thread.
"""
while not self._stop_event.is_set():
self._stop_event.wait(60)
self.dump_graph()
self.dump_timeline()
num_total_tasks, num_finished_tasks = self._summarize_task()
percent = (
num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0
)
logger.info(
f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)"
)
@dataclass
class Config:
"""
Configuration for a session.
"""
job_id: str # JOBID
job_time: datetime # JOB_TIME
data_root: str # DATA_ROOT
num_executors: int # NUM_NODES_TOTAL
ray_address: Optional[str] # RAY_ADDRESS
bind_numa_node: bool # BIND_NUMA_NODE
memory_allocator: str # MEMORY_ALLOCATOR
remove_output_root: bool
@staticmethod
def from_args_and_env(
platform: Optional[str] = None,
job_id: Optional[str] = None,
job_time: Optional[datetime] = None,
data_root: Optional[str] = None,
num_executors: Optional[int] = None,
ray_address: Optional[str] = None,
bind_numa_node: Optional[bool] = None,
memory_allocator: Optional[str] = None,
_remove_output_root: bool = True,
**kwargs,
) -> Config:
"""
Load config from arguments and environment variables.
If not specified, use the default value.
"""
def get_env(key: str, type: type = str):
"""
Get an environment variable and convert it to the given type.
If the variable is not set, return None.
"""
value = os.environ.get(f"SP_{key}")
return type(value) if value is not None else None
platform = get_platform(get_env("PLATFORM") or platform)
job_id = get_env("JOBID") or job_id or platform.default_job_id()
job_time = (
get_env("JOB_TIME", datetime.fromisoformat)
or job_time
or platform.default_job_time()
)
data_root = get_env("DATA_ROOT") or data_root or platform.default_data_root()
num_executors = get_env("NUM_EXECUTORS", int) or num_executors or 0
ray_address = get_env("RAY_ADDRESS") or ray_address
bind_numa_node = get_env("BIND_NUMA_NODE") == "1" or bind_numa_node
memory_allocator = (
get_env("MEMORY_ALLOCATOR")
or memory_allocator
or platform.default_memory_allocator()
)
config = Config(
job_id=job_id,
job_time=job_time,
data_root=data_root,
num_executors=num_executors,
ray_address=ray_address,
bind_numa_node=bind_numa_node,
memory_allocator=memory_allocator,
remove_output_root=_remove_output_root,
)
return config, platform

199
smallpond/utility.py Normal file
View File

@@ -0,0 +1,199 @@
import cProfile
import inspect
import io
import logging
import pstats
import queue
import subprocess
import sys
import threading
from typing import Any, Dict, Iterable
from loguru import logger
def overall_stats(
ctx,
inp,
sql_per_part,
sql_on_merged,
output_name,
output_dir=None,
cpu_limit=2,
memory_limit=30 << 30,
):
from smallpond.logical.node import DataSetPartitionNode, DataSinkNode, SqlEngineNode
n = SqlEngineNode(
ctx, inp, sql_per_part, cpu_limit=cpu_limit, memory_limit=memory_limit
)
p = DataSetPartitionNode(ctx, (n,), npartitions=1)
n2 = SqlEngineNode(
ctx,
(p,),
sql_on_merged,
output_name=output_name,
cpu_limit=cpu_limit,
memory_limit=memory_limit,
)
if output_dir is not None:
return DataSinkNode(ctx, (n2,), output_dir)
else:
return n2
def execute_command(cmd: str, env: Dict[str, str] = None, shell=False):
with subprocess.Popen(
cmd.split(),
env=env,
shell=shell,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
encoding="utf8",
) as proc:
for line in proc.stdout:
yield line.rstrip()
return_code = proc.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, cmd)
def cprofile_to_string(
perf_profile: cProfile.Profile, order_by=pstats.SortKey.TIME, top_k=20
):
perf_profile.disable()
pstats_output = io.StringIO()
profile_stats = pstats.Stats(perf_profile, stream=pstats_output)
profile_stats.strip_dirs().sort_stats(order_by).print_stats(top_k)
return pstats_output.getvalue()
class Wrapper(object):
def __init__(self, base_obj: Any):
self._base_obj = base_obj
def __str__(self) -> str:
return str(self._base_obj)
def __repr__(self) -> str:
return repr(self._base_obj)
def __getattr__(self, name):
return getattr(self._base_obj, name)
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, value)
else:
return setattr(self._base_obj, name, value)
class ConcurrentIterError(Exception):
pass
class ConcurrentIter(object):
"""
Use a background thread to iterate over an iterable.
Examples
--------
The following code snippet is a common pattern to read record batches from parquet files asynchronously in arrow stream task.
```
from smallpond.utility import ConcurrentIter
with ConcurrentIter(input_readers[0], max_buffer_size=1) as async_reader:
for batch_idx, batch in enumerate(async_reader):
# your code here
yield StreamOutput(output_table, batch_indices=[batch_idx])
```
"""
def __init__(self, iterable: Iterable, max_buffer_size=1) -> None:
assert isinstance(
iterable, Iterable
), f"expect an iterable but found: {repr(iterable)}"
self.__iterable = iterable
self.__queue = queue.Queue(max_buffer_size)
self.__last = object()
self.__stop = threading.Event()
self.__thread = threading.Thread(target=self._producer)
def __enter__(self):
self.__thread.start()
return iter(self)
def __exit__(self, exc_type, exc_value, traceback):
self.join()
def __iter__(self):
try:
yield from self._consumer()
finally:
self.join()
def join(self):
self.__stop.set()
self.clear()
self.__thread.join(timeout=1)
if self.__thread.is_alive():
print(f"waiting {self.__thread.name} of {self}", file=sys.stderr)
self.__thread.join()
print(f"joined {self.__thread.name} of {self}", file=sys.stderr)
def clear(self):
try:
while self.__queue.get_nowait() is not None:
pass
except queue.Empty:
pass
def _producer(self):
try:
for item in self.__iterable:
self.__queue.put(item)
if self.__stop.is_set():
self.clear()
break
except Exception as ex:
print(f"Error in {self}: {ex}", file=sys.stderr)
self.clear()
self.__queue.put(ConcurrentIterError(ex))
else:
self.__queue.put(self.__last)
def _consumer(self):
while True:
item = self.__queue.get()
if item is self.__last:
break
if isinstance(item, ConcurrentIterError):
(ex,) = item.args
raise item from ex
yield item
class InterceptHandler(logging.Handler):
"""
Intercept standard logging messages toward loguru sinks.
See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging
"""
def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists.
level: str | int
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
# Find caller from where originated the logged message.
frame, depth = inspect.currentframe(), 0
while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
frame = frame.f_back
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)

78
smallpond/worker.py Normal file
View File

@@ -0,0 +1,78 @@
# this file is the entry point for smallpond workers
import argparse
import os
import socket
import subprocess
import psutil
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="smallpond worker")
parser.add_argument(
"--ray_address",
required=True,
help="The address of the Ray cluster to connect to",
)
parser.add_argument(
"--log_dir", required=True, help="The directory where logs will be stored"
)
parser.add_argument(
"--bind_numa_node",
action="store_true",
help="Bind executor processes to numa nodes",
)
args = parser.parse_args()
log_path = os.path.join(args.log_dir, f"{socket.gethostname()}.log")
# limit the number of CPUs to the number of physical cores
cpu_count = psutil.cpu_count(logical=False)
memory = psutil.virtual_memory().total
if args.bind_numa_node:
import numa
numa_node_count = numa.info.get_num_configured_nodes()
cpu_count_per_socket = cpu_count // numa_node_count
memory_per_socket = memory // numa_node_count
for i in range(numa_node_count):
subprocess.run(
[
"numactl",
"-N",
str(i),
"-m",
str(i),
"ray",
"start",
"--address",
args.ray_address,
"--num-cpus",
str(cpu_count_per_socket),
"--memory",
str(memory_per_socket),
],
check=True,
)
else:
subprocess.run(
[
"ray",
"start",
"--address",
args.ray_address,
"--num-cpus",
str(cpu_count),
],
check=True,
)
# keep printing logs
while True:
try:
subprocess.run(["tail", "-F", log_path], check=True)
except subprocess.CalledProcessError as e:
# XXX: sometimes it raises `No such file or directory`
# don't know why. just ignore it
print(e)

0
tests/__init__.py Normal file
View File

30
tests/conftest.py Normal file
View File

@@ -0,0 +1,30 @@
import os
import pytest
import ray
import smallpond
@pytest.fixture(scope="session")
def ray_address():
"""A global Ray instance for all tests"""
ray_address = ray.init(
address="local",
# disable dashboard in unit tests
include_dashboard=False,
).address_info["gcs_address"]
yield ray_address
ray.shutdown()
@pytest.fixture
def sp(ray_address: str, request):
"""A smallpond session for each test"""
runtime_root = os.getenv("TEST_RUNTIME_ROOT") or f"tests/runtime"
sp = smallpond.init(
data_root=os.path.join(runtime_root, request.node.name),
ray_address=ray_address,
)
yield sp
sp.shutdown()

186
tests/datagen.py Normal file
View File

@@ -0,0 +1,186 @@
import base64
import glob
import os
import random
import string
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime, timedelta, timezone
from typing import Tuple
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from filelock import FileLock
def generate_url_and_domain() -> Tuple[str, str]:
domain_part = "".join(
random.choices(string.ascii_lowercase, k=random.randint(5, 15))
)
tld = random.choice(["com", "net", "org", "cn", "edu", "gov", "co", "io"])
domain = f"www.{domain_part}.{tld}"
path_segments = []
for _ in range(random.randint(1, 3)):
segment = "".join(
random.choices(
string.ascii_lowercase + string.digits, k=random.randint(3, 10)
)
)
path_segments.append(segment)
path = "/" + "/".join(path_segments)
protocol = random.choice(["http", "https"])
if random.random() < 0.3:
path += random.choice([".html", ".php", ".htm", ".aspx"])
return f"{protocol}://{domain}{path}", domain
def generate_random_date() -> str:
start = datetime(2023, 1, 1, tzinfo=timezone.utc)
end = datetime(2023, 12, 31, tzinfo=timezone.utc)
delta = end - start
random_date = start + timedelta(
seconds=random.randint(0, int(delta.total_seconds()))
)
return random_date.strftime("%Y-%m-%dT%H:%M:%SZ")
def generate_content() -> bytes:
target_length = (
random.randint(1000, 100000)
if random.random() < 0.8
else random.randint(100000, 1000000)
)
before = b"<!DOCTYPE html><html><head><title>Random Page</title></head><body>"
after = b"</body></html>"
total_before_after = len(before) + len(after)
fill_length = max(target_length - total_before_after, 0)
filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[
:fill_length
]
return before + filler + after
def generate_arrow_parquet(path: str, num_rows=100):
data = []
for _ in range(num_rows):
url, domain = generate_url_and_domain()
date = generate_random_date()
content = generate_content()
data.append({"url": url, "domain": domain, "date": date, "content": content})
df = pd.DataFrame(data)
df.to_parquet(path, engine="pyarrow")
def generate_arrow_files(output_dir: str, num_files=10):
os.makedirs(output_dir, exist_ok=True)
with ProcessPoolExecutor(max_workers=10) as executor:
executor.map(
generate_arrow_parquet,
[f"{output_dir}/data{i}.parquet" for i in range(num_files)],
)
def concat_arrow_files(input_dir: str, output_dir: str, repeat: int = 10):
os.makedirs(output_dir, exist_ok=True)
files = glob.glob(os.path.join(input_dir, "*.parquet"))
table = pa.concat_tables([pa.parquet.read_table(file) for file in files] * repeat)
pq.write_table(table, os.path.join(output_dir, "large_array.parquet"))
def generate_random_string(length: int) -> str:
"""Generate a random string of a specified length"""
return "".join(random.choices(string.ascii_letters + string.digits, k=length))
def generate_random_url() -> str:
"""Generate a random URL"""
path = generate_random_string(random.randint(10, 20))
return (
f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html"
)
def generate_random_data() -> str:
"""Generate random data"""
url = generate_random_url()
content = generate_random_string(random.randint(50, 100))
encoded_content = base64.b64encode(content.encode()).decode()
return f"{url}\t{encoded_content}"
def generate_url_parquet(path: str, num_rows=100):
"""Generate a parquet file with a specified number of random data lines"""
data = []
for _ in range(num_rows):
url = generate_random_url()
host = url.split("/")[0]
data.append({"host": host, "url": url})
df = pd.DataFrame(data)
df.to_parquet(path, engine="pyarrow")
def generate_url_parquet_files(output_dir: str, num_files: int = 10):
"""Generate multiple parquet files with a specified number of random data lines"""
os.makedirs(output_dir, exist_ok=True)
with ProcessPoolExecutor(max_workers=10) as executor:
executor.map(
generate_url_parquet,
[f"{output_dir}/urls{i}.parquet" for i in range(num_files)],
)
def generate_url_tsv_files(
output_dir: str, num_files: int = 10, lines_per_file: int = 100
):
"""Generate multiple files, each containing a specified number of random data lines"""
os.makedirs(output_dir, exist_ok=True)
for i in range(num_files):
with open(f"{output_dir}/urls{i}.tsv", "w") as f:
for _ in range(lines_per_file):
f.write(generate_random_data() + "\n")
def generate_long_path_list(path: str, num_lines: int = 1048576):
"""Generate a list of long paths"""
with open(path, "w", buffering=16 * 1024 * 1024) as f:
for i in range(num_lines):
path = os.path.abspath(f"tests/data/arrow/data{i % 10}.parquet")
f.write(f"{path}\n")
def generate_data(path: str = "tests/data"):
"""
Generate all data for testing.
"""
os.makedirs(path, exist_ok=True)
try:
with FileLock(path + "/data.lock"):
print("Generating data...")
if not os.path.exists(path + "/mock_urls"):
generate_url_tsv_files(
output_dir=path + "/mock_urls", num_files=10, lines_per_file=100
)
generate_url_parquet_files(output_dir=path + "/mock_urls", num_files=10)
if not os.path.exists(path + "/arrow"):
generate_arrow_files(output_dir=path + "/arrow", num_files=10)
if not os.path.exists(path + "/large_array"):
concat_arrow_files(
input_dir=path + "/arrow", output_dir=path + "/large_array"
)
if not os.path.exists(path + "/long_path_list.txt"):
generate_long_path_list(path=path + "/long_path_list.txt")
except Exception as e:
print(f"Error generating data: {e}")
if __name__ == "__main__":
generate_data()

187
tests/test_arrow.py Normal file
View File

@@ -0,0 +1,187 @@
import glob
import os.path
import tempfile
import unittest
import pyarrow.parquet as parquet
from loguru import logger
from smallpond.io.arrow import (
RowRange,
build_batch_reader_from_files,
cast_columns_to_large_string,
dump_to_parquet_files,
load_from_parquet_files,
)
from smallpond.utility import ConcurrentIter
from tests.test_fabric import TestFabric
class TestArrow(TestFabric, unittest.TestCase):
def test_load_from_parquet_files(self):
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected = self._load_parquet_files(parquet_files)
actual = load_from_parquet_files(parquet_files)
self._compare_arrow_tables(expected, actual)
def test_load_parquet_row_ranges(self):
for dataset_path in (
"tests/data/arrow/data0.parquet",
"tests/data/large_array/large_array.parquet",
):
with self.subTest(dataset_path=dataset_path):
metadata = parquet.read_metadata(dataset_path)
file_num_rows = metadata.num_rows
data_size = sum(
metadata.row_group(i).total_byte_size
for i in range(metadata.num_row_groups)
)
row_range = RowRange(
path=dataset_path,
begin=100,
end=200,
data_size=data_size,
file_num_rows=file_num_rows,
)
expected = self._load_parquet_files([dataset_path]).slice(
offset=100, length=100
)
actual = load_from_parquet_files([row_range])
self._compare_arrow_tables(expected, actual)
def test_dump_to_parquet_files(self):
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected = self._load_parquet_files(parquet_files)
with tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
ok = dump_to_parquet_files(expected, output_dir)
self.assertTrue(ok)
actual = self._load_parquet_files(
glob.glob(f"{output_dir}/*.parquet")
)
self._compare_arrow_tables(expected, actual)
def test_dump_load_empty_table(self):
# create empty table
empty_table = self._load_parquet_files(
["tests/data/arrow/data0.parquet"]
).slice(length=0)
self.assertEqual(empty_table.num_rows, 0)
# dump empty table
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ok = dump_to_parquet_files(empty_table, output_dir)
self.assertTrue(ok)
parquet_files = glob.glob(f"{output_dir}/*.parquet")
# load empty table from file
actual_table = load_from_parquet_files(parquet_files)
self._compare_arrow_tables(empty_table, actual_table)
def test_parquet_batch_reader(self):
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected_num_rows = sum(
parquet.read_metadata(file).num_rows for file in parquet_files
)
with build_batch_reader_from_files(
parquet_files,
batch_size=expected_num_rows,
max_batch_byte_size=None,
) as batch_reader, ConcurrentIter(batch_reader) as concurrent_iter:
total_num_rows = 0
for batch in concurrent_iter:
print(
f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}"
)
self.assertLessEqual(batch.num_rows, expected_num_rows)
total_num_rows += batch.num_rows
self.assertEqual(total_num_rows, expected_num_rows)
def test_table_to_batches(self):
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
table = self._load_parquet_files(parquet_files)
total_num_rows = 0
for batch in table.to_batches(max_chunksize=table.num_rows):
print(
f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}"
)
self.assertLessEqual(batch.num_rows, table.num_rows)
total_num_rows += batch.num_rows
self.assertEqual(total_num_rows, table.num_rows)
def test_arrow_schema_metadata(self):
table = self._load_parquet_files(glob.glob("tests/data/arrow/*.parquet"))
metadata = {b"a": b"1", b"b": b"2"}
table_with_meta = table.replace_schema_metadata(metadata)
print(f"table_with_meta.schema.metadata {table_with_meta.schema.metadata}")
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
self.assertTrue(
dump_to_parquet_files(
table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2
)
)
parquet_files = glob.glob(
os.path.join(output_dir, "arrow_schema_metadata*.parquet")
)
loaded_table = load_from_parquet_files(
parquet_files, table.column_names[:1]
)
print(f"loaded_table.schema.metadata {loaded_table.schema.metadata}")
self.assertEqual(
table_with_meta.schema.metadata, loaded_table.schema.metadata
)
with parquet.ParquetFile(parquet_files[0]) as file:
print(f"file.schema_arrow.metadata {file.schema_arrow.metadata}")
self.assertEqual(
table_with_meta.schema.metadata, file.schema_arrow.metadata
)
def test_load_mixed_string_types(self):
parquet_paths = glob.glob("tests/data/arrow/*.parquet")
table = self._load_parquet_files(parquet_paths)
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
dump_to_parquet_files(cast_columns_to_large_string(table), output_dir)
parquet_paths += glob.glob(os.path.join(output_dir, "*.parquet"))
loaded_table = load_from_parquet_files(parquet_paths)
self.assertEqual(table.num_rows * 2, loaded_table.num_rows)
batch_reader = build_batch_reader_from_files(parquet_paths)
self.assertEqual(
table.num_rows * 2, sum(batch.num_rows for batch in batch_reader)
)
@logger.catch(reraise=True, message="failed to load parquet files")
def _load_from_parquet_files_with_log(self, paths, columns):
load_from_parquet_files(paths, columns)
def test_load_not_exist_column(self):
parquet_files = glob.glob("tests/data/arrow/*.parquet")
with self.assertRaises(AssertionError) as context:
self._load_from_parquet_files_with_log(parquet_files, ["not_exist_column"])
def test_change_ordering_of_columns(self):
parquet_files = glob.glob("tests/data/arrow/*.parquet")
loaded_table = load_from_parquet_files(parquet_files)
reversed_cols = list(reversed(loaded_table.column_names))
loaded_table = load_from_parquet_files(parquet_files, reversed_cols)
self.assertEqual(loaded_table.column_names, reversed_cols)

90
tests/test_bench.py Normal file
View File

@@ -0,0 +1,90 @@
import shutil
import unittest
from benchmarks.file_io_benchmark import file_io_benchmark
from benchmarks.gray_sort_benchmark import generate_random_records, gray_sort_benchmark
from benchmarks.hash_partition_benchmark import hash_partition_benchmark
from benchmarks.urls_sort_benchmark import urls_sort_benchmark
from smallpond.common import MB
from smallpond.logical.node import Context, LogicalPlan
from tests.test_fabric import TestFabric
class TestBench(TestFabric, unittest.TestCase):
fault_inject_prob = 0.05
def test_file_io_benchmark(self):
for io_engine in ("duckdb", "arrow", "stream"):
with self.subTest(io_engine=io_engine):
plan = file_io_benchmark(
["tests/data/mock_urls/*.parquet"],
npartitions=3,
io_engine=io_engine,
)
self.execute_plan(plan, enable_profiling=True)
def test_urls_sort_benchmark(self):
for engine_type in ("duckdb", "arrow"):
with self.subTest(engine_type=engine_type):
plan = urls_sort_benchmark(
["tests/data/mock_urls/*.tsv"],
num_data_partitions=3,
num_hash_partitions=3,
engine_type=engine_type,
)
self.execute_plan(plan, enable_profiling=True)
@unittest.skipIf(shutil.which("gensort") is None, "gensort not found")
def test_gray_sort_benchmark(self):
record_nbytes = 100
key_nbytes = 10
total_data_nbytes = 100 * MB
gensort_batch_nbytes = 10 * MB
num_data_partitions = 5
num_sort_partitions = 1 << 3
for shuffle_engine in ("duckdb", "arrow"):
for sort_engine in ("duckdb", "arrow", "polars"):
with self.subTest(
shuffle_engine=shuffle_engine, sort_engine=sort_engine
):
ctx = Context()
random_records = generate_random_records(
ctx,
record_nbytes,
key_nbytes,
total_data_nbytes,
gensort_batch_nbytes,
num_data_partitions,
num_sort_partitions,
)
plan = LogicalPlan(ctx, random_records)
exec_plan = self.execute_plan(plan, enable_profiling=True)
plan = gray_sort_benchmark(
record_nbytes,
key_nbytes,
total_data_nbytes,
gensort_batch_nbytes,
num_data_partitions,
num_sort_partitions,
input_paths=exec_plan.final_output.resolved_paths,
shuffle_engine=shuffle_engine,
sort_engine=sort_engine,
hive_partitioning=True,
validate_results=True,
)
self.execute_plan(plan, enable_profiling=True)
def test_hash_partition_benchmark(self):
for engine_type in ("duckdb", "arrow"):
with self.subTest(engine_type=engine_type):
plan = hash_partition_benchmark(
["tests/data/mock_urls/*.parquet"],
npartitions=5,
hash_columns=["url"],
engine_type=engine_type,
hive_partitioning=True,
partition_stats=True,
)
self.execute_plan(plan, enable_profiling=True)

73
tests/test_common.py Normal file
View File

@@ -0,0 +1,73 @@
import itertools
import unittest
import numpy as np
from hypothesis import given
from hypothesis import strategies as st
from smallpond.common import get_nth_partition, split_into_cols, split_into_rows
from tests.test_fabric import TestFabric
class TestCommon(TestFabric, unittest.TestCase):
def test_get_nth_partition(self):
items = [1, 2, 3]
# split into 1 partitions
self.assertListEqual([1, 2, 3], get_nth_partition(items, 0, 1))
# split into 2 partitions
self.assertListEqual([1, 2], get_nth_partition(items, 0, 2))
self.assertListEqual([3], get_nth_partition(items, 1, 2))
# split into 3 partitions
self.assertListEqual([1], get_nth_partition(items, 0, 3))
self.assertListEqual([2], get_nth_partition(items, 1, 3))
self.assertListEqual([3], get_nth_partition(items, 2, 3))
# split into 5 partitions
self.assertListEqual([1], get_nth_partition(items, 0, 5))
self.assertListEqual([2], get_nth_partition(items, 1, 5))
self.assertListEqual([3], get_nth_partition(items, 2, 5))
self.assertListEqual([], get_nth_partition(items, 3, 5))
self.assertListEqual([], get_nth_partition(items, 4, 5))
@given(st.data())
def test_split_into_rows(self, data: st.data):
nelements = data.draw(st.integers(1, 100))
npartitions = data.draw(st.integers(1, 2 * nelements))
items = list(range(nelements))
computed = split_into_rows(items, npartitions)
expected = [
get_nth_partition(items, n, npartitions) for n in range(npartitions)
]
self.assertEqual(expected, computed)
@given(st.data())
def test_split_into_cols(self, data: st.data):
nelements = data.draw(st.integers(1, 100))
npartitions = data.draw(st.integers(1, 2 * nelements))
items = list(range(nelements))
chunks = split_into_cols(items, npartitions)
self.assertEqual(npartitions, len(chunks))
self.assertListEqual(
items,
[x for row in itertools.zip_longest(*chunks) for x in row if x is not None],
)
chunk_sizes = set(len(chk) for chk in chunks)
if len(chunk_sizes) > 1:
small_size, large_size = sorted(chunk_sizes)
self.assertEqual(small_size + 1, large_size)
else:
(chunk_size,) = chunk_sizes
self.assertEqual(len(items), chunk_size * npartitions)
def test_split_into_rows_bench(self):
for nelements in [100000, 1000000]:
items = np.arange(nelements)
for npartitions in [1024, 4096, 10240, nelements, 2 * nelements]:
chunks = split_into_rows(items, npartitions)
self.assertEqual(npartitions, len(chunks))
def test_split_into_cols_bench(self):
for nelements in [100000, 1000000]:
items = np.arange(nelements)
for npartitions in [1024, 4096, 10240, nelements, 2 * nelements]:
chunks = split_into_cols(items, npartitions)
self.assertEqual(npartitions, len(chunks))

223
tests/test_dataframe.py Normal file
View File

@@ -0,0 +1,223 @@
from typing import List
import pandas as pd
import pyarrow as pa
import pytest
from smallpond.dataframe import Session
def test_pandas(sp: Session):
pandas_df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df = sp.from_pandas(pandas_df)
assert df.to_pandas().equals(pandas_df)
def test_arrow(sp: Session):
arrow_table = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
df = sp.from_arrow(arrow_table)
assert df.to_arrow() == arrow_table
def test_items(sp: Session):
df = sp.from_items([1, 2, 3])
assert df.take_all() == [{"item": 1}, {"item": 2}, {"item": 3}]
df = sp.from_items([{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}])
assert df.take_all() == [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]
def test_csv(sp: Session):
df = sp.read_csv(
"tests/data/mock_urls/*.tsv",
schema={"urlstr": "varchar", "valstr": "varchar"},
delim=r"\t",
)
assert df.count() == 1000
def test_parquet(sp: Session):
df = sp.read_parquet("tests/data/mock_urls/*.parquet")
assert df.count() == 1000
def test_take(sp: Session):
df = sp.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
assert df.take(2) == [{"a": 1, "b": 4}, {"a": 2, "b": 5}]
assert df.take_all() == [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]
def test_map(sp: Session):
df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
df1 = df.map("a + b as c")
assert df1.to_arrow() == pa.table({"c": [5, 7, 9]})
df2 = df.map(lambda r: {"c": r["a"] + r["b"]})
assert df2.to_arrow() == pa.table({"c": [5, 7, 9]})
# user need to specify the schema if can not be inferred from the mapping values
df3 = df.map(
lambda r: {"c": None if r["a"] == 1 else r["a"] + r["b"]},
schema=pa.schema([("c", pa.int64())]),
)
assert df3.to_arrow() == pa.table({"c": pa.array([None, 7, 9], type=pa.int64())})
def test_flat_map(sp: Session):
df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
df1 = df.flat_map(lambda r: [{"c": r["a"]}, {"c": r["b"]}])
assert df1.to_arrow() == pa.table({"c": [1, 4, 2, 5, 3, 6]})
df2 = df.flat_map("unnest(array[a, b]) as c")
assert df2.to_arrow() == pa.table({"c": [1, 4, 2, 5, 3, 6]})
# user need to specify the schema if can not be inferred from the mapping values
df3 = df.flat_map(lambda r: [{"c": None}], schema=pa.schema([("c", pa.int64())]))
assert df3.to_arrow() == pa.table(
{"c": pa.array([None, None, None], type=pa.int64())}
)
def test_map_batches(sp: Session):
df = sp.read_parquet("tests/data/mock_urls/*.parquet")
df = df.map_batches(
lambda batch: pa.table({"num_rows": [batch.num_rows]}),
batch_size=350,
)
assert df.take_all() == [{"num_rows": 350}, {"num_rows": 350}, {"num_rows": 300}]
def test_filter(sp: Session):
df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
df1 = df.filter("a > 1")
assert df1.to_arrow() == pa.table({"a": [2, 3], "b": [5, 6]})
df2 = df.filter(lambda r: r["a"] > 1)
assert df2.to_arrow() == pa.table({"a": [2, 3], "b": [5, 6]})
def test_random_shuffle(sp: Session):
df = sp.from_items(list(range(1000))).repartition(10, by_rows=True)
df = df.random_shuffle()
shuffled = [d["item"] for d in df.take_all()]
assert sorted(shuffled) == list(range(1000))
def count_inversions(arr: List[int]) -> int:
return sum(
sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j])
for i in range(len(arr))
)
# check the shuffle is random enough
# the expected number of inversions is n*(n-1)/4 = 249750
assert 220000 <= count_inversions(shuffled) <= 280000
def test_partition_by(sp: Session):
df = sp.from_items(list(range(1000))).repartition(10, by="item % 10")
df = df.map("min(item % 10) as min, max(item % 10) as max")
assert df.take_all() == [{"min": i, "max": i} for i in range(10)]
def test_partition_by_key_out_of_range(sp: Session):
df = sp.from_items(list(range(1000))).repartition(10, by="item % 11")
try:
df.to_arrow()
except Exception as ex:
assert "partition key 10 is out of range 0-9" in str(ex)
else:
assert False, "expected exception"
def test_partition_by_hash(sp: Session):
df = sp.from_items(list(range(1000))).repartition(10, hash_by="item")
items = [d["item"] for d in df.take_all()]
assert sorted(items) == list(range(1000))
def test_count(sp: Session):
df = sp.from_items([1, 2, 3])
assert df.count() == 3
def test_limit(sp: Session):
df = sp.from_items(list(range(1000))).repartition(10, by_rows=True)
assert df.limit(2).count() == 2
@pytest.mark.skip(reason="limit can not be pushed down to sql node for now")
@pytest.mark.timeout(10)
def test_limit_large(sp: Session):
# limit will be fused with the previous select
# otherwise, it will be timeout
df = sp.partial_sql("select * from range(1000000000)")
assert df.limit(2).count() == 2
def test_partial_sql(sp: Session):
# no input deps
df = sp.partial_sql("select * from range(3)")
assert df.to_arrow() == pa.table({"range": [0, 1, 2]})
# join
df1 = sp.from_arrow(pa.table({"id1": [1, 2, 3], "val1": ["a", "b", "c"]}))
df2 = sp.from_arrow(pa.table({"id2": [1, 2, 3], "val2": ["d", "e", "f"]}))
joined = sp.partial_sql(
"select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2
)
assert joined.to_arrow() == pa.table(
{"id1": [1, 2, 3], "val1": ["a", "b", "c"], "val2": ["d", "e", "f"]},
schema=pa.schema(
[
("id1", pa.int64()),
("val1", pa.large_string()),
("val2", pa.large_string()),
]
),
)
def test_error_message(sp: Session):
df = sp.from_items([1, 2, 3])
df = sp.partial_sql("select a,, from {0}", df)
try:
df.to_arrow()
except Exception as ex:
# sql query should be in the exception message
assert "select a,, from" in str(ex)
else:
assert False, "expected exception"
def test_unpicklable_task_exception(sp: Session):
from loguru import logger
df = sp.from_items([1, 2, 3])
try:
df.map(lambda x: logger.info("use outside logger")).to_arrow()
except Exception as ex:
assert "Can't pickle task" in str(ex)
assert (
"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
in str(ex)
)
else:
assert False, "expected exception"
def test_log(sp: Session):
df = sp.from_items([1, 2, 3])
def log_record(x):
import logging
import sys
from loguru import logger
print("stdout")
print("stderr", file=sys.stderr)
logger.info("loguru")
logging.info("logging")
return x
df.map(log_record).to_arrow()
# TODO: check logs should be see in the log file
# FIXME: logs in unit test are not written to the log file
# because we share the same ray instance for all tests

174
tests/test_dataset.py Normal file
View File

@@ -0,0 +1,174 @@
import glob
import os.path
import unittest
from pathlib import PurePath
import duckdb
import pandas
import pyarrow as arrow
import pytest
from loguru import logger
from smallpond.common import DEFAULT_ROW_GROUP_SIZE, MB
from smallpond.logical.dataset import ParquetDataSet
from smallpond.utility import ConcurrentIter
from tests.test_fabric import TestFabric
class TestDataSet(TestFabric, unittest.TestCase):
def test_parquet_file_created_by_pandas(self):
num_urls = 0
for txt_file in glob.glob("tests/data/mock_urls/*.tsv"):
urls = pandas.read_csv(txt_file, delimiter="\t", names=["url"])
urls.to_parquet(
os.path.join(
self.output_root_abspath,
PurePath(os.path.basename(txt_file)).with_suffix(".parquet"),
)
)
num_urls += urls.size
dataset = ParquetDataSet([os.path.join(self.output_root_abspath, "*.parquet")])
self.assertEqual(num_urls, dataset.num_rows)
def _generate_parquet_dataset(
self, output_path, npartitions, num_rows, row_group_size
):
duckdb.sql(
f"""copy (
select range as i, range % {npartitions} as partition from range(0, {num_rows}) )
to '{output_path}'
(FORMAT PARQUET, ROW_GROUP_SIZE {row_group_size}, PARTITION_BY partition, OVERWRITE_OR_IGNORE true)"""
)
return ParquetDataSet([f"{output_path}/**/*.parquet"])
def _check_partition_datasets(
self, orig_dataset: ParquetDataSet, partition_func, npartition
):
# build partitioned datasets
partitioned_datasets = partition_func(npartition)
self.assertEqual(npartition, len(partitioned_datasets))
self.assertEqual(
orig_dataset.num_rows,
sum(dataset.num_rows for dataset in partitioned_datasets),
)
# load as arrow table
loaded_table = arrow.concat_tables(
[dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets]
)
self.assertEqual(orig_dataset.num_rows, loaded_table.num_rows)
# compare arrow tables
orig_table = orig_dataset.to_arrow_table(max_workers=1)
self.assertEqual(orig_table.shape, loaded_table.shape)
self.assertTrue(orig_table.sort_by("i").equals(loaded_table.sort_by("i")))
# compare sql query results
join_query = f"""
select count(a.i) as num_rows
from {orig_dataset.sql_query_fragment()} as a
join ( {' union all '.join([dataset.sql_query_fragment() for dataset in partitioned_datasets])} ) as b on a.i = b.i"""
results = duckdb.sql(join_query).fetchall()
self.assertEqual(orig_dataset.num_rows, results[0][0])
def test_num_rows(self):
dataset = ParquetDataSet(["tests/data/arrow/*.parquet"])
self.assertEqual(dataset.num_rows, 1000)
def test_partition_by_files(self):
output_path = os.path.join(self.output_root_abspath, "test_partition_by_files")
orig_dataset = self._generate_parquet_dataset(
output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000
)
num_files = len(orig_dataset.resolved_paths)
for npartition in range(1, num_files + 1):
for random_shuffle in (False, True):
with self.subTest(npartition=npartition, random_shuffle=random_shuffle):
orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir)
self._check_partition_datasets(
orig_dataset,
lambda n: orig_dataset.partition_by_files(
n, random_shuffle=random_shuffle
),
npartition,
)
def test_partition_by_rows(self):
output_path = os.path.join(self.output_root_abspath, "test_partition_by_rows")
orig_dataset = self._generate_parquet_dataset(
output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000
)
num_files = len(orig_dataset.resolved_paths)
for npartition in range(1, 2 * num_files + 1):
for random_shuffle in (False, True):
with self.subTest(npartition=npartition, random_shuffle=random_shuffle):
orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir)
self._check_partition_datasets(
orig_dataset,
lambda n: orig_dataset.partition_by_rows(
n, random_shuffle=random_shuffle
),
npartition,
)
def test_resolved_many_paths(self):
with open("tests/data/long_path_list.txt", buffering=16 * MB) as fin:
filenames = list(map(os.path.basename, map(str.strip, fin.readlines())))
logger.info(f"loaded {len(filenames)} filenames")
dataset = ParquetDataSet(filenames)
self.assertEqual(len(dataset.resolved_paths), len(filenames))
def test_paths_with_char_ranges(self):
dataset_with_char_ranges = ParquetDataSet(
["tests/data/arrow/data[0-9].parquet"]
)
dataset_with_wildcards = ParquetDataSet(["tests/data/arrow/*.parquet"])
self.assertEqual(
len(dataset_with_char_ranges.resolved_paths),
len(dataset_with_wildcards.resolved_paths),
)
def test_to_arrow_table_batch_reader(self):
memdb = duckdb.connect(
database=":memory:", config={"arrow_large_buffer_size": "true"}
)
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
for conn in (None, memdb):
print(f"dataset_path: {dataset_path}, conn: {conn}")
with self.subTest(dataset_path=dataset_path, conn=conn):
dataset = ParquetDataSet([dataset_path])
to_batches = dataset.to_arrow_table(
max_workers=1, conn=conn
).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2)
batch_reader = dataset.to_batch_reader(
batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn
)
with ConcurrentIter(
batch_reader, max_buffer_size=2
) as batch_reader:
for batch_iter in (to_batches, batch_reader):
total_num_rows = 0
for batch in batch_iter:
print(
f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}"
)
self.assertLessEqual(
batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2
)
total_num_rows += batch.num_rows
print(f"{dataset_path}: total_num_rows {total_num_rows}")
self.assertEqual(total_num_rows, dataset.num_rows)
@pytest.mark.parametrize("reader", ["arrow", "duckdb"])
@pytest.mark.parametrize("dataset_path", ["tests/data/arrow/*.parquet"])
# @pytest.mark.parametrize("dataset_path", ["tests/data/arrow/*.parquet", "tests/data/large_array/*.parquet"])
def test_arrow_reader(benchmark, reader: str, dataset_path: str):
dataset = ParquetDataSet([dataset_path])
conn = None
if reader == "duckdb":
conn = duckdb.connect(
database=":memory:", config={"arrow_large_buffer_size": "true"}
)
benchmark(dataset.to_arrow_table, conn=conn)
# result: arrow reader is 4x faster than duckdb reader in small dataset, 1.4x faster in large dataset

60
tests/test_deltalake.py Normal file
View File

@@ -0,0 +1,60 @@
import glob
import importlib
import tempfile
import unittest
from smallpond.io.arrow import cast_columns_to_large_string
from tests.test_fabric import TestFabric
@unittest.skipUnless(
importlib.util.find_spec("deltalake") is not None, "cannot find deltalake"
)
class TestDeltaLake(TestFabric, unittest.TestCase):
def test_read_write_deltalake(self):
from deltalake import DeltaTable, write_deltalake
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
parquet_files = glob.glob(dataset_path)
expected = self._load_parquet_files(parquet_files)
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
write_deltalake(output_dir, expected, large_dtypes=True)
dt = DeltaTable(output_dir)
self._compare_arrow_tables(expected, dt.to_pyarrow_table())
def test_load_mixed_large_dtypes(self):
from deltalake import DeltaTable, write_deltalake
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
):
parquet_files = glob.glob(dataset_path)
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
table = cast_columns_to_large_string(
self._load_parquet_files(parquet_files)
)
write_deltalake(output_dir, table, large_dtypes=True, mode="overwrite")
write_deltalake(output_dir, table, large_dtypes=False, mode="append")
loaded_table = DeltaTable(output_dir).to_pyarrow_table()
print("table:\n", table.schema)
print("loaded_table:\n", loaded_table.schema)
self.assertEqual(table.num_rows * 2, loaded_table.num_rows)
def test_delete_update(self):
import pandas as pd
from deltalake import DeltaTable, write_deltalake
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
df = pd.DataFrame({"num": [1, 2, 3], "animal": ["cat", "dog", "snake"]})
write_deltalake(output_dir, df, mode="overwrite")
dt = DeltaTable(output_dir)
dt.delete("animal = 'cat'")
dt.update(predicate="num = 3", new_values={"animal": "fish"})

46
tests/test_driver.py Normal file
View File

@@ -0,0 +1,46 @@
import os.path
import unittest
import uuid
from loguru import logger
from benchmarks.gray_sort_benchmark import gray_sort_benchmark
from examples.sort_mock_urls import sort_mock_urls
from smallpond.common import GB, MB
from smallpond.execution.driver import Driver
from tests.test_fabric import TestFabric
@unittest.skipUnless(os.getenv("ENABLE_DRIVER_TEST"), "unit test disabled")
class TestDriver(TestFabric, unittest.TestCase):
fault_inject_prob = 0.05
def create_driver(self, num_executors: int):
cmdline = f"scheduler --job_id {str(uuid.uuid4())} --job_name {self._testMethodName} --data_root {self.output_root_abspath} --num_executors {num_executors} --fault_inject_prob {self.fault_inject_prob}"
driver = Driver()
driver.parse_arguments(args=cmdline.split())
logger.info(f"{cmdline=} {driver.mode=} {driver.job_id=} {driver.data_root=}")
return driver
def test_standalone_mode(self):
plan = sort_mock_urls(["tests/data/mock_urls/*.tsv"], npartitions=3)
driver = self.create_driver(num_executors=0)
exec_plan = driver.run(plan, stop_process_on_done=False)
self.assertTrue(exec_plan.successful)
self.assertGreater(exec_plan.final_output.num_files, 0)
def test_run_on_remote_executors(self):
driver = self.create_driver(num_executors=2)
plan = gray_sort_benchmark(
record_nbytes=100,
key_nbytes=10,
total_data_nbytes=1 * GB,
gensort_batch_nbytes=100 * MB,
num_data_partitions=10,
num_sort_partitions=10,
validate_results=True,
)
exec_plan = driver.run(plan, stop_process_on_done=False)
self.assertTrue(exec_plan.successful)
self.assertGreater(exec_plan.final_output.num_files, 0)

886
tests/test_execution.py Normal file
View File

@@ -0,0 +1,886 @@
import functools
import os.path
import socket
import tempfile
import time
import unittest
from datetime import datetime
from typing import Iterable, List, Tuple
import pandas
import pyarrow as arrow
from loguru import logger
from pandas.core.api import DataFrame as DataFrame
from smallpond.common import GB, MB, split_into_rows
from smallpond.execution.task import (
DataSinkTask,
DataSourceTask,
JobId,
PartitionInfo,
PythonScriptTask,
RuntimeContext,
StreamOutput,
)
from smallpond.execution.workqueue import WorkStatus
from smallpond.logical.dataset import (
ArrowTableDataSet,
DataSet,
ParquetDataSet,
SqlQueryDataSet,
)
from smallpond.logical.node import (
ArrowBatchNode,
ArrowComputeNode,
ArrowStreamNode,
Context,
DataSetPartitionNode,
DataSinkNode,
DataSourceNode,
EvenlyDistributedPartitionNode,
HashPartitionNode,
LogicalPlan,
Node,
PandasBatchNode,
PandasComputeNode,
ProjectionNode,
PythonScriptNode,
RootNode,
SqlEngineNode,
)
from smallpond.logical.udf import UDFListType, UDFType
from tests.test_fabric import TestFabric
class OutputMsgPythonTask(PythonScriptTask):
def __init__(self, msg: str, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.msg = msg
def initialize(self):
pass
def finalize(self):
pass
def process(
self,
runtime_ctx: RuntimeContext,
input_datasets: List[DataSet],
output_path: str,
) -> bool:
logger.info(
f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}"
)
self.inject_fault()
return True
# method1: inherit Task class and override spawn method
class OutputMsgPythonNode(PythonScriptNode):
def spawn(self, *args, **kwargs) -> OutputMsgPythonTask:
return OutputMsgPythonTask("python script", *args, **kwargs)
# method2: override process method
# this usage is not recommended and only for testing. use `process_func` instead.
class OutputMsgPythonNode2(PythonScriptNode):
def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], msg: str) -> None:
super().__init__(ctx, input_deps)
self.msg = msg
def process(
self,
runtime_ctx: RuntimeContext,
input_datasets: List[DataSet],
output_path: str,
) -> bool:
logger.info(f"msg: {self.msg}, num files: {input_datasets[0].num_files}")
return True
# this usage is not recommended and only for testing. use `process_func` instead.
class CopyInputArrowNode(ArrowComputeNode):
def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], msg: str) -> None:
super().__init__(ctx, input_deps)
self.msg = msg
def process(
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
return copy_input_arrow(runtime_ctx, input_tables, self.msg)
# this usage is not recommended and only for testing. use `process_func` instead.
class CopyInputStreamNode(ArrowStreamNode):
def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], msg: str) -> None:
super().__init__(ctx, input_deps)
self.msg = msg
def process(
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
) -> Iterable[arrow.Table]:
return copy_input_stream(runtime_ctx, input_readers, self.msg)
def copy_input_arrow(
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str
) -> arrow.Table:
logger.info(f"msg: {msg}, num rows: {input_tables[0].num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
runtime_ctx.task.inject_fault()
return input_tables[0]
def copy_input_stream(
runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str
) -> Iterable[arrow.Table]:
for index, batch in enumerate(input_readers[0]):
logger.info(f"msg: {msg}, batch index: {index}, num rows: {batch.num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
yield StreamOutput(
arrow.Table.from_batches([batch]),
batch_indices=[index],
force_checkpoint=True,
)
runtime_ctx.task.inject_fault()
def copy_input_batch(
runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str
) -> arrow.Table:
logger.info(f"msg: {msg}, num rows: {input_batches[0].num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
runtime_ctx.task.inject_fault()
return input_batches[0]
def copy_input_data_frame(
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
) -> DataFrame:
runtime_ctx.task.inject_fault()
return input_dfs[0]
def copy_input_data_frame_batch(
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
) -> DataFrame:
runtime_ctx.task.inject_fault()
return input_dfs[0]
def merge_input_tables(
runtime_ctx: RuntimeContext, input_batches: List[arrow.Table]
) -> arrow.Table:
runtime_ctx.task.inject_fault()
output = arrow.concat_tables(input_batches)
logger.info(
f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}"
)
return output
def merge_input_data_frames(
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
) -> DataFrame:
runtime_ctx.task.inject_fault()
output = pandas.concat(input_dfs)
logger.info(
f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}"
)
return output
def parse_url(
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
urls = input_tables[0].columns[0]
hosts = [url.as_py().split("/", maxsplit=2)[0] for url in urls]
return input_tables[0].append_column("host", arrow.array(hosts))
def nonzero_exit_code(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
import sys
if runtime_ctx.task._memory_boost == 1:
sys.exit(1)
return True
# create an empty file with a fixed name
def empty_file(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
import os
with open(os.path.join(output_path, "file"), "w") as fout:
pass
return True
def return_fake_gpus(count: int = 8):
import GPUtil
return [GPUtil.GPU(i, *list(range(11))) for i in range(count)]
def split_url(urls: arrow.array) -> arrow.array:
url_parts = [url.as_py().split("/") for url in urls]
return arrow.array(url_parts, type=arrow.list_(arrow.string()))
def choose_random_urls(
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5
) -> arrow.Table:
# get the current running task
runtime_task = runtime_ctx.task
# access task-specific attributes
cpu_limit = runtime_task.cpu_limit
random_gen = runtime_task.python_random_gen
# input data
(url_table,) = input_tables
hosts, urls = url_table.columns
logger.info(f"{cpu_limit=} {len(urls)=}")
# generate ramdom samples
random_urls = random_gen.choices(urls.to_pylist(), k=k)
return arrow.Table.from_arrays([arrow.array(random_urls)], names=["random_urls"])
class TestExecution(TestFabric, unittest.TestCase):
fault_inject_prob = 0.05
def test_arrow_task(self):
for use_duckdb_reader in (False, True):
with self.subTest(use_duckdb_reader=use_duckdb_reader):
with tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_table = dataset.to_arrow_table()
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=7
)
if use_duckdb_reader:
data_partitions = ProjectionNode(
ctx,
data_partitions,
columns=["*", "string_split(url, '/')[0] as parsed_host"],
)
arrow_compute = ArrowComputeNode(
ctx,
(data_partitions,),
process_func=functools.partial(
copy_input_arrow, msg="arrow compute"
),
use_duckdb_reader=use_duckdb_reader,
output_name="arrow_compute",
output_path=output_dir,
cpu_limit=2,
)
arrow_stream = ArrowStreamNode(
ctx,
(data_partitions,),
process_func=functools.partial(
copy_input_stream, msg="arrow stream"
),
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
use_duckdb_reader=use_duckdb_reader,
output_name="arrow_stream",
output_path=output_dir,
cpu_limit=2,
)
arrow_batch = ArrowBatchNode(
ctx,
(data_partitions,),
process_func=functools.partial(
copy_input_batch, msg="arrow batch"
),
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
use_duckdb_reader=use_duckdb_reader,
output_name="arrow_batch",
output_path=output_dir,
cpu_limit=2,
)
data_sink = DataSinkNode(
ctx,
(arrow_compute, arrow_stream, arrow_batch),
output_path=output_dir,
)
plan = LogicalPlan(ctx, data_sink)
exec_plan = self.execute_plan(
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
)
self.assertTrue(
all(map(os.path.exists, exec_plan.final_output.resolved_paths))
)
arrow_compute_output = ParquetDataSet(
[os.path.join(output_dir, "arrow_compute", "**/*.parquet")],
recursive=True,
)
arrow_stream_output = ParquetDataSet(
[os.path.join(output_dir, "arrow_stream", "**/*.parquet")],
recursive=True,
)
arrow_batch_output = ParquetDataSet(
[os.path.join(output_dir, "arrow_batch", "**/*.parquet")],
recursive=True,
)
self._compare_arrow_tables(
data_table,
arrow_compute_output.to_arrow_table().select(
data_table.column_names
),
)
self._compare_arrow_tables(
data_table,
arrow_stream_output.to_arrow_table().select(
data_table.column_names
),
)
self._compare_arrow_tables(
data_table,
arrow_batch_output.to_arrow_table().select(
data_table.column_names
),
)
def test_pandas_task(self):
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_table = dataset.to_arrow_table()
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=7)
pandas_compute = PandasComputeNode(
ctx,
(data_partitions,),
process_func=copy_input_data_frame,
output_name="pandas_compute",
output_path=output_dir,
cpu_limit=2,
)
pandas_batch = PandasBatchNode(
ctx,
(data_partitions,),
process_func=copy_input_data_frame_batch,
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
output_name="pandas_batch",
output_path=output_dir,
cpu_limit=2,
)
data_sink = DataSinkNode(
ctx, (pandas_compute, pandas_batch), output_path=output_dir
)
plan = LogicalPlan(ctx, data_sink)
exec_plan = self.execute_plan(
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
)
self.assertTrue(
all(map(os.path.exists, exec_plan.final_output.resolved_paths))
)
pandas_compute_output = ParquetDataSet(
[os.path.join(output_dir, "pandas_compute", "**/*.parquet")],
recursive=True,
)
pandas_batch_output = ParquetDataSet(
[os.path.join(output_dir, "pandas_batch", "**/*.parquet")],
recursive=True,
)
self._compare_arrow_tables(
data_table, pandas_compute_output.to_arrow_table()
)
self._compare_arrow_tables(data_table, pandas_batch_output.to_arrow_table())
def test_variable_length_input_datasets(self):
ctx = Context()
small_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
large_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10)
small_partitions = DataSetPartitionNode(
ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7
)
large_partitions = DataSetPartitionNode(
ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7
)
arrow_batch = ArrowBatchNode(
ctx,
(small_partitions, large_partitions),
process_func=merge_input_tables,
streaming_batch_size=100,
secs_checkpoint_interval=0.5,
output_name="arrow_batch",
cpu_limit=2,
)
pandas_batch = PandasBatchNode(
ctx,
(small_partitions, large_partitions),
process_func=merge_input_data_frames,
streaming_batch_size=100,
secs_checkpoint_interval=0.5,
output_name="pandas_batch",
cpu_limit=2,
)
plan = LogicalPlan(ctx, RootNode(ctx, (arrow_batch, pandas_batch)))
exec_plan = self.execute_plan(
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
arrow_batch_output = ParquetDataSet(
[os.path.join(exec_plan.ctx.output_root, "arrow_batch", "**/*.parquet")],
recursive=True,
)
pandas_batch_output = ParquetDataSet(
[os.path.join(exec_plan.ctx.output_root, "pandas_batch", "**/*.parquet")],
recursive=True,
)
self.assertEqual(
small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows
)
self.assertEqual(
small_dataset.num_rows + large_dataset.num_rows,
pandas_batch_output.num_rows,
)
def test_projection_task(self):
ctx = Context()
# select columns when defining dataset
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"], columns=["url"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=3, partition_by_rows=True
)
# projection as input of arrow node
generated_columns = ["filename", "file_row_number"]
urls_with_host = ArrowComputeNode(
ctx,
(ProjectionNode(ctx, data_partitions, ["url"], generated_columns),),
process_func=parse_url,
use_duckdb_reader=True,
)
# projection as input of sql node
distinct_urls_with_host = SqlEngineNode(
ctx,
(
ProjectionNode(
ctx,
data_partitions,
["url", "string_split(url, '/')[0] as host"],
generated_columns,
),
),
r"select distinct host, url, filename from {0}",
)
# unify different schemas
merged_diff_schemas = ProjectionNode(
ctx,
DataSetPartitionNode(
ctx, (distinct_urls_with_host, urls_with_host), npartitions=1
),
union_by_name=True,
)
host_partitions = HashPartitionNode(
ctx,
(merged_diff_schemas,),
npartitions=3,
hash_columns=["host"],
engine_type="duckdb",
output_name="host_partitions",
)
host_partitions.max_num_producer_tasks = 1
plan = LogicalPlan(ctx, host_partitions)
final_output = self.execute_plan(plan, fault_inject_prob=0.1).final_output
final_table = final_output.to_arrow_table()
self.assertEqual(
sorted(
[
"url",
"host",
*generated_columns,
HashPartitionNode.default_data_partition_column,
]
),
sorted(final_table.column_names),
)
def test_arrow_type_in_udfs(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files
)
ctx.create_function(
"split_url",
split_url,
[UDFType.VARCHAR],
UDFListType(UDFType.VARCHAR),
use_arrow_type=True,
)
uniq_hosts = SqlEngineNode(
ctx,
(data_partitions,),
r"select split_url(url) as url_parts from {0}",
udfs=["split_url"],
)
plan = LogicalPlan(ctx, uniq_hosts)
self.execute_plan(plan)
def test_many_simple_tasks(self):
ctx = Context()
npartitions = 1000
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * npartitions)
data_files = DataSourceNode(ctx, dataset)
data_partitions = EvenlyDistributedPartitionNode(
ctx, (data_files,), npartitions=npartitions
)
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(
plan,
num_executors=10,
secs_executor_probe_interval=5,
enable_profiling=True,
)
def test_many_producers_and_partitions(self):
ctx = Context()
npartitions = 10000
dataset = ParquetDataSet(
["tests/data/mock_urls/*.parquet"] * (npartitions * 10)
)
data_files = DataSourceNode(ctx, dataset)
data_partitions = EvenlyDistributedPartitionNode(
ctx, (data_files,), npartitions=npartitions, cpu_limit=1
)
data_partitions.max_num_producer_tasks = 20
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(
plan,
num_executors=10,
secs_executor_probe_interval=5,
enable_profiling=True,
)
def test_local_gpu_rank(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files
)
output_msg = OutputMsgPythonNode(
ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5
)
plan = LogicalPlan(ctx, output_msg)
runtime_ctx = RuntimeContext(
JobId.new(),
datetime.now(),
self.output_root_abspath,
console_log_level="WARNING",
)
runtime_ctx.get_local_gpus = return_fake_gpus
runtime_ctx.initialize(socket.gethostname(), cleanup_root=True)
self.execute_plan(plan, runtime_ctx=runtime_ctx)
def test_python_node_with_process_method(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
copy_input_arrow_node = CopyInputArrowNode(ctx, (data_files,), "hello")
copy_input_stream_node = CopyInputStreamNode(ctx, (data_files,), "hello")
output_msg = OutputMsgPythonNode2(
ctx, (copy_input_arrow_node, copy_input_stream_node), "hello"
)
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(plan)
def test_sql_engine_oom(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
uniq_urls = SqlEngineNode(
ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB
)
uniq_url_partitions = DataSetPartitionNode(ctx, (uniq_urls,), 2)
uniq_url_count = SqlEngineNode(
ctx,
(uniq_url_partitions,),
sql_query=r"select count(distinct columns(*)) from {0}",
memory_limit=2 * MB,
)
plan = LogicalPlan(ctx, uniq_url_count)
self.execute_plan(plan, max_fail_count=10)
@unittest.skip("flaky on CI")
def test_enforce_memory_limit(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/arrow/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
arrow_compute = ArrowComputeNode(
ctx,
(data_files,),
process_func=functools.partial(copy_input_arrow, msg="arrow compute"),
memory_limit=1 * GB,
)
arrow_stream = ArrowStreamNode(
ctx,
(data_files,),
process_func=functools.partial(copy_input_stream, msg="arrow stream"),
memory_limit=1 * GB,
)
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
data_sink = DataSinkNode(
ctx, (arrow_compute, arrow_stream), output_path=output_dir
)
plan = LogicalPlan(ctx, data_sink)
self.execute_plan(
plan,
max_fail_count=10,
enforce_memory_limit=True,
nonzero_exitcode_as_oom=True,
)
def test_task_crash_as_oom(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
nonzero_exitcode = PythonScriptNode(
ctx, (data_files,), process_func=nonzero_exit_code
)
plan = LogicalPlan(ctx, nonzero_exitcode)
exec_plan = self.execute_plan(
plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False
)
self.assertFalse(exec_plan.successful)
exec_plan = self.execute_plan(
plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True
)
self.assertTrue(exec_plan.successful)
def test_manifest_only_data_sink(self):
with open("tests/data/long_path_list.txt", buffering=16 * MB) as fin:
filenames = list(map(str.strip, fin.readlines()))
logger.info(f"loaded {len(filenames)} filenames")
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ctx = Context()
dataset = ParquetDataSet(filenames)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=512)
data_sink = DataSinkNode(
ctx, (data_partitions,), output_path=output_dir, manifest_only=True
)
plan = LogicalPlan(ctx, data_sink)
self.execute_plan(plan)
with open(
os.path.join(output_dir, DataSinkTask.manifest_filename),
buffering=16 * MB,
) as fin:
num_lines = len(fin.readlines())
self.assertEqual(len(filenames), num_lines)
def test_sql_batched_processing(self):
for materialize_in_memory in (False, True):
with self.subTest(materialize_in_memory=materialize_in_memory):
ctx = Context()
dataset = ParquetDataSet(["tests/data/large_array/*.parquet"] * 2)
data_files = DataSourceNode(ctx, dataset)
content_length = SqlEngineNode(
ctx,
(data_files,),
r"select url, octet_length(content) as content_len from {0}",
materialize_in_memory=materialize_in_memory,
batched_processing=True,
cpu_limit=2,
memory_limit=2 * GB,
)
plan = LogicalPlan(ctx, content_length)
final_output: ParquetDataSet = self.execute_plan(plan).final_output
self.assertEqual(dataset.num_rows, final_output.num_rows)
def test_multiple_sql_queries(self):
for materialize_in_memory in (False, True):
with self.subTest(materialize_in_memory=materialize_in_memory):
ctx = Context()
dataset = ParquetDataSet(["tests/data/large_array/*.parquet"] * 2)
data_files = DataSourceNode(ctx, dataset)
content_length = SqlEngineNode(
ctx,
(data_files,),
[
r"create or replace temp table content_len_data as select url, octet_length(content) as content_len from {0}",
r"select * from content_len_data",
],
materialize_in_memory=materialize_in_memory,
batched_processing=True,
cpu_limit=2,
memory_limit=2 * GB,
)
plan = LogicalPlan(ctx, content_length)
final_output: ParquetDataSet = self.execute_plan(plan).final_output
self.assertEqual(dataset.num_rows, final_output.num_rows)
def test_temp_outputs_in_final_results(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
url_counts = SqlEngineNode(
ctx, (data_partitions,), r"select count(url) as cnt from {0}"
)
distinct_url_counts = SqlEngineNode(
ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}"
)
merged_counts = DataSetPartitionNode(
ctx,
(
ProjectionNode(ctx, url_counts, ["cnt"]),
ProjectionNode(ctx, distinct_url_counts, ["cnt"]),
),
npartitions=1,
)
split_counts = DataSetPartitionNode(ctx, (merged_counts,), npartitions=10)
plan = LogicalPlan(ctx, split_counts)
final_output: ParquetDataSet = self.execute_plan(plan).final_output
self.assertEqual(data_partitions.npartitions * 2, final_output.num_rows)
def test_override_output_path(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
url_counts = SqlEngineNode(
ctx,
(data_partitions,),
r"select count(url) as cnt from {0}",
output_name="url_counts",
)
distinct_url_counts = SqlEngineNode(
ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}"
)
merged_counts = DataSetPartitionNode(
ctx,
(
ProjectionNode(ctx, url_counts, ["cnt"]),
ProjectionNode(ctx, distinct_url_counts, ["cnt"]),
),
npartitions=1,
)
plan = LogicalPlan(ctx, merged_counts)
output_path = os.path.join(self.runtime_ctx.output_root, "final_output")
final_output = self.execute_plan(plan, output_path=output_path).final_output
self.assertTrue(os.path.exists(os.path.join(output_path, "url_counts")))
self.assertTrue(os.path.exists(os.path.join(output_path, "FinalResults")))
def test_data_sink_avoid_filename_conflicts(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
empty_files1 = PythonScriptNode(
ctx, (data_partitions,), process_func=empty_file
)
empty_files2 = PythonScriptNode(
ctx, (data_partitions,), process_func=empty_file
)
link_path = os.path.join(self.runtime_ctx.output_root, "link")
copy_path = os.path.join(self.runtime_ctx.output_root, "copy")
copy_input_path = os.path.join(self.runtime_ctx.output_root, "copy_input")
data_link = DataSinkNode(
ctx, (empty_files1, empty_files2), type="link", output_path=link_path
)
data_copy = DataSinkNode(
ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path
)
data_copy_input = DataSinkNode(
ctx, (data_partitions,), type="copy", output_path=copy_input_path
)
plan = LogicalPlan(ctx, RootNode(ctx, (data_link, data_copy, data_copy_input)))
self.execute_plan(plan)
# there should be 21 files (20 input files + 1 manifest file) in the sink dir
self.assertEqual(21, len(os.listdir(link_path)))
self.assertEqual(21, len(os.listdir(copy_path)))
# file name should not be modified if no conflict
self.assertEqual(
set(
filename
for filename in os.listdir("tests/data/mock_urls")
if filename.endswith(".parquet")
),
set(
filename
for filename in os.listdir(copy_input_path)
if filename.endswith(".parquet")
),
)
def test_literal_datasets_as_data_sources(self):
ctx = Context()
num_rows = 10
query_dataset = SqlQueryDataSet(f"select i from range({num_rows}) as x(i)")
table_dataset = ArrowTableDataSet(
arrow.Table.from_arrays([list(range(num_rows))], names=["i"])
)
query_source = DataSourceNode(ctx, query_dataset)
table_source = DataSourceNode(ctx, table_dataset)
query_partitions = DataSetPartitionNode(
ctx, (query_source,), npartitions=num_rows, partition_by_rows=True
)
table_partitions = DataSetPartitionNode(
ctx, (table_source,), npartitions=num_rows, partition_by_rows=True
)
joined_rows = SqlEngineNode(
ctx,
(query_partitions, table_partitions),
r"select a.i as i, b.i as j from {0} as a join {1} as b on a.i = b.i",
)
plan = LogicalPlan(ctx, joined_rows)
final_output: ParquetDataSet = self.execute_plan(plan).final_output
self.assertEqual(num_rows, final_output.num_rows)
def test_partial_process_func(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=3)
# use default value of k
random_urls_k5 = ArrowComputeNode(
ctx,
(data_partitions,),
process_func=choose_random_urls,
output_name="random_urls_k5",
)
# set value of k using functools.partial
random_urls_k10 = ArrowComputeNode(
ctx,
(data_partitions,),
process_func=functools.partial(choose_random_urls, k=10),
output_name="random_urls_k10",
)
random_urls_all = SqlEngineNode(
ctx,
(random_urls_k5, random_urls_k10),
r"select * from {0} union select * from {1}",
output_name="random_urls_all",
)
plan = LogicalPlan(ctx, random_urls_all)
exec_plan = self.execute_plan(plan)
self.assertEqual(
data_partitions.npartitions * 5,
exec_plan.get_output("random_urls_k5").to_arrow_table().num_rows,
)
self.assertEqual(
data_partitions.npartitions * 10,
exec_plan.get_output("random_urls_k10").to_arrow_table().num_rows,
)

294
tests/test_fabric.py Normal file
View File

@@ -0,0 +1,294 @@
import os.path
import queue
import sys
import unittest
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from multiprocessing import Manager, Process
from typing import List, Optional
import fsspec
import numpy as np
import psutil
import pyarrow as arrow
import pyarrow.compute as pc
import pyarrow.parquet as parquet
from loguru import logger
from smallpond.common import DEFAULT_MAX_FAIL_COUNT, DEFAULT_MAX_RETRY_COUNT, GB, MB
from smallpond.execution.executor import Executor
from smallpond.execution.scheduler import Scheduler
from smallpond.execution.task import ExecutionPlan, JobId, RuntimeContext
from smallpond.io.arrow import cast_columns_to_large_string
from smallpond.logical.node import LogicalPlan
from smallpond.logical.planner import Planner
from tests.datagen import generate_data
generate_data()
def run_scheduler(
runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue
):
runtime_ctx.initialize("scheduler")
scheduler.add_state_observer(Scheduler.StateObserver(SaveSchedState(queue)))
retval = scheduler.run()
print(f"scheduler exited with value {retval}", file=sys.stderr)
def run_executor(runtime_ctx: RuntimeContext, executor: Executor):
runtime_ctx.initialize(executor.id)
retval = executor.run()
print(f"{executor.id} exited with value {retval}", file=sys.stderr)
class SaveSchedState:
"""
A state observer that push the scheduler state into a queue when finished.
"""
def __init__(self, queue: queue.Queue):
self.queue = queue
def __call__(self, sched_state: Scheduler) -> bool:
if sched_state.num_local_running_works == 0:
self.queue.put(sched_state)
return True
class TestFabric(unittest.TestCase):
"""
A helper class that includes boilerplate code to test a logical plan.
"""
runtime_root = os.getenv("TEST_RUNTIME_ROOT") or f"tests/runtime"
runtime_ctx = None
fault_inject_prob = 0.00
queue_manager = None
sched_states: queue.Queue = None
latest_state: Scheduler = None
executors: List[Executor] = None
processes: List[Process] = None
@property
def output_dir(self):
return os.path.join(self.__class__.__name__, self._testMethodName)
@property
def output_root_abspath(self):
output_root = os.path.abspath(os.path.join(self.runtime_root, self.output_dir))
os.makedirs(output_root, exist_ok=True)
return output_root
def setUp(self) -> None:
try:
from pytest_cov.embed import cleanup_on_sigterm
except ImportError:
pass
else:
cleanup_on_sigterm()
self.runtime_ctx = RuntimeContext(
JobId.new(),
datetime.now(),
self.output_root_abspath,
console_log_level="WARNING",
)
self.runtime_ctx.initialize("setup")
return super().setUp()
def tearDown(self) -> None:
if self.sched_states is not None:
self.get_latest_sched_state()
assert self.sched_states.qsize() == 0
self.sched_states = None
if self.queue_manager is not None:
self.queue_manager.shutdown()
self.queue_manager = None
return super().tearDown()
def get_latest_sched_state(self) -> Scheduler:
while True:
try:
self.latest_state = self.sched_states.get(block=False)
except queue.Empty:
return self.latest_state
def join_running_procs(self, timeout=30):
for i, process in enumerate(self.processes):
if process.is_alive():
logger.info(f"join #{i} process: {process.name}")
process.join(timeout=None if i == 0 else timeout)
if process.exitcode is None:
logger.info(f"terminate #{i} process: {process.name}")
process.terminate()
process.join(timeout=timeout)
if process.exitcode is None:
logger.info(f"kill #{i} process: {process.name}")
process.kill()
process.join()
logger.info(
f"#{i} process {process.name} exited with code {process.exitcode}"
)
def start_execution(
self,
plan: LogicalPlan,
num_executors: int = 2,
secs_wq_poll_interval: float = 0.1,
secs_executor_probe_interval: float = 1,
max_num_missed_probes: int = 10,
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
max_fail_count: int = DEFAULT_MAX_FAIL_COUNT,
prioritize_retry=False,
speculative_exec="enable",
stop_executor_on_failure=False,
enforce_memory_limit=False,
nonzero_exitcode_as_oom=False,
fault_inject_prob=None,
enable_profiling=False,
enable_diagnostic_metrics=False,
remove_empty_parquet=False,
skip_task_with_empty_input=False,
console_log_level="WARNING",
file_log_level="DEBUG",
output_path: Optional[str] = None,
runtime_ctx: Optional[RuntimeContext] = None,
):
"""
Start a scheduler and `num_executors` executors to execute `plan`.
When this function returns, the execution is mostly still running.
Parameters
----------
plan
A logical plan.
num_executors, optional
The number of executors
console_log_level, optional
Set to logger.INFO if more verbose loguru is needed for debug, by default "WARNING".
Returns
-------
A 3-tuple of type (Scheduler, List[Executor], List[Process]).
"""
if runtime_ctx is None:
runtime_ctx = RuntimeContext(
JobId.new(),
datetime.now(),
self.output_root_abspath,
num_executors=num_executors,
random_seed=123456,
enforce_memory_limit=enforce_memory_limit,
max_usable_cpu_count=min(64, psutil.cpu_count(logical=False)),
max_usable_gpu_count=0,
max_usable_memory_size=min(64 * GB, psutil.virtual_memory().total),
secs_wq_poll_interval=secs_wq_poll_interval,
secs_executor_probe_interval=secs_executor_probe_interval,
max_num_missed_probes=max_num_missed_probes,
fault_inject_prob=(
fault_inject_prob
if fault_inject_prob is not None
else self.fault_inject_prob
),
enable_profiling=enable_profiling,
enable_diagnostic_metrics=enable_diagnostic_metrics,
remove_empty_parquet=remove_empty_parquet,
skip_task_with_empty_input=skip_task_with_empty_input,
console_log_level=console_log_level,
file_log_level=file_log_level,
output_path=output_path,
)
self.queue_manager = Manager()
self.sched_states = self.queue_manager.Queue()
exec_plan = Planner(runtime_ctx).create_exec_plan(plan)
scheduler = Scheduler(
exec_plan,
max_retry_count=max_retry_count,
max_fail_count=max_fail_count,
prioritize_retry=prioritize_retry,
speculative_exec=speculative_exec,
stop_executor_on_failure=stop_executor_on_failure,
nonzero_exitcode_as_oom=nonzero_exitcode_as_oom,
)
self.latest_state = scheduler
self.executors = [
Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)
]
self.processes = [
Process(
target=run_scheduler,
# XXX: on macOS, scheduler state observer will be cleared when cross-process
# so we pass the queue and add the observer in the new process
args=(runtime_ctx, scheduler, self.sched_states),
name="scheduler",
)
]
self.processes += [
Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id)
for executor in self.executors
]
for process in reversed(self.processes):
process.start()
return self.sched_states, self.executors, self.processes
def execute_plan(self, *args, check_result=True, **kvargs) -> ExecutionPlan:
"""
Start a scheduler and `num_executors` executors to execute `plan`,
and wait the execution completed, then assert if it succeeds.
Parameters
----------
plan
A logical plan.
num_executors, optional
The number of executors
console_log_level, optional
Set to logger.INFO if more verbose loguru is needed for debug, by default "WARNING".
Returns
-------
The completed ExecutionPlan instance.
"""
self.start_execution(*args, **kvargs)
self.join_running_procs()
latest_state = self.get_latest_sched_state()
if check_result:
self.assertTrue(latest_state.success)
return latest_state.exec_plan
def _load_parquet_files(
self, paths, filesystem: fsspec.AbstractFileSystem = None
) -> arrow.Table:
def read_parquet_file(path):
return arrow.Table.from_batches(
parquet.ParquetFile(
path, buffer_size=16 * MB, filesystem=filesystem
).iter_batches()
)
with ThreadPoolExecutor(16) as pool:
return arrow.concat_tables(pool.map(read_parquet_file, paths))
def _compare_arrow_tables(self, expected: arrow.Table, actual: arrow.Table):
def sorted_table(t: arrow.Table):
return t.sort_by([(col, "ascending") for col in t.column_names])
self.assertEqual(expected.shape, actual.shape)
self.assertEqual(expected.column_names, actual.column_names)
expected = sorted_table(cast_columns_to_large_string(expected))
actual = sorted_table(cast_columns_to_large_string(actual))
for col, x, y in zip(expected.column_names, expected.columns, actual.columns):
if not pc.equal(x, y):
x = x.to_numpy(zero_copy_only=False)
y = y.to_numpy(zero_copy_only=False)
logger.error(f" expect {col}: {x}")
logger.error(f" actual {col}: {y}")
np.testing.assert_array_equal(x, y, verbose=True)

25
tests/test_filesystem.py Normal file
View File

@@ -0,0 +1,25 @@
import os.path
import tempfile
import threading
import unittest
from smallpond.io.filesystem import dump, load
from tests.test_fabric import TestFabric
class TestFilesystem(TestFabric, unittest.TestCase):
def test_pickle_runtime_ctx(self):
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
pickle_path = os.path.join(output_dir, "runtime_ctx.pickle")
dump(self.runtime_ctx, pickle_path)
runtime_ctx = load(pickle_path)
self.assertEqual(self.runtime_ctx.job_id, runtime_ctx.job_id)
def test_pickle_trace(self):
with self.assertRaises(TypeError) as context:
with tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
thread = threading.Thread()
pickle_path = os.path.join(output_dir, "thread.pickle")
dump(thread, pickle_path)

103
tests/test_logical.py Normal file
View File

@@ -0,0 +1,103 @@
import unittest
from loguru import logger
from smallpond.logical.dataset import ParquetDataSet
from smallpond.logical.node import (
Context,
DataSetPartitionNode,
DataSourceNode,
EvenlyDistributedPartitionNode,
HashPartitionNode,
LogicalPlan,
SqlEngineNode,
)
from smallpond.logical.planner import Planner
from tests.test_fabric import TestFabric
class TestLogicalPlan(TestFabric, unittest.TestCase):
def test_join_chunkmeta_inodes(self):
ctx = Context()
chunkmeta_dump = DataSourceNode(
ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"])
)
chunkmeta_partitions = HashPartitionNode(
ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"]
)
inodes_dump = DataSourceNode(
ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"])
)
inodes_partitions = HashPartitionNode(
ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"]
)
num_gc_chunks = SqlEngineNode(
ctx,
(chunkmeta_partitions, inodes_partitions),
r"""
select count(chunkmeta_chunkId) from {0}
where chunkmeta.chunkmeta_chunkId NOT LIKE "F%" AND
chunkmeta.inodeId not in ( select distinct inode_id from {1} )""",
)
plan = LogicalPlan(ctx, num_gc_chunks)
logger.info(str(plan))
exec_plan = Planner(self.runtime_ctx).create_exec_plan(plan)
logger.info(str(exec_plan))
def test_partition_dims_not_compatible(self):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
partition_dim_a = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A"
)
partition_dim_b = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B"
)
join_two_inputs = SqlEngineNode(
ctx,
(partition_dim_a, partition_dim_b),
r"select a.* from {0} as a join {1} as b on a.host = b.host",
)
plan = LogicalPlan(ctx, join_two_inputs)
logger.info(str(plan))
with self.assertRaises(AssertionError) as context:
Planner(self.runtime_ctx).create_exec_plan(plan)
def test_npartitions_not_compatible(self):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
partition_dim_a = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A"
)
partition_dim_a2 = EvenlyDistributedPartitionNode(
ctx,
(data_source,),
npartitions=parquet_dataset.num_files * 2,
dimension="A",
)
join_two_inputs1 = SqlEngineNode(
ctx,
(partition_dim_a, partition_dim_a2),
r"select a.* from {0} as a join {1} as b on a.host = b.host",
)
join_two_inputs2 = SqlEngineNode(
ctx,
(partition_dim_a2, partition_dim_a),
r"select a.* from {0} as a join {1} as b on a.host = b.host",
)
plan = LogicalPlan(
ctx,
DataSetPartitionNode(
ctx, (join_two_inputs1, join_two_inputs2), npartitions=1
),
)
logger.info(str(plan))
with self.assertRaises(AssertionError) as context:
Planner(self.runtime_ctx).create_exec_plan(plan)

659
tests/test_partition.py Normal file
View File

@@ -0,0 +1,659 @@
import os.path
import tempfile
import unittest
from typing import List
import pyarrow.compute as pc
from smallpond.common import DATA_PARTITION_COLUMN_NAME, GB
from smallpond.execution.task import RuntimeContext
from smallpond.logical.dataset import DataSet, ParquetDataSet
from smallpond.logical.node import (
ArrowComputeNode,
ConsolidateNode,
Context,
DataSetPartitionNode,
DataSinkNode,
DataSourceNode,
EvenlyDistributedPartitionNode,
HashPartitionNode,
LoadPartitionedDataSetNode,
LogicalPlan,
ProjectionNode,
SqlEngineNode,
UnionNode,
UserDefinedPartitionNode,
UserPartitionedDataSourceNode,
)
from tests.test_execution import parse_url
from tests.test_fabric import TestFabric
class CalculatePartitionFromFilename(UserDefinedPartitionNode):
def partition(self, runtime_ctx: RuntimeContext, dataset: DataSet) -> List[DataSet]:
partitioned_datasets: List[ParquetDataSet] = [
ParquetDataSet([]) for _ in range(self.npartitions)
]
for path in dataset.resolved_paths:
partition_idx = hash(path) % self.npartitions
partitioned_datasets[partition_idx].paths.append(path)
return partitioned_datasets
class TestPartition(TestFabric, unittest.TestCase):
def test_many_file_partitions(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files
)
count_rows = SqlEngineNode(
ctx,
(data_partitions,),
"select count(*) from {0}",
cpu_limit=1,
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, count_rows)
self.execute_plan(plan)
def test_many_row_partitions(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True
)
count_rows = SqlEngineNode(
ctx,
(data_partitions,),
"select count(*) from {0}",
cpu_limit=1,
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, count_rows)
exec_plan = self.execute_plan(plan, num_executors=5)
self.assertEqual(
exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows
)
def test_empty_dataset_partition(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
# create more partitions than files
data_partitions = EvenlyDistributedPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files * 2
)
data_partitions.max_num_producer_tasks = 3
unique_urls = SqlEngineNode(
ctx,
(data_partitions,),
r"select distinct url from {0}",
cpu_limit=1,
memory_limit=1 * GB,
)
# nested partition
nested_partitioned_urls = EvenlyDistributedPartitionNode(
ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True
)
parsed_urls = ArrowComputeNode(
ctx,
(nested_partitioned_urls,),
process_func=parse_url,
cpu_limit=1,
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, parsed_urls)
final_output = self.execute_plan(
plan, remove_empty_parquet=True, skip_task_with_empty_input=True
).final_output
self.assertTrue(isinstance(final_output, ParquetDataSet))
self.assertEqual(dataset.num_rows, final_output.num_rows)
def test_hash_partition(self):
for engine_type in ("duckdb", "arrow"):
for partition_by_rows in (False, True):
for hive_partitioning in (
(False, True) if engine_type == "duckdb" else (False,)
):
with self.subTest(
engine_type=engine_type,
partition_by_rows=partition_by_rows,
hive_partitioning=hive_partitioning,
):
ctx = Context()
dataset = ParquetDataSet(["tests/data/arrow/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
npartitions = 3
data_partitions = DataSetPartitionNode(
ctx,
(data_files,),
npartitions=npartitions,
partition_by_rows=partition_by_rows,
)
hash_partitions = HashPartitionNode(
ctx,
(ProjectionNode(ctx, data_partitions, ["url"]),),
npartitions=npartitions,
hash_columns=["url"],
engine_type=engine_type,
hive_partitioning=hive_partitioning,
cpu_limit=2,
memory_limit=2 * GB,
output_name="hash_partitions",
)
row_count = SqlEngineNode(
ctx,
(hash_partitions,),
r"select count(*) as row_count from {0}",
cpu_limit=1,
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, row_count)
exec_plan = self.execute_plan(plan)
self.assertEqual(
dataset.num_rows,
pc.sum(
exec_plan.final_output.to_arrow_table().column(
"row_count"
)
).as_py(),
)
self.assertEqual(
npartitions,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions, DATA_PARTITION_COLUMN_NAME
)
),
)
self.assertEqual(
npartitions,
len(
exec_plan.get_output(
"hash_partitions"
).load_partitioned_datasets(
npartitions,
DATA_PARTITION_COLUMN_NAME,
hive_partitioning,
)
),
)
def test_empty_hash_partition(self):
for engine_type in ("duckdb", "arrow"):
for partition_by_rows in (False, True):
for hive_partitioning in (
(False, True) if engine_type == "duckdb" else (False,)
):
with self.subTest(
engine_type=engine_type,
partition_by_rows=partition_by_rows,
hive_partitioning=hive_partitioning,
):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
npartitions = 3
npartitions_nested = 4
num_rows = 1
head_rows = SqlEngineNode(
ctx, (data_files,), f"select * from {{0}} limit {num_rows}"
)
data_partitions = DataSetPartitionNode(
ctx,
(head_rows,),
npartitions=npartitions,
partition_by_rows=partition_by_rows,
)
hash_partitions = HashPartitionNode(
ctx,
(data_partitions,),
npartitions=npartitions,
hash_columns=["url"],
data_partition_column="hash_partitions",
engine_type=engine_type,
hive_partitioning=hive_partitioning,
output_name="hash_partitions",
cpu_limit=2,
memory_limit=1 * GB,
)
nested_hash_partitions = HashPartitionNode(
ctx,
(hash_partitions,),
npartitions=npartitions_nested,
hash_columns=["url"],
data_partition_column="nested_hash_partitions",
nested=True,
engine_type=engine_type,
hive_partitioning=hive_partitioning,
output_name="nested_hash_partitions",
cpu_limit=2,
memory_limit=1 * GB,
)
select_every_row = SqlEngineNode(
ctx,
(nested_hash_partitions,),
r"select * from {0}",
cpu_limit=1,
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, select_every_row)
exec_plan = self.execute_plan(
plan, skip_task_with_empty_input=True
)
self.assertEqual(num_rows, exec_plan.final_output.num_rows)
self.assertEqual(
npartitions,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions, "hash_partitions"
)
),
)
self.assertEqual(
npartitions_nested,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions_nested, "nested_hash_partitions"
)
),
)
self.assertEqual(
npartitions,
len(
exec_plan.get_output(
"hash_partitions"
).load_partitioned_datasets(
npartitions, "hash_partitions"
)
),
)
self.assertEqual(
npartitions_nested,
len(
exec_plan.get_output(
"nested_hash_partitions"
).load_partitioned_datasets(
npartitions_nested, "nested_hash_partitions"
)
),
)
if hive_partitioning:
self.assertEqual(
npartitions,
len(
exec_plan.get_output(
"hash_partitions"
).load_partitioned_datasets(
npartitions,
"hash_partitions",
hive_partitioning=True,
)
),
)
self.assertEqual(
npartitions_nested,
len(
exec_plan.get_output(
"nested_hash_partitions"
).load_partitioned_datasets(
npartitions_nested,
"nested_hash_partitions",
hive_partitioning=True,
)
),
)
def test_load_partitioned_datasets(self):
def run_test_plan(
npartitions: int,
data_partition_column: str,
engine_type: str,
hive_partitioning: bool,
):
ctx = Context()
input_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
input_data_files = DataSourceNode(ctx, input_dataset)
# create hash partitions
input_partitions = HashPartitionNode(
ctx,
(input_data_files,),
npartitions=npartitions,
hash_columns=["url"],
data_partition_column=data_partition_column,
engine_type=engine_type,
hive_partitioning=hive_partitioning,
output_name="input_partitions",
cpu_limit=1,
memory_limit=1 * GB,
)
split_urls = SqlEngineNode(
ctx,
(input_partitions,),
f"select url, string_split(url, '/')[0] as host from {{0}}",
cpu_limit=1,
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, split_urls)
exec_plan = self.execute_plan(plan)
self.assertEqual(
npartitions,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions, data_partition_column
)
),
)
self.assertEqual(
npartitions,
len(
exec_plan.get_output("input_partitions").load_partitioned_datasets(
npartitions, data_partition_column, hive_partitioning
)
),
)
return exec_plan
npartitions = 5
data_partition_column = "_human_readable_column_name_"
for engine_type in ("duckdb", "arrow"):
with self.subTest(engine_type=engine_type):
exec_plan1 = run_test_plan(
npartitions,
data_partition_column,
engine_type,
hive_partitioning=engine_type == "duckdb",
)
exec_plan2 = run_test_plan(
npartitions,
data_partition_column,
engine_type,
hive_partitioning=False,
)
ctx = Context()
output1 = DataSourceNode(
ctx, dataset=exec_plan1.get_output("input_partitions")
)
output2 = DataSourceNode(
ctx, dataset=exec_plan2.get_output("input_partitions")
)
split_urls1 = LoadPartitionedDataSetNode(
ctx,
(output1,),
npartitions=npartitions,
data_partition_column=data_partition_column,
hive_partitioning=engine_type == "duckdb",
)
split_urls2 = LoadPartitionedDataSetNode(
ctx,
(output2,),
npartitions=npartitions,
data_partition_column=data_partition_column,
hive_partitioning=False,
)
split_urls3 = SqlEngineNode(
ctx,
(split_urls1, split_urls2),
f"""
select split_urls1.url, string_split(split_urls2.url, '/')[0] as host
from {{0}} as split_urls1
join {{1}} as split_urls2
on split_urls1.url = split_urls2.url
""",
cpu_limit=1,
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, split_urls3)
exec_plan3 = self.execute_plan(plan)
# load each partition as arrow table and compare
final_output_partitions1 = (
exec_plan1.final_output.load_partitioned_datasets(
npartitions, data_partition_column
)
)
final_output_partitions3 = (
exec_plan3.final_output.load_partitioned_datasets(
npartitions, data_partition_column
)
)
self.assertEqual(npartitions, len(final_output_partitions3))
for x, y in zip(final_output_partitions1, final_output_partitions3):
self._compare_arrow_tables(x.to_arrow_table(), y.to_arrow_table())
def test_nested_partition(self):
ctx = Context()
parquet_files = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_files)
SqlEngineNode.default_cpu_limit = 1
SqlEngineNode.default_memory_limit = 1 * GB
initial_reduce = r"select host, count(*) as cnt from {0} group by host"
combine_reduce_results = (
r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host"
)
join_query = r"select host, cnt from {0} where (exists (select * from {1} where {1}.host = {0}.host)) and (exists (select * from {2} where {2}.host = {0}.host))"
partition_by_hosts = HashPartitionNode(
ctx,
(data_source,),
npartitions=3,
hash_columns=["host"],
data_partition_column="host_partition",
)
partition_by_hosts_x_urls = HashPartitionNode(
ctx,
(partition_by_hosts,),
npartitions=5,
hash_columns=["url"],
data_partition_column="url_partition",
nested=True,
)
url_count_by_hosts_x_urls1 = SqlEngineNode(
ctx,
(partition_by_hosts_x_urls,),
initial_reduce,
output_name="url_count_by_hosts_x_urls1",
)
url_count_by_hosts1 = SqlEngineNode(
ctx,
(ConsolidateNode(ctx, url_count_by_hosts_x_urls1, ["host_partition"]),),
combine_reduce_results,
output_name="url_count_by_hosts1",
)
join_count_by_hosts_x_urls1 = SqlEngineNode(
ctx,
(url_count_by_hosts_x_urls1, url_count_by_hosts1, data_source),
join_query,
output_name="join_count_by_hosts_x_urls1",
)
partitioned_urls = LoadPartitionedDataSetNode(
ctx,
(partition_by_hosts_x_urls,),
data_partition_column="url_partition",
npartitions=5,
)
partitioned_hosts_x_urls = LoadPartitionedDataSetNode(
ctx,
(partitioned_urls,),
data_partition_column="host_partition",
npartitions=3,
nested=True,
)
partitioned_3dims = EvenlyDistributedPartitionNode(
ctx,
(partitioned_hosts_x_urls,),
npartitions=2,
dimension="inner_partition",
partition_by_rows=True,
nested=True,
)
url_count_by_3dims = SqlEngineNode(ctx, (partitioned_3dims,), initial_reduce)
url_count_by_hosts_x_urls2 = SqlEngineNode(
ctx,
(
ConsolidateNode(
ctx, url_count_by_3dims, ["host_partition", "url_partition"]
),
),
combine_reduce_results,
output_name="url_count_by_hosts_x_urls2",
)
url_count_by_hosts2 = SqlEngineNode(
ctx,
(ConsolidateNode(ctx, url_count_by_hosts_x_urls2, ["host_partition"]),),
combine_reduce_results,
output_name="url_count_by_hosts2",
)
url_count_by_hosts_expected = SqlEngineNode(
ctx,
(data_source,),
initial_reduce,
per_thread_output=False,
output_name="url_count_by_hosts_expected",
)
join_count_by_hosts_x_urls2 = SqlEngineNode(
ctx,
(url_count_by_hosts_x_urls2, url_count_by_hosts2, data_source),
join_query,
output_name="join_count_by_hosts_x_urls2",
)
union_url_count_by_hosts = UnionNode(
ctx, (url_count_by_hosts1, url_count_by_hosts2)
)
union_url_count_by_hosts_x_urls = UnionNode(
ctx,
(
url_count_by_hosts_x_urls1,
url_count_by_hosts_x_urls2,
join_count_by_hosts_x_urls1,
join_count_by_hosts_x_urls2,
),
)
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
data_sink = DataSinkNode(
ctx,
(
url_count_by_hosts_expected,
union_url_count_by_hosts,
union_url_count_by_hosts_x_urls,
),
output_path=output_dir,
manifest_only=True,
)
plan = LogicalPlan(ctx, data_sink)
exec_plan = self.execute_plan(plan, remove_empty_parquet=True)
# verify results
self._compare_arrow_tables(
exec_plan.get_output("url_count_by_hosts_x_urls1").to_arrow_table(),
exec_plan.get_output("url_count_by_hosts_x_urls2").to_arrow_table(),
)
self._compare_arrow_tables(
exec_plan.get_output("join_count_by_hosts_x_urls1").to_arrow_table(),
exec_plan.get_output("join_count_by_hosts_x_urls2").to_arrow_table(),
)
self._compare_arrow_tables(
exec_plan.get_output("url_count_by_hosts_x_urls1").to_arrow_table(),
exec_plan.get_output("join_count_by_hosts_x_urls1").to_arrow_table(),
)
self._compare_arrow_tables(
exec_plan.get_output("url_count_by_hosts1").to_arrow_table(),
exec_plan.get_output("url_count_by_hosts2").to_arrow_table(),
)
self._compare_arrow_tables(
exec_plan.get_output("url_count_by_hosts_expected").to_arrow_table(),
exec_plan.get_output("url_count_by_hosts1").to_arrow_table(),
)
def test_user_defined_partition(self):
ctx = Context()
parquet_files = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_files)
file_partitions1 = CalculatePartitionFromFilename(
ctx, (data_source,), npartitions=3, dimension="by_filename_hash1"
)
url_count1 = SqlEngineNode(
ctx,
(file_partitions1,),
r"select host, count(*) as cnt from {0} group by host",
output_name="url_count1",
)
file_partitions2 = CalculatePartitionFromFilename(
ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2"
)
url_count2 = SqlEngineNode(
ctx,
(file_partitions2,),
r"select host, cnt from {0}",
output_name="url_count2",
)
plan = LogicalPlan(ctx, url_count2)
exec_plan = self.execute_plan(plan, enable_diagnostic_metrics=True)
self._compare_arrow_tables(
exec_plan.get_output("url_count1").to_arrow_table(),
exec_plan.get_output("url_count2").to_arrow_table(),
)
def test_user_partitioned_data_source(self):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
evenly_dist_data_source = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files
)
parquet_datasets = [ParquetDataSet([p]) for p in parquet_dataset.resolved_paths]
partitioned_data_source = UserPartitionedDataSourceNode(ctx, parquet_datasets)
url_count_by_host1 = SqlEngineNode(
ctx,
(evenly_dist_data_source,),
r"select host, count(*) as cnt from {0} group by host",
output_name="url_count_by_host1",
cpu_limit=1,
memory_limit=1 * GB,
)
url_count_by_host2 = SqlEngineNode(
ctx,
(evenly_dist_data_source, partitioned_data_source),
r"select {1}.host, count(*) as cnt from {0} join {1} on {0}.host = {1}.host group by {1}.host",
output_name="url_count_by_host2",
cpu_limit=1,
memory_limit=1 * GB,
)
plan = LogicalPlan(
ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2])
)
exec_plan = self.execute_plan(plan, enable_diagnostic_metrics=True)
self._compare_arrow_tables(
exec_plan.get_output("url_count_by_host1").to_arrow_table(),
exec_plan.get_output("url_count_by_host2").to_arrow_table(),
)
def test_partition_info_in_sql_query(self):
"""
User can refer to the partition info in the SQL query.
"""
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
evenly_dist_data_source = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files
)
sql_query = SqlEngineNode(
ctx,
(evenly_dist_data_source,),
r"select host, {__data_partition__} as partition_info from {0}",
)
plan = LogicalPlan(ctx, sql_query)
exec_plan = self.execute_plan(plan)

70
tests/test_plan.py Normal file
View File

@@ -0,0 +1,70 @@
import os
import tempfile
import unittest
from examples.fstest import fstest
from examples.shuffle_data import shuffle_data
from examples.shuffle_mock_urls import shuffle_mock_urls
from examples.sort_mock_urls import sort_mock_urls
from examples.sort_mock_urls_v2 import sort_mock_urls_v2
from smallpond.dataframe import Session
from tests.test_fabric import TestFabric
class TestPlan(TestFabric, unittest.TestCase):
def test_sort_mock_urls(self):
for engine_type in ("duckdb", "arrow"):
with self.subTest(engine_type=engine_type):
plan = sort_mock_urls(
["tests/data/mock_urls/*.tsv"],
npartitions=3,
engine_type=engine_type,
)
self.execute_plan(plan)
def test_sort_mock_urls_external_output_path(self):
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
plan = sort_mock_urls(
["tests/data/mock_urls/*.tsv"],
npartitions=3,
external_output_path=output_dir,
)
self.execute_plan(plan)
def test_shuffle_mock_urls(self):
for engine_type in ("duckdb", "arrow"):
with self.subTest(engine_type=engine_type):
plan = shuffle_mock_urls(
["tests/data/mock_urls/*.parquet"],
npartitions=3,
sort_rand_keys=True,
)
self.execute_plan(plan)
def test_shuffle_data(self):
for engine_type in ("duckdb", "arrow"):
with self.subTest(engine_type=engine_type):
plan = shuffle_data(
["tests/data/mock_urls/*.parquet"],
num_data_partitions=3,
num_out_data_partitions=3,
engine_type=engine_type,
)
self.execute_plan(plan)
def test_fstest(sp: Session):
path = sp._runtime_ctx.output_root
fstest(
sp,
input_path=os.path.join(path, "*"),
output_path=path,
size="10M",
npartitions=3,
)
def test_sort_mock_urls_v2(sp: Session):
sort_mock_urls_v2(
sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3
)

180
tests/test_scheduler.py Normal file
View File

@@ -0,0 +1,180 @@
import os.path
import random
import time
import unittest
from typing import List, Tuple
from loguru import logger
from smallpond.execution.scheduler import ExecutorState
from smallpond.execution.task import PythonScriptTask, RuntimeContext
from smallpond.logical.dataset import DataSet, ParquetDataSet
from smallpond.logical.node import (
Context,
DataSetPartitionNode,
DataSourceNode,
LogicalPlan,
Node,
PythonScriptNode,
)
from tests.test_fabric import TestFabric
class RandomSleepTask(PythonScriptTask):
def __init__(
self, *args, sleep_secs: float, fail_first_try: bool, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.sleep_secs = sleep_secs
self.fail_first_try = fail_first_try
def process(
self,
runtime_ctx: RuntimeContext,
input_datasets: List[DataSet],
output_path: str,
) -> bool:
logger.info(f"sleeping {self.sleep_secs} secs")
time.sleep(self.sleep_secs)
with open(os.path.join(output_path, self.output_filename), "w") as fout:
fout.write(f"{repr(self)}")
if self.fail_first_try and self.retry_count == 0:
return False
return True
class RandomSleepNode(PythonScriptNode):
def __init__(
self,
ctx: Context,
input_deps: Tuple[Node, ...],
*,
max_sleep_secs=5,
fail_first_try=False,
**kwargs,
):
super().__init__(ctx, input_deps, **kwargs)
self.max_sleep_secs = max_sleep_secs
self.fail_first_try = fail_first_try
def spawn(self, *args, **kwargs) -> RandomSleepTask:
sleep_secs = (
random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs
)
return RandomSleepTask(
*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try
)
class TestScheduler(TestFabric, unittest.TestCase):
def create_random_sleep_plan(
self, npartitions, max_sleep_secs, fail_first_try=False
):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=npartitions, partition_by_rows=True
)
random_sleep = RandomSleepNode(
ctx,
(data_partitions,),
max_sleep_secs=max_sleep_secs,
fail_first_try=fail_first_try,
)
return LogicalPlan(ctx, random_sleep)
def check_executor_state(self, target_state: ExecutorState, nloops=200):
for _ in range(nloops):
latest_sched_state = self.get_latest_sched_state()
if any(
executor.state == target_state
for executor in latest_sched_state.remote_executors
):
logger.info(
f"found {target_state} executor in: {latest_sched_state.remote_executors}"
)
break
time.sleep(0.1)
else:
self.assertTrue(
False,
f"cannot find any executor in state {target_state}: {latest_sched_state.remote_executors}",
)
def test_standalone_mode(self):
plan = self.create_random_sleep_plan(npartitions=10, max_sleep_secs=1)
self.execute_plan(plan, num_executors=0)
def test_failed_executors(self):
num_exec = 6
num_fail = 4
plan = self.create_random_sleep_plan(npartitions=300, max_sleep_secs=10)
_, executors, processes = self.start_execution(
plan,
num_executors=num_exec,
secs_wq_poll_interval=0.1,
secs_executor_probe_interval=0.5,
console_log_level="WARNING",
)
latest_sched_state = self.get_latest_sched_state()
self.check_executor_state(ExecutorState.GOOD)
for i, (executor, process) in enumerate(
random.sample(list(zip(executors, processes[1:])), k=num_fail)
):
if i % 2 == 0:
logger.warning(f"kill executor: {executor}")
process.kill()
else:
logger.warning(f"skip probes: {executor}")
executor.skip_probes(latest_sched_state.ctx.max_num_missed_probes * 2)
self.join_running_procs()
latest_sched_state = self.get_latest_sched_state()
self.assertTrue(latest_sched_state.success)
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
self.assertLessEqual(
1,
len(latest_sched_state.stopped_executors),
f"remote_executors: {latest_sched_state.remote_executors}",
)
self.assertLessEqual(
num_fail / 2,
len(latest_sched_state.failed_executors),
f"remote_executors: {latest_sched_state.remote_executors}",
)
def test_speculative_scheduling(self):
for speculative_exec in ("disable", "enable", "aggressive"):
with self.subTest(speculative_exec=speculative_exec):
plan = self.create_random_sleep_plan(npartitions=100, max_sleep_secs=10)
self.execute_plan(
plan,
num_executors=3,
secs_wq_poll_interval=0.1,
secs_executor_probe_interval=0.5,
prioritize_retry=(speculative_exec == "aggressive"),
speculative_exec=speculative_exec,
)
latest_sched_state = self.get_latest_sched_state()
if speculative_exec == "disable":
self.assertEqual(len(latest_sched_state.abandoned_tasks), 0)
else:
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
def test_stop_executor_on_failure(self):
plan = self.create_random_sleep_plan(
npartitions=3, max_sleep_secs=5, fail_first_try=True
)
exec_plan = self.execute_plan(
plan,
num_executors=5,
secs_wq_poll_interval=0.1,
secs_executor_probe_interval=0.5,
check_result=False,
stop_executor_on_failure=True,
)
latest_sched_state = self.get_latest_sched_state()
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)

54
tests/test_session.py Normal file
View File

@@ -0,0 +1,54 @@
import os
from smallpond.dataframe import Session
def test_shutdown_cleanup(sp: Session):
assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should exist"
assert os.path.exists(
sp._runtime_ctx.staging_root
), "staging directory should exist"
assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should exist"
# create some tasks and complete them
df = sp.from_items([1, 2, 3])
df.write_parquet(sp._runtime_ctx.output_root)
sp.shutdown()
# shutdown should clean up directories
assert not os.path.exists(
sp._runtime_ctx.queue_root
), "queue directory should be cleared"
assert not os.path.exists(
sp._runtime_ctx.staging_root
), "staging directory should be cleared"
assert not os.path.exists(
sp._runtime_ctx.temp_root
), "temp directory should be cleared"
with open(sp._runtime_ctx.job_status_path) as fin:
assert "success" in fin.read(), "job status should be success"
def test_shutdown_no_cleanup_on_failure(sp: Session):
df = sp.from_items([1, 2, 3])
try:
# create a task that will fail
df.map(lambda x: x / 0).compute()
except Exception:
pass
else:
raise RuntimeError("task should fail")
sp.shutdown()
# shutdown should not clean up directories
assert os.path.exists(
sp._runtime_ctx.queue_root
), "queue directory should not be cleared"
assert os.path.exists(
sp._runtime_ctx.staging_root
), "staging directory should not be cleared"
assert os.path.exists(
sp._runtime_ctx.temp_root
), "temp directory should not be cleared"
with open(sp._runtime_ctx.job_status_path) as fin:
assert "failure" in fin.read(), "job status should be failure"

50
tests/test_utility.py Normal file
View File

@@ -0,0 +1,50 @@
import random
import subprocess
import time
import unittest
from typing import Iterable
from smallpond.utility import ConcurrentIter, execute_command
from tests.test_fabric import TestFabric
class TestUtility(TestFabric, unittest.TestCase):
def test_concurrent_iter_no_error(self):
def slow_iterator(iter: Iterable[int], sleep_ms: int):
for i in iter:
time.sleep(sleep_ms / 1000)
yield i
for n in [1, 5, 10, 50, 100]:
with ConcurrentIter(slow_iterator(range(n), 2)) as iter1:
with ConcurrentIter(slow_iterator(iter1, 5)) as iter2:
self.assertEqual(sum(slow_iterator(iter2, 1)), sum(range(n)))
def test_concurrent_iter_with_error(self):
def broken_iterator(iter: Iterable[int], sleep_ms: int):
for i in iter:
time.sleep(sleep_ms / 1000)
if random.randint(1, 10) == 1:
raise Exception("raised before yield")
yield i
if random.randint(1, 10) == 1:
raise Exception("raised after yield")
raise Exception("raised at the end")
for n in [1, 5, 10, 50, 100]:
with self.assertRaises(Exception):
with ConcurrentIter(range(n)) as iter:
print(sum(broken_iterator(iter, 1)))
with self.assertRaises(Exception):
with ConcurrentIter(broken_iterator(range(n), 2)) as iter1:
with ConcurrentIter(broken_iterator(iter1, 5)) as iter2:
print(sum(iter2))
def test_execute_command(self):
with self.assertRaises(subprocess.CalledProcessError):
for line in execute_command("ls non_existent_file"):
print(line)
for line in execute_command("echo hello"):
print(line)
for line in execute_command("cat /dev/null"):
print(line)

119
tests/test_workqueue.py Normal file
View File

@@ -0,0 +1,119 @@
import multiprocessing
import multiprocessing.dummy
import multiprocessing.queues
import queue
import tempfile
import time
import unittest
from loguru import logger
from smallpond.execution.workqueue import (
WorkItem,
WorkQueue,
WorkQueueInMemory,
WorkQueueOnFilesystem,
)
from tests.test_fabric import TestFabric
class PrintWork(WorkItem):
def __init__(self, name: str, message: str) -> None:
super().__init__(name, cpu_limit=1, gpu_limit=0, memory_limit=0)
self.message = message
def run(self) -> bool:
logger.debug(f"{self.key}: {self.message}")
return True
def producer(wq: WorkQueue, id: int, numItems: int, numConsumers: int) -> None:
print(f"wq.outbound_works: {wq.outbound_works}")
for i in range(numItems):
wq.push(PrintWork(f"item-{i}", message="hello"), buffering=(i % 3 == 1))
# wq.push(PrintWork(f"item-{i}", message="hello"))
if i % 5 == 0:
wq.flush()
for i in range(numConsumers):
wq.push(PrintWork(f"stop-{i}", message="stop"))
logger.success(f"producer {id} generated {numItems} items")
def consumer(wq: WorkQueue, id: int) -> int:
numItems = 0
numWaits = 0
running = True
while running:
items = wq.pop(count=1)
if not items:
numWaits += 1
time.sleep(0.01)
continue
for item in items:
assert isinstance(item, PrintWork)
if item.message == "stop":
running = False
break
item.exec()
numItems += 1
logger.success(f"consumer {id} collected {numItems} items, {numWaits} waits")
logger.complete()
return numItems
class WorkQueueTestBase(object):
wq: WorkQueue = None
pool: multiprocessing.Pool = None
def setUp(self) -> None:
logger.disable("smallpond.execution.workqueue")
return super().setUp()
def test_basics(self):
numItems = 200
for i in range(numItems):
self.wq.push(PrintWork(f"item-{i}", message="hello"))
numCollected = 0
for _ in range(numItems):
items = self.wq.pop()
logger.info(f"{len(items)} items")
numCollected += len(items)
if numItems == numCollected:
break
def test_multi_consumers(self):
numConsumers = 10
numItems = 200
result = self.pool.starmap_async(
consumer, [(self.wq, id) for id in range(numConsumers)]
)
producer(self.wq, 0, numItems, numConsumers)
logger.info("waiting for result")
numCollected = sum(result.get(timeout=20))
logger.info(f"expected vs collected: {numItems} vs {numCollected}")
self.assertEqual(numItems, numCollected)
logger.success("all done")
self.pool.terminate()
self.pool.join()
logger.success("workers stopped")
class TestWorkQueueInMemory(WorkQueueTestBase, TestFabric, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
self.wq = WorkQueueInMemory(queue_type=queue.Queue)
self.pool = multiprocessing.dummy.Pool(10)
class TestWorkQueueOnFilesystem(WorkQueueTestBase, TestFabric, unittest.TestCase):
workq_root: str
def setUp(self) -> None:
super().setUp()
self.workq_root = tempfile.mkdtemp(dir=self.runtime_ctx.queue_root)
self.wq = WorkQueueOnFilesystem(self.workq_root, sort=True)
self.pool = multiprocessing.Pool(10)