commit 770aa417d59c3327fcbecc2f4652baf4aabba846 Author: Runji Wang Date: Tue Feb 25 18:16:31 2025 +0800 init diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..9ed7389 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,126 @@ +name: CI Pipeline + +on: + push: + branches: [ main ] + pull_request: + schedule: + - cron: '0 0 * * *' + +jobs: + fmt: + name: Format Check + runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install formatters + run: | + python -m pip install --upgrade pip + pip install black + + - name: Check code formatting + run: | + black --check . + + # - name: Check typos + # uses: crate-ci/typos@v1.29.10 + + test: + name: Test (${{ matrix.os }}, Python ${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [self-hosted] # macos-latest + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[dev] + + - name: Install gensort on Linux + if: matrix.os == 'self-hosted' + run: | + wget https://www.ordinal.com/try.cgi/gensort-linux-1.5.tar.gz + tar -xzf gensort-linux-1.5.tar.gz + chmod +x 64/gensort 64/valsort + export PATH=$PATH:$(pwd)/64 + + - name: Run tests + timeout-minutes: 20 + run: | + pytest -n 4 -v -x --durations=50 --timeout=600 \ + --junitxml=pytest.xml \ + --cov-report term --cov-report xml:coverage.xml \ + --cov=smallpond --cov=examples --cov=benchmarks --cov=tests \ + tests/test_*.py + + - name: Archive test results + uses: actions/upload-artifact@v4 + with: + name: test-results-${{ matrix.os }}-py${{ matrix.python-version }} + path: | + pytest.xml + coverage.xml + + build-docs: + name: Build Documentation + runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.8 + uses: actions/setup-python@v5 + with: + python-version: "3.8" + + - name: Install dependencies + run: pip install -e .[docs] + + - name: Build HTML docs + run: | + cd docs + make html + + - name: Archive documentation + uses: actions/upload-artifact@v4 + with: + name: documentation + path: docs/build/html + + deploy-docs: + name: Deploy Documentation + runs-on: self-hosted + needs: [build-docs] + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + steps: + - uses: actions/checkout@v4 + + - name: Download artifact + uses: actions/download-artifact@v4 + with: + name: documentation + path: docs/build/html + + - name: Deploy to docs branch + uses: peaceiris/actions-gh-pages@v4 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: docs/build/html + destination_dir: ./ + keep_files: true + branch: docs + force_orphan: true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..747fcea --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +__pycache__ +.ipynb_checkpoints +.tmp/ +dist/ +build/ +*.egg-info/ +tests/data/ +tests/runtime/ +*.log +*.pyc +*.xml +.tmp/ +.idea +.coverage +.vscode/ +.hypothesis/ +docs/*/generated/ +.venv*/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e61b41d --- /dev/null +++ b/LICENSE @@ -0,0 +1,7 @@ +Copyright 2025 DeepSeek + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the β€œSoftware”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED β€œAS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..1f79241 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +exclude tests/data/** tests/runtime/** diff --git a/README.md b/README.md new file mode 100644 index 0000000..7a48af6 --- /dev/null +++ b/README.md @@ -0,0 +1,81 @@ +# smallpond + +[![CI](https://github.com/deepseek-ai/smallpond/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/deepseek-ai/smallpond/actions/workflows/ci.yml) +[![PyPI](https://img.shields.io/pypi/v/smallpond)](https://pypi.org/project/smallpond/) +[![Docs](https://img.shields.io/badge/docs-latest-brightgreen.svg)](https://deepseek-ai.github.io/smallpond/) +[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) + +A lightweight data processing framework built on DuckDB and [3FS]. + +## Features + +- πŸš€ High-performance data processing powered by DuckDB +- 🌍 Scalable to handle PB-scale datasets +- πŸ› οΈ Easy operations with no long-running services + +## Installation + +Python 3.8 to 3.12 is supported. + +```bash +pip install smallpond +``` + +## Quick Start + +```bash +# Download example data +wget https://duckdb.org/data/prices.parquet +``` + +```python +import smallpond + +# Initialize session +sp = smallpond.init() + +# Load data +df = sp.read_parquet("prices.parquet") + +# Process data +df = df.repartition(3, hash_by="ticker") +df = sp.partial_sql("SELECT ticker, min(price), max(price) FROM {0} GROUP BY ticker", df) + +# Save results +df.write_parquet("output/") +# Show results +print(df.to_pandas()) +``` + +## Documentation + +For detailed guides and API reference: +- [Getting Started](docs/source/getstarted.rst) +- [API Reference](docs/source/api.rst) + +## Performance + +We executed the [Gray Sort benchmark] using [smallpond] on a cluster comprising 50 compute nodes and 25 storage nodes running [3FS]. The benchmark sorted 110.5TiB of data in 30 minutes and 14 seconds, achieving a throughput of 3.66TiB/min. + +[3FS]: https://github.com/deepseek-ai/3fs +[Gray Sort benchmark]: https://sortbenchmark.org/ +[smallpond]: benchmarks/gray_sort_benchmark.py + +## Development + +```bash +pip install .[dev] + +# run unit tests +pytest -v tests/test*.py + +# build documentation +pip install .[docs] +cd docs +make html +python -m http.server --directory build/html +``` + +## License + +This project is licensed under the [MIT License](LICENSE). diff --git a/benchmarks/file_io_benchmark.py b/benchmarks/file_io_benchmark.py new file mode 100644 index 0000000..e30c7a7 --- /dev/null +++ b/benchmarks/file_io_benchmark.py @@ -0,0 +1,84 @@ +from smallpond.common import DEFAULT_BATCH_SIZE, DEFAULT_ROW_GROUP_SIZE, GB +from smallpond.contrib.copy_table import CopyArrowTable, StreamCopy +from smallpond.execution.driver import Driver +from smallpond.logical.dataset import ParquetDataSet +from smallpond.logical.node import ( + Context, + DataSetPartitionNode, + DataSourceNode, + LogicalPlan, + SqlEngineNode, +) + + +def file_io_benchmark( + input_paths, + npartitions, + io_engine="duckdb", + batch_size=DEFAULT_BATCH_SIZE, + row_group_size=DEFAULT_ROW_GROUP_SIZE, + output_name="data", + **kwargs, +) -> LogicalPlan: + ctx = Context() + dataset = ParquetDataSet(input_paths) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions) + + if io_engine == "duckdb": + data_copy = SqlEngineNode( + ctx, + (data_partitions,), + r"select * from {0}", + parquet_row_group_size=row_group_size, + per_thread_output=False, + output_name=output_name, + cpu_limit=1, + memory_limit=10 * GB, + ) + elif io_engine == "arrow": + data_copy = CopyArrowTable( + ctx, + (data_partitions,), + parquet_row_group_size=row_group_size, + output_name=output_name, + cpu_limit=1, + memory_limit=10 * GB, + ) + elif io_engine == "stream": + data_copy = StreamCopy( + ctx, + (data_partitions,), + streaming_batch_size=batch_size, + parquet_row_group_size=row_group_size, + output_name=output_name, + cpu_limit=1, + memory_limit=10 * GB, + ) + + plan = LogicalPlan(ctx, data_copy) + return plan + + +def main(): + driver = Driver() + driver.add_argument("-i", "--input_paths", nargs="+") + driver.add_argument("-n", "--npartitions", type=int, default=None) + driver.add_argument( + "-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream") + ) + driver.add_argument("-b", "--batch_size", type=int, default=1024 * 1024) + driver.add_argument("-s", "--row_group_size", type=int, default=1024 * 1024) + driver.add_argument("-o", "--output_name", default="data") + driver.add_argument("-NC", "--cpus_per_node", type=int, default=128) + + user_args, driver_args = driver.parse_arguments() + total_num_cpus = driver_args.num_executors * user_args.cpus_per_node + user_args.npartitions = user_args.npartitions or total_num_cpus + + plan = file_io_benchmark(**driver.get_arguments()) + driver.run(plan) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/gray_sort_benchmark.py b/benchmarks/gray_sort_benchmark.py new file mode 100644 index 0000000..9cb28ad --- /dev/null +++ b/benchmarks/gray_sort_benchmark.py @@ -0,0 +1,417 @@ +import functools +import logging +import os.path +import shutil +import subprocess +import tempfile +from pathlib import PurePath +from typing import Iterable, List + +import duckdb +import polars +import psutil +import pyarrow as arrow +import pyarrow.compute as pc + +from smallpond.common import GB, MB, next_power_of_two, pytest_running +from smallpond.execution.driver import Driver +from smallpond.execution.task import ( + ArrowStreamTask, + PythonScriptTask, + RuntimeContext, + StreamOutput, +) +from smallpond.logical.dataset import ArrowTableDataSet, DataSet, ParquetDataSet +from smallpond.logical.node import ( + ArrowStreamNode, + Context, + DataSetPartitionNode, + DataSourceNode, + LogicalPlan, + ProjectionNode, + PythonScriptNode, + ShuffleNode, +) + + +class SortBenchTool(object): + + gensort_path = shutil.which("gensort") + valsort_path = shutil.which("valsort") + + @staticmethod + def ensure_installed(): + if not SortBenchTool.gensort_path or not SortBenchTool.valsort_path: + raise Exception("gensort or valsort not found") + + +def generate_records( + runtime_ctx: RuntimeContext, + input_readers: List[arrow.RecordBatchReader], + record_nbytes=100, + key_nbytes=10, + bucket_nbits=12, + gensort_batch_nbytes=500 * MB, +) -> Iterable[arrow.Table]: + runtime_task: ArrowStreamTask = runtime_ctx.task + batch_size = gensort_batch_nbytes // record_nbytes + schema = arrow.schema( + [ + arrow.field("buckets", arrow.uint16()), + arrow.field("keys", arrow.binary()), + arrow.field("records", arrow.binary()), + ] + ) + + with tempfile.NamedTemporaryFile(dir="/dev/shm", buffering=0) as shm_file: + for batch_idx, batch in enumerate(input_readers[0]): + for begin_at, num_records in zip(*batch.columns): + begin_at, num_records = begin_at.as_py(), num_records.as_py() + for offset in range(begin_at, begin_at + num_records, batch_size): + record_count = min(batch_size, begin_at + num_records - offset) + gensort_cmd = f"{SortBenchTool.gensort_path} -t2 -b{offset} {record_count} {shm_file.name},buf,trans=100m" + subprocess.run(gensort_cmd.split()).check_returncode() + runtime_task.add_elapsed_time("generate records (secs)") + shm_file.seek(0) + buffer = arrow.py_buffer( + shm_file.read(record_count * record_nbytes) + ) + runtime_task.add_elapsed_time("read records (secs)") + # https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout + records = arrow.Array.from_buffers( + arrow.binary(record_nbytes), record_count, [None, buffer] + ) + keys = pc.binary_slice(records, 0, key_nbytes) + # get first 2 bytes and convert to big-endian uint16 + binary_prefix = pc.binary_slice(records, 0, 2).cast(arrow.binary()) + reversed_prefix = pc.binary_reverse(binary_prefix).cast( + arrow.binary(2) + ) + uint16_prefix = reversed_prefix.view(arrow.uint16()) + buckets = pc.shift_right(uint16_prefix, 16 - bucket_nbits) + runtime_task.add_elapsed_time("build arrow table (secs)") + yield arrow.Table.from_arrays( + [buckets, keys, records], schema=schema + ) + yield StreamOutput( + schema.empty_table(), + batch_indices=[batch_idx], + force_checkpoint=pytest_running(), + ) + + +def sort_records( + runtime_ctx: RuntimeContext, + input_datasets: List[DataSet], + output_path: str, + sort_engine="polars", + write_io_nbytes=500 * MB, +) -> bool: + runtime_task: PythonScriptTask = runtime_ctx.task + data_file_path = os.path.join( + runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat" + ) + + if sort_engine == "polars": + input_data = polars.read_parquet( + input_datasets[0].resolved_paths, + rechunk=False, + hive_partitioning=False, + columns=input_datasets[0].columns, + ) + runtime_task.perf_metrics["num input rows"] += len(input_data) + runtime_task.add_elapsed_time("input load time (secs)") + sorted_records = input_data.sort("keys").get_column("records") + runtime_task.add_elapsed_time("sort by keys (secs)") + record_arrays = [chunk.to_arrow() for chunk in sorted_records.get_chunks()] + runtime_task.add_elapsed_time("convert to chunks (secs)") + elif sort_engine == "arrow": + input_table = input_datasets[0].to_arrow_table(runtime_task.cpu_limit) + runtime_task.perf_metrics["num input rows"] += input_table.num_rows + runtime_task.add_elapsed_time("input load time (secs)") + sorted_table = input_table.sort_by("keys") + runtime_task.add_elapsed_time("sort by keys (secs)") + record_arrays = sorted_table.column("records").chunks + runtime_task.add_elapsed_time("convert to chunks (secs)") + elif sort_engine == "duckdb": + with duckdb.connect( + database=":memory:", config={"allow_unsigned_extensions": "true"} + ) as conn: + runtime_task.prepare_connection(conn) + input_views = runtime_task.create_input_views(conn, input_datasets) + sql_query = "select records from {0} order by keys".format(*input_views) + sorted_table = conn.sql(sql_query).to_arrow_table() + runtime_task.add_elapsed_time("sort by keys (secs)") + record_arrays = sorted_table.column("records").chunks + runtime_task.add_elapsed_time("convert to chunks (secs)") + else: + raise Exception(f"unknown sort engine: {sort_engine}") + + with open(data_file_path, "wb") as fout: + for record_array in record_arrays: + # https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-layout + validity_bitmap, offsets, values = record_array.buffers() + buffer_mem = memoryview(values) + + total_write_nbytes = sum( + fout.write(buffer_mem[offset : offset + write_io_nbytes]) + for offset in range(0, len(buffer_mem), write_io_nbytes) + ) + assert total_write_nbytes == len(buffer_mem) + + runtime_task.perf_metrics["num output rows"] += len(record_array) + runtime_task.add_elapsed_time("output dump time (secs)") + return True + + +def validate_records( + runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str +) -> bool: + for data_path in input_datasets[0].resolved_paths: + summary_path = os.path.join( + output_path, PurePath(data_path).with_suffix(".sum").name + ) + cmdstr = ( + f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m" + ) + logging.debug(f"running command: {cmdstr}") + result = subprocess.run(cmdstr.split(), capture_output=True, encoding="utf8") + if result.stderr: + logging.info(f"valsort stderr: {result.stderr}") + if result.stdout: + logging.info(f"valsort stdout: {result.stdout}") + if result.returncode != 0: + return False + return True + + +def validate_summary( + runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str +) -> bool: + concated_summary_path = os.path.join(output_path, "merged.sum") + with open(concated_summary_path, "wb") as fout: + for path in input_datasets[0].resolved_paths: + with open(path, "rb") as fin: + fout.write(fin.read()) + cmdstr = f"{SortBenchTool.valsort_path} -s {concated_summary_path}" + logging.debug(f"running command: {cmdstr}") + result = subprocess.run(cmdstr.split(), capture_output=True, encoding="utf8") + if result.stderr: + logging.info(f"valsort stderr: {result.stderr}") + if result.stdout: + logging.info(f"valsort stdout: {result.stdout}") + return result.returncode == 0 + + +def generate_random_records( + ctx, + record_nbytes, + key_nbytes, + total_data_nbytes, + gensort_batch_nbytes, + num_data_partitions, + num_sort_partitions, + parquet_compression=None, + parquet_compression_level=None, +): + num_record_ranges = num_data_partitions * 10 + total_num_records = total_data_nbytes // record_nbytes + record_range_size = (total_num_records + num_record_ranges - 1) // num_record_ranges + logging.warning( + f"{record_nbytes} bytes/record x total {total_num_records:,d} records = " + f"{total_data_nbytes/GB:.3f}GB / {num_record_ranges} record ranges = " + f"{record_range_size * record_nbytes/GB:.3f}GB ({record_range_size:,d} records) per record range" + ) + + range_begin_at = [pos for pos in range(0, total_num_records, record_range_size)] + range_num_records = [ + min(total_num_records, record_range_size * (range_idx + 1)) - begin_at + for range_idx, begin_at in enumerate(range_begin_at) + ] + assert sum(range_num_records) == total_num_records + record_range = DataSourceNode( + ctx, + ArrowTableDataSet( + arrow.Table.from_arrays( + [range_begin_at, range_num_records], names=["begin_at", "num_records"] + ) + ), + ) + record_range_partitions = DataSetPartitionNode( + ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True + ) + + random_records = ArrowStreamNode( + ctx, + (record_range_partitions,), + process_func=functools.partial( + generate_records, + record_nbytes=record_nbytes, + key_nbytes=key_nbytes, + bucket_nbits=num_sort_partitions.bit_length() - 1, + gensort_batch_nbytes=gensort_batch_nbytes, + ), + background_io_thread=True, + streaming_batch_size=10, + parquet_row_group_size=1024 * 1024, + parquet_compression=parquet_compression, + parquet_compression_level=parquet_compression_level, + output_name="random_records", + cpu_limit=2, + ) + return random_records + + +def gray_sort_benchmark( + record_nbytes, + key_nbytes, + total_data_nbytes, + gensort_batch_nbytes, + num_data_partitions, + num_sort_partitions, + input_paths=None, + shuffle_engine="duckdb", + sort_engine="polars", + hive_partitioning=False, + validate_results=False, + shuffle_cpu_limit=32, + shuffle_memory_limit=None, + sort_cpu_limit=8, + sort_memory_limit=None, + parquet_compression=None, + parquet_compression_level=None, + **kwargs, +) -> LogicalPlan: + ctx = Context() + num_sort_partitions = next_power_of_two(num_sort_partitions) + + if input_paths: + input_dataset = ParquetDataSet(input_paths) + input_nbytes = sum(os.path.getsize(p) for p in input_dataset.resolved_paths) + logging.warning( + f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files" + ) + random_records = DataSourceNode(ctx, input_dataset) + else: + random_records = generate_random_records( + ctx, + record_nbytes, + key_nbytes, + total_data_nbytes, + gensort_batch_nbytes, + num_data_partitions, + num_sort_partitions, + parquet_compression, + parquet_compression_level, + ) + + partitioned_records = ShuffleNode( + ctx, + (random_records,), + npartitions=num_sort_partitions, + data_partition_column="buckets", + engine_type=shuffle_engine, + hive_partitioning=hive_partitioning, + parquet_row_group_size=10 * 1024 * 1024, + parquet_compression=parquet_compression, + parquet_compression_level=parquet_compression_level, + cpu_limit=shuffle_cpu_limit, + memory_limit=shuffle_memory_limit, + ) + + sorted_records = PythonScriptNode( + ctx, + (ProjectionNode(ctx, partitioned_records, ["keys", "records"]),), + process_func=functools.partial(sort_records, sort_engine=sort_engine), + output_name="sorted_records", + cpu_limit=sort_cpu_limit, + memory_limit=sort_memory_limit, + ) + + if validate_results: + partitioned_summaries = PythonScriptNode( + ctx, + (sorted_records,), + process_func=validate_records, + output_name="partitioned_summaries", + ) + merged_summaries = DataSetPartitionNode( + ctx, (partitioned_summaries,), npartitions=1 + ) + final_check = PythonScriptNode( + ctx, (merged_summaries,), process_func=validate_summary + ) + root = final_check + else: + root = sorted_records + + return LogicalPlan(ctx, root) + + +def main(): + SortBenchTool.ensure_installed() + + driver = Driver() + driver.add_argument("-R", "--record_nbytes", type=int, default=100) + driver.add_argument("-K", "--key_nbytes", type=int, default=10) + driver.add_argument("-T", "--total_data_nbytes", type=int, default=None) + driver.add_argument("-B", "--gensort_batch_nbytes", type=int, default=512 * MB) + driver.add_argument("-n", "--num_data_partitions", type=int, default=None) + driver.add_argument("-t", "--num_sort_partitions", type=int, default=None) + driver.add_argument("-i", "--input_paths", nargs="+", default=[]) + driver.add_argument( + "-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow") + ) + driver.add_argument( + "-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars") + ) + driver.add_argument("-H", "--hive_partitioning", action="store_true") + driver.add_argument("-V", "--validate_results", action="store_true") + driver.add_argument( + "-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit + ) + driver.add_argument( + "-M", + "--shuffle_memory_limit", + type=int, + default=ShuffleNode.default_memory_limit, + ) + driver.add_argument("-TC", "--sort_cpu_limit", type=int, default=8) + driver.add_argument("-TM", "--sort_memory_limit", type=int, default=None) + driver.add_argument( + "-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False) + ) + driver.add_argument( + "-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total + ) + driver.add_argument("-CP", "--parquet_compression", default=None) + driver.add_argument("-LV", "--parquet_compression_level", type=int, default=None) + + user_args, driver_args = driver.parse_arguments() + assert len(user_args.input_paths) == 0 or user_args.num_sort_partitions is not None + + total_num_cpus = max(1, driver_args.num_executors) * user_args.cpus_per_node + memory_per_cpu = user_args.memory_per_node // user_args.cpus_per_node + + user_args.sort_cpu_limit = ( + 1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit + ) + sort_memory_limit = ( + user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu + ) + user_args.total_data_nbytes = ( + user_args.total_data_nbytes + or max(1, driver_args.num_executors) * user_args.memory_per_node + ) + user_args.num_data_partitions = user_args.num_data_partitions or total_num_cpus // 2 + user_args.num_sort_partitions = user_args.num_sort_partitions or max( + total_num_cpus // user_args.sort_cpu_limit, + user_args.total_data_nbytes // (sort_memory_limit // 4), + ) + + plan = gray_sort_benchmark(**vars(user_args)) + driver.run(plan) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/hash_partition_benchmark.py b/benchmarks/hash_partition_benchmark.py new file mode 100644 index 0000000..3ca7bb4 --- /dev/null +++ b/benchmarks/hash_partition_benchmark.py @@ -0,0 +1,97 @@ +from smallpond.common import GB +from smallpond.contrib.log_dataset import LogDataSet +from smallpond.execution.driver import Driver +from smallpond.logical.dataset import ParquetDataSet +from smallpond.logical.node import ( + ConsolidateNode, + Context, + DataSourceNode, + HashPartitionNode, + LogicalPlan, + SqlEngineNode, +) + + +def hash_partition_benchmark( + input_paths, + npartitions, + hash_columns, + engine_type="duckdb", + use_parquet_writer=False, + hive_partitioning=False, + cpu_limit=None, + memory_limit=None, + partition_stats=True, + **kwargs, +) -> LogicalPlan: + ctx = Context() + dataset = ParquetDataSet(input_paths) + data_files = DataSourceNode(ctx, dataset) + + partitioned_datasets = HashPartitionNode( + ctx, + (data_files,), + npartitions=npartitions, + hash_columns=hash_columns, + data_partition_column="partition_keys", + engine_type=engine_type, + use_parquet_writer=use_parquet_writer, + hive_partitioning=hive_partitioning, + output_name="partitioned_datasets", + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + + if partition_stats: + partition_stats = SqlEngineNode( + ctx, + (partitioned_datasets,), + f""" + select partition_keys, count(*) as row_cnt, count( distinct ( {', '.join(hash_columns)} ) ) as uniq_key_cnt from {{0}} + group by partition_keys""", + output_name="partition_stats", + cpu_limit=1, + memory_limit=10 * GB, + ) + sorted_stats = SqlEngineNode( + ctx, + (ConsolidateNode(ctx, partition_stats, []),), + r"select * from {0} order by row_cnt desc", + ) + plan = LogicalPlan(ctx, LogDataSet(ctx, (sorted_stats,), num_rows=npartitions)) + else: + plan = LogicalPlan(ctx, partitioned_datasets) + + return plan + + +def main(): + driver = Driver() + driver.add_argument("-i", "--input_paths", nargs="+", required=True) + driver.add_argument("-n", "--npartitions", type=int, default=None) + driver.add_argument("-c", "--hash_columns", nargs="+", required=True) + driver.add_argument( + "-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow") + ) + driver.add_argument("-S", "--partition_stats", action="store_true") + driver.add_argument("-W", "--use_parquet_writer", action="store_true") + driver.add_argument("-H", "--hive_partitioning", action="store_true") + driver.add_argument( + "-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit + ) + driver.add_argument( + "-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit + ) + driver.add_argument("-NC", "--cpus_per_node", type=int, default=192) + driver.add_argument("-NM", "--memory_per_node", type=int, default=2000 * GB) + + user_args, driver_args = driver.parse_arguments() + total_num_cpus = driver_args.num_executors * user_args.cpus_per_node + user_args.npartitions = user_args.npartitions or total_num_cpus + + plan = hash_partition_benchmark(**vars(user_args)) + driver.run(plan) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/urls_sort_benchmark.py b/benchmarks/urls_sort_benchmark.py new file mode 100644 index 0000000..4065c04 --- /dev/null +++ b/benchmarks/urls_sort_benchmark.py @@ -0,0 +1,126 @@ +from typing import List, OrderedDict + +from smallpond.common import GB +from smallpond.dataframe import Session +from smallpond.execution.driver import Driver +from smallpond.logical.dataset import CsvDataSet +from smallpond.logical.node import ( + Context, + DataSetPartitionNode, + DataSourceNode, + HashPartitionNode, + LogicalPlan, + SqlEngineNode, +) + + +def urls_sort_benchmark( + input_paths: List[str], + num_data_partitions: int, + num_hash_partitions: int, + engine_type="duckdb", + sort_cpu_limit=8, + sort_memory_limit=None, +) -> LogicalPlan: + ctx = Context() + dataset = CsvDataSet( + input_paths, + schema=OrderedDict([("urlstr", "varchar"), ("valstr", "varchar")]), + delim=r"\t", + ) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode( + ctx, (data_files,), npartitions=num_data_partitions + ) + + imported_urls = SqlEngineNode( + ctx, + (data_partitions,), + r""" + select urlstr, valstr from {0} + """, + output_name="imported_urls", + parquet_row_group_size=1024 * 1024, + cpu_limit=1, + memory_limit=16 * GB, + ) + + urls_partitions = HashPartitionNode( + ctx, + (imported_urls,), + npartitions=num_hash_partitions, + hash_columns=["urlstr"], + engine_type=engine_type, + parquet_row_group_size=1024 * 1024, + cpu_limit=1, + memory_limit=16 * GB, + ) + + sorted_urls = SqlEngineNode( + ctx, + (urls_partitions,), + r"select * from {0} order by urlstr", + output_name="sorted_urls", + parquet_row_group_size=1024 * 1024, + cpu_limit=sort_cpu_limit, + memory_limit=sort_memory_limit, + ) + + plan = LogicalPlan(ctx, sorted_urls) + return plan + + +def urls_sort_benchmark_v2( + sp: Session, + input_paths: List[str], + output_path: str, + num_data_partitions: int, + num_hash_partitions: int, + engine_type="duckdb", + sort_cpu_limit=8, + sort_memory_limit=None, +): + dataset = sp.read_csv( + input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t" + ) + data_partitions = dataset.repartition(num_data_partitions) + urls_partitions = data_partitions.repartition( + num_hash_partitions, hash_by="urlstr", engine_type=engine_type + ) + sorted_urls = urls_partitions.partial_sort( + by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit + ) + sorted_urls.write_parquet(output_path) + + +def main(): + driver = Driver() + driver.add_argument("-i", "--input_paths", nargs="+") + driver.add_argument("-n", "--num_data_partitions", type=int, default=None) + driver.add_argument("-m", "--num_hash_partitions", type=int, default=None) + driver.add_argument("-e", "--engine_type", default="duckdb") + driver.add_argument("-TC", "--sort_cpu_limit", type=int, default=8) + driver.add_argument("-TM", "--sort_memory_limit", type=int, default=None) + user_args, driver_args = driver.parse_arguments() + + num_nodes = driver_args.num_executors + cpus_per_node = 120 + partition_rounds = 2 + user_args.num_data_partitions = ( + user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds + ) + user_args.num_hash_partitions = ( + user_args.num_hash_partitions or num_nodes * cpus_per_node + ) + + # v1 + plan = urls_sort_benchmark(**vars(user_args)) + driver.run(plan) + + # v2 + # sp = smallpond.init() + # urls_sort_benchmark_v2(sp, **vars(user_args)) + + +if __name__ == "__main__": + main() diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..747ffb7 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/api.rst b/docs/source/api.rst new file mode 100644 index 0000000..3f734ab --- /dev/null +++ b/docs/source/api.rst @@ -0,0 +1,81 @@ +API Reference +============= + +Smallpond provides both high-level and low-level APIs. + +.. note:: + Currently, smallpond provides two different APIs, supporting dynamic and static construction of data flow graphs respectively. Due to historical reasons, these two APIs use different scheduler backends and support different configuration options. + + - The High-level API currently uses Ray as the backend, supporting dynamic construction and execution of data flow graphs. + - The Low-level API uses a built-in scheduler and only supports one-time execution of static data flow graphs. However, it offers more performance optimizations and richer configuration options. + + We are working to merge them so that in the future, you can use a unified high-level API and freely choose between Ray or the built-in scheduler. + +High-level API +-------------- + +The high-level API is centered around :ref:`dataframe`. It allows dynamic construction of data flow graphs, execution, and result retrieval. + +A typical workflow looks like this: + +.. code-block:: python + + import smallpond + + sp = smallpond.init() + + df = sp.read_parquet("path/to/dataset/*.parquet") + df = df.repartition(10) + df = df.map("x + 1") + df.write_parquet("path/to/output") + +.. toctree:: + :maxdepth: 2 + + api/dataframe + +It is recommended to use the DataFrame API. + +Low-level API +------------- + +In the low-level API, users manually create :ref:`nodes` to construct static data flow graphs, then submit them to smallpond to generate :ref:`tasks` and wait for all tasks to complete. + +A complete example is shown below. + +.. code-block:: python + + from smallpond.logical.dataset import ParquetDataSet + from smallpond.logical.node import Context, DataSourceNode, DataSetPartitionNode, SqlEngineNode, LogicalPlan + from smallpond.execution.driver import Driver + + def my_pipeline(input_paths: List[str], npartitions: int): + ctx = Context() + dataset = ParquetDataSet(input_paths) + node = DataSourceNode(ctx, dataset) + node = DataSetPartitionNode(ctx, (node,), npartitions=npartitions) + node = SqlEngineNode(ctx, (node,), "SELECT * FROM {0}") + return LogicalPlan(ctx, node) + + if __name__ == "__main__": + driver = Driver() + driver.add_argument("-i", "--input_paths", nargs="+") + driver.add_argument("-n", "--npartitions", type=int, default=10) + + plan = my_pipeline(**driver.get_arguments()) + driver.run(plan) + +To run this script: + +.. code-block:: bash + + python script.py -i "path/to/*.parquet" -n 10 + + +.. toctree:: + :maxdepth: 2 + + api/dataset + api/nodes + api/tasks + api/execution diff --git a/docs/source/api/dataframe.rst b/docs/source/api/dataframe.rst new file mode 100644 index 0000000..7846d12 --- /dev/null +++ b/docs/source/api/dataframe.rst @@ -0,0 +1,104 @@ +.. _dataframe: + +DataFrame +========= + +DataFrame is the main class in smallpond. It represents a lazily computed, partitioned data set. + +A typical workflow looks like this: + +.. code-block:: python + + import smallpond + + sp = smallpond.init() + + df = sp.read_parquet("path/to/dataset/*.parquet") + df = df.repartition(10) + df = df.map("x + 1") + df.write_parquet("path/to/output") + +Initialization +-------------- + +.. autosummary:: + :toctree: ../generated + + smallpond.init + +.. currentmodule:: smallpond.dataframe + +.. _loading_data: + +Loading Data +------------ + +.. autosummary:: + :toctree: ../generated + + Session.from_items + Session.from_arrow + Session.from_pandas + Session.read_csv + Session.read_json + Session.read_parquet + +.. _partitioning_data: + +Partitioning Data +----------------- + +.. autosummary:: + :toctree: ../generated + + DataFrame.repartition + +.. _transformations: + +Transformations +--------------- + +Apply transformations and return a new DataFrame. + +.. autosummary:: + :toctree: ../generated + + Session.partial_sql + DataFrame.map + DataFrame.map_batches + DataFrame.flat_map + DataFrame.filter + DataFrame.limit + DataFrame.partial_sort + DataFrame.random_shuffle + +.. _consuming_data: + +Consuming Data +-------------- + +These operations will trigger execution of the lazy transformations performed on this DataFrame. + +.. autosummary:: + :toctree: ../generated + + DataFrame.count + DataFrame.take + DataFrame.take_all + DataFrame.to_arrow + DataFrame.to_pandas + DataFrame.write_parquet + DataFrame.write_parquet_lazy + +Execution +--------- + +DataFrames are lazily computed. You can use these methods to manually trigger computation. + +.. autosummary:: + :toctree: ../generated + + DataFrame.compute + DataFrame.is_computed + DataFrame.recompute + Session.wait diff --git a/docs/source/api/dataset.rst b/docs/source/api/dataset.rst new file mode 100644 index 0000000..2050590 --- /dev/null +++ b/docs/source/api/dataset.rst @@ -0,0 +1,28 @@ +.. currentmodule:: smallpond.logical.dataset + +Dataset +======= + +Dataset represents a collection of files. + +To create a dataset: + +.. code-block:: python + + dataset = ParquetDataSet("path/to/dataset/*.parquet") + +DataSets +-------- + +.. autosummary:: + :toctree: ../generated + + DataSet + FileSet + ParquetDataSet + CsvDataSet + JsonDataSet + ArrowTableDataSet + PandasDataSet + PartitionedDataSet + SqlQueryDataSet diff --git a/docs/source/api/execution.rst b/docs/source/api/execution.rst new file mode 100644 index 0000000..f851fe0 --- /dev/null +++ b/docs/source/api/execution.rst @@ -0,0 +1,85 @@ +.. currentmodule:: smallpond.execution + +.. _execution: + +Execution +========= + +Submit a Job +------------ + +After constructing the LogicalPlan, you can use the JobManager to create a Job in the cluster to execute it. However, in most cases, you only need to use the Driver as the entry point of the entire script and then submit the plan. The Driver is a simple wrapper around the JobManager. It reads the configuration from the command line arguments and passes it to the JobManager. + +.. code-block:: python + + from smallpond.execution.driver import Driver + + if __name__ == "__main__": + driver = Driver() + # add your own arguments + driver.add_argument("-i", "--input_paths", nargs="+") + driver.add_argument("-n", "--npartitions", type=int, default=10) + # build and run logical plan + plan = my_pipeline(**driver.get_arguments()) + driver.run(plan) + + +.. autosummary:: + :toctree: ../generated + + ~driver.Driver + ~manager.JobManager + +Scheduler and Executor +---------------------- + +Scheduler and Executor are lower-level APIs. They are directly responsible for scheduling and executing tasks, respectively. Generally, users do not need to use them directly. + +.. autosummary:: + :toctree: ../generated + + ~scheduler.Scheduler + ~executor.Executor + +.. _platform: + +Customize Platform +------------------ + +Smallpond supports user-defined task execution platforms. A Platform includes methods for submitting jobs and a series of default configurations. By default, smallpond automatically detects the current environment and selects the most suitable platform. If it cannot detect one, it uses the default platform. + +You can specify a built-in platform via parameters: + +.. code-block:: bash + + # run with your platform + python script.py --platform mpi + + +Or implement your own Platform class: + +.. code-block:: python + + # path/to/my/platform.py + from smallpond.platform import Platform + + class MyPlatform(Platform): + def start_job(self, ...) -> List[str]: + ... + +.. code-block:: bash + + # run with your platform + # if using Driver + python script.py --platform path.to.my.platform + + # if using smallpond.init + SP_PLATFORM=path.to.my.platform python script.py + +.. currentmodule:: smallpond + +.. autosummary:: + :toctree: ../generated + + ~platform.Platform + ~platform.MPI diff --git a/docs/source/api/nodes.rst b/docs/source/api/nodes.rst new file mode 100644 index 0000000..8c47b87 --- /dev/null +++ b/docs/source/api/nodes.rst @@ -0,0 +1,89 @@ +.. currentmodule:: smallpond.logical.node + +.. _nodes: + +Nodes +===== + +Nodes represent the fundamental building blocks of a data processing pipeline. Each node encapsulates a specific operation or transformation that can be applied to a dataset. +Nodes can be chained together to form a logical plan, which is a directed acyclic graph (DAG) of nodes that represent the overall data processing workflow. + +A typical workflow to create a logical plan is as follows: + +.. code-block:: python + + # Create a global context + ctx = Context() + + # Create a dataset + dataset = ParquetDataSet("path/to/dataset/*.parquet") + + # Create a data source node + node = DataSourceNode(ctx, dataset) + + # Partition the data + node = DataSetPartitionNode(ctx, (node,), npartitions=2) + + # Create a SQL engine node to transform the data + node = SqlEngineNode(ctx, (node,), "SELECT * FROM {0}") + + # Create a logical plan from the root node + plan = LogicalPlan(ctx, node) + +You can then create tasks from the logical plan, see :ref:`tasks`. + +Notable properties of Node: + +1. Nodes are partitioned. Each Node generates a series of tasks, with each task processing one partition of data. +2. The input and output of a Node are a series of partitioned Datasets. A Node may write data to shared storage and return a new Dataset, or it may simply recombine the input Datasets. + +Context +------- + +.. autosummary:: + :toctree: ../generated + + Context + NodeId + +LogicalPlan +----------- + +.. autosummary:: + :toctree: ../generated + + LogicalPlan + LogicalPlanVisitor + .. Planner + +Nodes +----- + +.. autosummary:: + :toctree: ../generated + + Node + DataSetPartitionNode + ArrowBatchNode + ArrowComputeNode + ArrowStreamNode + ConsolidateNode + DataSinkNode + DataSourceNode + EvenlyDistributedPartitionNode + HashPartitionNode + LimitNode + LoadPartitionedDataSetNode + PandasBatchNode + PandasComputeNode + PartitionNode + ProjectionNode + PythonScriptNode + RangePartitionNode + RepeatPartitionNode + RootNode + ShuffleNode + SqlEngineNode + UnionNode + UserDefinedPartitionNode + UserPartitionedDataSourceNode diff --git a/docs/source/api/tasks.rst b/docs/source/api/tasks.rst new file mode 100644 index 0000000..a7ae9ef --- /dev/null +++ b/docs/source/api/tasks.rst @@ -0,0 +1,76 @@ +.. currentmodule:: smallpond.execution.task + +.. _tasks: + +Tasks +===== + +.. code-block:: python + + # create a runtime context + runtime_ctx = RuntimeContext(JobId.new(), data_root) + runtime_ctx.initialize(socket.gethostname(), cleanup_root=True) + + # create a logical plan + plan = create_logical_plan() + + # create an execution plan + planner = Planner(runtime_ctx) + exec_plan = planner.create_exec_plan(plan) + +You can then execute the tasks in a scheduler, see :ref:`execution`. + +RuntimeContext +-------------- + +.. autosummary:: + :toctree: ../generated + + RuntimeContext + JobId + TaskId + TaskRuntimeId + PartitionInfo + PerfStats + + +ExecutionPlan +------------- + +.. autosummary:: + :toctree: ../generated + + ExecutionPlan + + +Tasks +----- + +.. autosummary:: + :toctree: ../generated + + Task + ArrowBatchTask + ArrowComputeTask + ArrowStreamTask + DataSinkTask + DataSourceTask + EvenlyDistributedPartitionProducerTask + HashPartitionArrowTask + HashPartitionDuckDbTask + HashPartitionTask + LoadPartitionedDataSetProducerTask + MergeDataSetsTask + PandasBatchTask + PandasComputeTask + PartitionConsumerTask + PartitionProducerTask + ProjectionTask + PythonScriptTask + RangePartitionTask + RepeatPartitionProducerTask + RootTask + SplitDataSetTask + SqlEngineTask + UserDefinedPartitionProducerTask + diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..a3af3a5 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,44 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "smallpond" +copyright = "2025, deepseek" +author = "deepseek" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", +] + +templates_path = ["_templates"] +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "pydata_sphinx_theme" +html_static_path = ["_static"] + +html_theme_options = { + "icon_links": [ + { + "name": "GitHub", + "url": "https://github.com/deepseek-ai/smallpond", + "icon": "fa-brands fa-square-github", + } + ] +} + +import os +import sys + +sys.path.insert(0, os.path.abspath("../..")) diff --git a/docs/source/getstarted.rst b/docs/source/getstarted.rst new file mode 100644 index 0000000..83a6d03 --- /dev/null +++ b/docs/source/getstarted.rst @@ -0,0 +1,83 @@ +Getting Started +=============== + +Installation +------------ + +Python 3.8+ is required. + +.. code-block:: bash + + pip install smallpond + +Initialization +-------------- + +The first step is to initialize the smallpond session: + +.. code-block:: python + + import smallpond + + sp = smallpond.init() + +Loading Data +------------ + +Create a DataFrame from a set of files: + +.. code-block:: python + + df = sp.read_parquet("path/to/dataset/*.parquet") + +To learn more about loading data, please refer to :ref:`loading_data`. + +Partitioning Data +----------------- + +Smallpond requires users to manually specify data partitions for now. + +.. code-block:: python + + df = df.repartition(3) # repartition by files + df = df.repartition(3, by_row=True) # repartition by rows + df = df.repartition(3, hash_by="host") # repartition by hash of column + +To learn more about partitioning data, please refer to :ref:`partitioning_data`. + +Transforming Data +----------------- + +Apply python functions or SQL expressions to transform data. + +.. code-block:: python + + df = df.map('a + b as c') + df = df.map(lambda row: {'c': row['a'] + row['b']}) + +To learn more about transforming data, please refer to :ref:`transformations`. + +Saving Data +----------- + +Save the transformed data to a set of files: + +.. code-block:: python + + df.write_parquet("path/to/output") + +To learn more about saving data, please refer to :ref:`consuming_data`. + +Monitoring +---------- + +Smallpond uses `Ray Core`_ as the task scheduler. You can use `Ray Dashboard`_ to monitor the task execution. + +.. _Ray Core: https://docs.ray.io/en/latest/ray-core/walkthrough.html +.. _Ray Dashboard: https://docs.ray.io/en/latest/ray-observability/getting-started.html + +When smallpond starts, it will print the Ray Dashboard URL: + +.. code-block:: bash + + ... Started a local Ray instance. View the dashboard at http://127.0.0.1:8008 diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..37a172f --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,27 @@ +smallpond +========= + +Smallpond is a lightweight distributed data processing framework. +It uses `duckdb`_ as the compute engine and stores data in `parquet`_ format on a distributed file system (e.g. `3FS`_). + +.. _duckdb: https://duckdb.org/ +.. _parquet: https://parquet.apache.org/ +.. _3FS: https://github.com/deepseek-ai/3fs + +Why smallpond? +-------------- + +- **Performance**: Smallpond uses DuckDB to deliver native-level performance for efficient data processing. +- **Scalability**: Leverages high-performance distributed file systems for intermediate storage, enabling PB-scale data handling without memory bottlenecks. +- **Simplicity**: No long-running services or complex dependencies, making it easy to deploy and maintain. + +.. toctree:: + :maxdepth: 1 + + getstarted + internals + +.. toctree:: + :maxdepth: 3 + + api diff --git a/docs/source/internals.rst b/docs/source/internals.rst new file mode 100644 index 0000000..3ef3d2d --- /dev/null +++ b/docs/source/internals.rst @@ -0,0 +1,37 @@ +Internals +========= + +Data Root +--------- + +Smallpond stores all data in a single directory called data root. + +This directory has the following structure: + +.. code-block:: bash + + data_root + └── 2024-12-11-12-00-28.2cc39990-296f-48a3-8063-78cf6dca460b # job_time.job_id + β”œβ”€β”€ config # configuration and state + β”‚ β”œβ”€β”€ exec_plan.pickle + β”‚ β”œβ”€β”€ logical_plan.pickle + β”‚ └── runtime_ctx.pickle + β”œβ”€β”€ log # logs + β”‚ β”œβ”€β”€ graph.png + β”‚ └── scheduler.log + β”œβ”€β”€ queue # message queue between scheduler and workers + β”œβ”€β”€ output # output data + β”œβ”€β”€ staging # intermediate data + β”‚ β”œβ”€β”€ DataSourceTask.000001 + β”‚ β”œβ”€β”€ EvenlyDistributedPartitionProducerTask.000002 + β”‚ β”œβ”€β”€ completed_tasks # output dataset of completed tasks + β”‚ └── started_tasks # used for checkpoint + └── temp # temporary data + β”œβ”€β”€ DataSourceTask.000001 + └── EvenlyDistributedPartitionProducerTask.000002 + +Failure Recovery +---------------- + +Smallpond can recover from failure and resume execution from the last checkpoint. +Checkpoint is task-level. A few tasks, such as `ArrowBatchTask`, support checkpointing at the batch level. diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/fstest.py b/examples/fstest.py new file mode 100644 index 0000000..21967b9 --- /dev/null +++ b/examples/fstest.py @@ -0,0 +1,225 @@ +# Test the correctness of file system read and write. +# +# This script runs multiple tasks to write and read data to/from the file system. +# Each task writes to an individual file in the given directory. +# Then it reads the data back and verifies the correctness. + +import argparse +import glob +import logging +import os +import random +import time +from typing import Any, Dict, Iterator, Optional, Tuple, Union + +import numpy as np + +import smallpond +from smallpond.dataframe import Session + + +def fswrite( + path: str, + size: int, + blocksize: Union[int, Tuple[int, int]], +) -> Dict[str, Any]: + t0 = time.time() + with open(path, "wb") as f: + for start, length in iter_io_slice(0, size, blocksize): + logging.debug(f"writing {length} bytes at offset {start}") + f.write(generate_data(start, length)) + t1 = time.time() + elapsed = t1 - t0 + logging.info(f"write done: {path} in {elapsed:.2f}s") + return { + "path": path, + "size": size, + "elapsed(s)": elapsed, + "throughput(MB/s)": size / elapsed / 1024 / 1024, + } + + +def fsread( + path: str, + blocksize: Union[int, Tuple[int, int]], + randread: bool, +) -> Dict[str, Any]: + t0 = time.time() + with open(path, "rb") as f: + size = os.path.getsize(path) + slices = list(iter_io_slice(0, size, blocksize)) + if randread: + random.shuffle(slices) + for start, length in slices: + logging.debug(f"reading {length} bytes at offset {start}") + f.seek(start) + data = f.read(length) + expected_data = generate_data(start, length) + check_data(data, expected_data, start) + t1 = time.time() + elapsed = t1 - t0 + logging.info(f"read done: {path} in {elapsed:.2f}s") + return { + "path": path, + "size": size, + "elapsed(s)": elapsed, + "throughput(MB/s)": size / elapsed / 1024 / 1024, + } + + +def check_data(actual: bytes, expected: bytes, offset: int) -> None: + """ + Check if the expected data matches the actual data. Raise an error if there is a mismatch. + """ + if expected == actual: + return + # find the first mismatch + index = next( + (i for i, (b1, b2) in enumerate(zip(actual, expected)) if b1 != b2), + min(len(actual), len(expected)), + ) + expected = expected[index : index + 16] + actual = actual[index : index + 16] + raise ValueError( + f"Data mismatch at offset {offset + index}.\nexpect: {expected}\nactual: {actual}" + ) + + +def generate_data(offset: int, length: int) -> bytes: + """ + Generate data for the slice [offset, offset + length). + The full data is a repeated sequence of [0x00000000, 0x00000001, ..., 0xffffffff] in little-endian. + """ + istart = offset // 4 + iend = (offset + length + 3) // 4 + return ( + np.arange(istart, iend) + .astype(np.uint32) + .tobytes()[offset % 4 : offset % 4 + length] + ) + + +def iter_io_slice( + offset: int, length: int, block_size: Union[int, Tuple[int, int]] +) -> Iterator[Tuple[int, int]]: + """ + Generate the IO (offset, size) for the slice [offset, offset + length) with the given block size. + `block_size` can be an integer or a range [start, end]. If a range is provided, the IO size will be randomly selected from the range. + """ + start = offset + end = offset + length + while start < end: + if isinstance(block_size, int): + size = block_size + else: + smin, smax = block_size + size = random.randint(smin, smax) + size = min(size, end - start) + yield (start, size) + start += size + + +def size_str_to_bytes(size_str: str) -> int: + """ + Parse size string to bytes. + e.g. 1k -> 1024, 1M -> 1024^2, 1G -> 1024^3, 1T -> 1024^4 + """ + if size_str.endswith("k"): + return int(size_str[:-1]) * 1024 + elif size_str.endswith("M"): + return int(size_str[:-1]) * 1024 * 1024 + elif size_str.endswith("G"): + return int(size_str[:-1]) * 1024 * 1024 * 1024 + elif size_str.endswith("T"): + return int(size_str[:-1]) * 1024 * 1024 * 1024 * 1024 + else: + return int(size_str) + + +def fstest( + sp: Session, + input_path: Optional[str], + output_path: Optional[str], + size: Optional[str], + npartitions: int, + blocksize: Optional[str] = "4k", + blocksize_range: Optional[str] = None, + randread: bool = False, +) -> None: + # preprocess arguments + if output_path is not None and size is None: + raise ValueError("--size is required if --output_path is provided") + if size is not None: + size = size_str_to_bytes(size) + if blocksize_range is not None: + start, end = blocksize_range.split("-") + blocksize = (size_str_to_bytes(start), size_str_to_bytes(end)) + elif blocksize is not None: + blocksize = size_str_to_bytes(blocksize) + else: + raise ValueError("either --blocksize or --blocksize_range must be provided") + + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + df = sp.from_items( + [{"path": os.path.join(output_path, f"{i}")} for i in range(npartitions)] + ) + df = df.repartition(npartitions, by_rows=True) + stats = df.map(lambda x: fswrite(x["path"], size, blocksize)).to_pandas() + logging.info(f"write stats:\n{stats}") + + if input_path is not None: + paths = list(glob.glob(input_path)) + df = sp.from_items([{"path": path} for path in paths]) + df = df.repartition(len(paths), by_rows=True) + stats = df.map(lambda x: fsread(x["path"], blocksize, randread)).to_pandas() + logging.info(f"read stats:\n{stats}") + + +if __name__ == "__main__": + """ + Example usage: + - write only: + python example/fstest.py -o 'fstest' -j 8 -s 1G + - read only: + python example/fstest.py -i 'fstest/*' + - write and then read: + python example/fstest.py -o 'fstest' -j 8 -s 1G -i 'fstest/*' + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", "--output_path", type=str, help="The output path to write data to." + ) + parser.add_argument( + "-i", + "--input_path", + type=str, + help="The input path to read data from. If -o is provided, this is ignored.", + ) + parser.add_argument( + "-j", "--npartitions", type=int, help="The number of parallel jobs", default=10 + ) + parser.add_argument( + "-s", + "--size", + type=str, + help="The size for each file. Required if -o is provided.", + ) + parser.add_argument("-bs", "--blocksize", type=str, help="Block size", default="4k") + parser.add_argument( + "-bsrange", + "--blocksize_range", + type=str, + help="A range of I/O block sizes. e.g. 4k-128k", + ) + parser.add_argument( + "-randread", + "--randread", + action="store_true", + help="Whether to read data randomly", + default=False, + ) + args = parser.parse_args() + + sp = smallpond.init() + fstest(sp, **vars(args)) diff --git a/examples/shuffle_data.py b/examples/shuffle_data.py new file mode 100644 index 0000000..82cf493 --- /dev/null +++ b/examples/shuffle_data.py @@ -0,0 +1,78 @@ +from smallpond.contrib.copy_table import StreamCopy +from smallpond.execution.driver import Driver +from smallpond.logical.dataset import ParquetDataSet +from smallpond.logical.node import ( + Context, + DataSetPartitionNode, + DataSourceNode, + HashPartitionNode, + LogicalPlan, + SqlEngineNode, +) + + +def shuffle_data( + input_paths, + num_out_data_partitions: int = 0, + num_data_partitions: int = 10, + num_hash_partitions: int = 10, + engine_type="duckdb", + skip_hash_partition=False, +) -> LogicalPlan: + ctx = Context() + dataset = ParquetDataSet(input_paths, union_by_name=True) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode( + ctx, + (data_files,), + npartitions=num_data_partitions, + partition_by_rows=True, + random_shuffle=skip_hash_partition, + ) + if skip_hash_partition: + urls_partitions = data_partitions + else: + urls_partitions = HashPartitionNode( + ctx, + (data_partitions,), + npartitions=num_hash_partitions, + hash_columns=None, + random_shuffle=True, + engine_type=engine_type, + ) + shuffled_urls = SqlEngineNode( + ctx, + (urls_partitions,), + r"select *, cast(random() * 2147483647 as integer) as sort_key from {0} order by sort_key", + cpu_limit=16, + ) + repartitioned = DataSetPartitionNode( + ctx, + (shuffled_urls,), + npartitions=num_out_data_partitions, + partition_by_rows=True, + ) + shuffled_urls = StreamCopy( + ctx, (repartitioned,), output_name="data_copy", cpu_limit=1 + ) + + plan = LogicalPlan(ctx, shuffled_urls) + return plan + + +def main(): + driver = Driver() + driver.add_argument("-i", "--input_paths", nargs="+") + driver.add_argument("-nd", "--num_data_partitions", type=int, default=1024) + driver.add_argument("-nh", "--num_hash_partitions", type=int, default=3840) + driver.add_argument("-no", "--num_out_data_partitions", type=int, default=1920) + driver.add_argument( + "-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow") + ) + driver.add_argument("-x", "--skip_hash_partition", action="store_true") + plan = shuffle_data(**driver.get_arguments()) + driver.run(plan) + + +if __name__ == "__main__": + main() diff --git a/examples/shuffle_mock_urls.py b/examples/shuffle_mock_urls.py new file mode 100644 index 0000000..ffde6e6 --- /dev/null +++ b/examples/shuffle_mock_urls.py @@ -0,0 +1,73 @@ +from smallpond.common import GB +from smallpond.execution.driver import Driver +from smallpond.logical.dataset import ParquetDataSet +from smallpond.logical.node import ( + Context, + DataSetPartitionNode, + DataSourceNode, + HashPartitionNode, + LogicalPlan, + SqlEngineNode, +) + + +def shuffle_mock_urls( + input_paths, npartitions: int = 10, sort_rand_keys=True, engine_type="duckdb" +) -> LogicalPlan: + ctx = Context() + dataset = ParquetDataSet(input_paths) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions) + + urls_partitions = HashPartitionNode( + ctx, + (data_partitions,), + npartitions=npartitions, + hash_columns=None, + random_shuffle=True, + engine_type=engine_type, + output_name="urls_partitions", + cpu_limit=1, + memory_limit=20 * GB, + ) + + if sort_rand_keys: + # shuffle as sorting partition keys + shuffled_urls = SqlEngineNode( + ctx, + (urls_partitions,), + r"select *, random() as partition_key from {0} order by partition_key", + output_name="shuffled_urls", + cpu_limit=1, + memory_limit=40 * GB, + ) + else: + # shuffle as reservoir sampling + shuffled_urls = SqlEngineNode( + ctx, + (urls_partitions,), + r"select * from {0} using sample 100% (reservoir, {rand_seed})", + output_name="shuffled_urls", + cpu_limit=1, + memory_limit=40 * GB, + ) + + plan = LogicalPlan(ctx, shuffled_urls) + return plan + + +def main(): + driver = Driver() + driver.add_argument("-i", "--input_paths", nargs="+") + driver.add_argument("-n", "--npartitions", type=int, default=500) + driver.add_argument("-s", "--sort_rand_keys", action="store_true") + driver.add_argument( + "-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow") + ) + + plan = shuffle_mock_urls(**driver.get_arguments()) + driver.run(plan) + + +if __name__ == "__main__": + main() diff --git a/examples/sort_mock_urls.py b/examples/sort_mock_urls.py new file mode 100644 index 0000000..eb4042e --- /dev/null +++ b/examples/sort_mock_urls.py @@ -0,0 +1,104 @@ +import logging +import os.path +from typing import List, Optional, OrderedDict + +import pyarrow as arrow + +from smallpond.execution.driver import Driver +from smallpond.execution.task import RuntimeContext +from smallpond.logical.dataset import CsvDataSet +from smallpond.logical.node import ( + ArrowComputeNode, + Context, + DataSetPartitionNode, + DataSinkNode, + DataSourceNode, + HashPartitionNode, + LogicalPlan, + SqlEngineNode, +) + + +class SortUrlsNode(ArrowComputeNode): + def process( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + logging.info(f"sorting urls by 'host', table shape: {input_tables[0].shape}") + return input_tables[0].sort_by("host") + + +def sort_mock_urls( + input_paths, + npartitions: int, + engine_type="duckdb", + external_output_path: Optional[str] = None, +) -> LogicalPlan: + ctx = Context() + dataset = CsvDataSet( + input_paths, + schema=OrderedDict([("urlstr", "varchar"), ("valstr", "varchar")]), + delim=r"\t", + ) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions) + imported_urls = SqlEngineNode( + ctx, + (data_partitions,), + r""" + select split_part(urlstr, '/', 1) as host, split_part(urlstr, ' ', 1) as url, from_base64(valstr) AS payload from {0} + """, + output_name="imported_urls", + output_path=external_output_path, + ) + urls_partitions = HashPartitionNode( + ctx, + (imported_urls,), + npartitions=npartitions, + hash_columns=["host"], + engine_type=engine_type, + output_name="urls_partitions", + output_path=external_output_path, + ) + + if engine_type == "duckdb": + sorted_urls = SqlEngineNode( + ctx, + (urls_partitions,), + r"select * from {0} order by host", + output_name="sorted_urls", + ) + else: + sorted_urls = SortUrlsNode( + ctx, + (urls_partitions,), + output_name="sorted_urls", + output_path=external_output_path, + ) + + final_result = DataSetPartitionNode(ctx, (sorted_urls,), npartitions=1) + + if external_output_path: + final_result = DataSinkNode( + ctx, + (final_result,), + output_path=os.path.join(external_output_path, "data_sink"), + ) + + plan = LogicalPlan(ctx, final_result) + return plan + + +def main(): + driver = Driver() + driver.add_argument( + "-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"] + ) + driver.add_argument("-n", "--npartitions", type=int, default=10) + driver.add_argument("-e", "--engine_type", default="duckdb") + + plan = sort_mock_urls(**driver.get_arguments()) + driver.run(plan) + + +if __name__ == "__main__": + main() diff --git a/examples/sort_mock_urls_v2.py b/examples/sort_mock_urls_v2.py new file mode 100644 index 0000000..d642c3f --- /dev/null +++ b/examples/sort_mock_urls_v2.py @@ -0,0 +1,36 @@ +import argparse +from typing import List + +import smallpond +from smallpond.dataframe import Session + + +def sort_mock_urls_v2( + sp: Session, input_paths: List[str], output_path: str, npartitions: int +): + dataset = sp.read_csv( + input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t" + ).repartition(npartitions) + urls = dataset.map( + """ + split_part(urlstr, '/', 1) as host, + split_part(urlstr, ' ', 1) as url, + from_base64(valstr) AS payload + """ + ) + urls = urls.repartition(npartitions, hash_by="host") + sorted_urls = urls.partial_sort(by=["host"]) + sorted_urls.write_parquet(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"] + ) + parser.add_argument("-o", "--output_path", type=str, default="sort_mock_urls") + parser.add_argument("-n", "--npartitions", type=int, default=10) + args = parser.parse_args() + + sp = smallpond.init() + sort_mock_urls_v2(sp, **vars(args)) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..624d7f7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,79 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "smallpond" +version = "0.15.0" +description = "A lightweight data processing framework built on DuckDB and shared file system." +authors = [ + { name = "DeepSeek-AI", email = "research@deepseek.com" }, + { name = "Runji Wang" }, + { name = "Yiliang Xiong" }, + { name = "Yiyuan Liu" }, + { name = "Yuheng Zou" }, + { name = "Yichao Zhang" }, + { name = "Wenjun Gao" }, + { name = "Wentao Zhang" }, + { name = "Xiaotao Nie" }, + { name = "Minghua Zhang" }, + { name = "Zhewen Hao" }, +] +urls = { Homepage = "https://github.com/deepseek-ai/smallpond" } +keywords = ["distributed query processing", "SQL", "parquet"] +requires-python = ">=3.8" +dependencies = [ + "duckdb >= 1.2.0", + "pyarrow ~= 16.1.0", + "polars ~= 0.20.9", + "pandas >= 1.3.4", + "plotly >= 5.22.0", + "lxml >= 4.9.3", + "cloudpickle >= 2.0.0", + "zstandard >= 0.22.0", + "loguru >= 0.7.2", + "psutil >= 5.9.8", + "GPUtil >= 1.4.0", + "py-libnuma >= 1.2", + "fsspec >= 2023.12.2", + "ray[default] >= 2.10.0", + "graphviz >= 0.19.1", +] + +[project.optional-dependencies] +dev = [ + "coverage~=7.4.4", + "hypothesis~=6.100.0", + "pytest==8.2.1", + "pytest-cov==5.0.0", + "pytest-forked==1.6.0", + "pytest-xdist==3.6.1", + "pytest-timeout==2.3.1", + "pytest-benchmark==4.0.0", + "setproctitle==1.3.3", + "soupsieve~=2.5", + "setuptools-scm==8.1.0", + "packaging==24.2", + "jaraco.functools==4.1.0", +] +docs = [ + "sphinx==7.1.2", + "pydata-sphinx-theme==0.14.4", +] +warc = [ + "warcio >= 1.7.4", + "beautifulsoup4 >= 4.12.2", +] + +[tool.setuptools] +packages = [ + "smallpond", + "smallpond.io", + "smallpond.execution", + "smallpond.logical", + "smallpond.contrib", + "smallpond.platform", +] + +[tool.setuptools_scm] +fallback_version = "0.0.0" diff --git a/smallpond/__init__.py b/smallpond/__init__.py new file mode 100644 index 0000000..2ff6755 --- /dev/null +++ b/smallpond/__init__.py @@ -0,0 +1,89 @@ +from importlib.metadata import PackageNotFoundError, version +from typing import Optional + +try: + __version__ = version("smallpond") +except PackageNotFoundError: + # package is not installed + __version__ = "unknown" + + +def init( + job_id: Optional[str] = None, + job_time: Optional[float] = None, + job_name: Optional[str] = None, + data_root: Optional[str] = None, + num_executors: Optional[int] = None, + ray_address: Optional[str] = None, + bind_numa_node: Optional[bool] = None, + platform: Optional[str] = None, + _remove_output_root: bool = True, + **kwargs, +) -> "Session": + """ + Initialize smallpond environment. + + This is the entry point for smallpond:: + + import smallpond + sp = smallpond.init() + + By default, it will use a local ray cluster as worker node. + To use more worker nodes, please specify the argument:: + + sp = smallpond.init(num_executors=10) + + It will create an task to run the workers. + + Parameters + ---------- + All parameters are optional. If not specified, read from environment variables. If not set, use default values. + + job_id (SP_JOBID) + Unique job id. Default to a random uuid. + job_time (SP_JOB_TIME) + Job create time (seconds since epoch). Default to current time. + job_name (SP_JOB_NAME) + Job display name. Default to the filename of the current script. + data_root (SP_DATA_ROOT) + The root folder for all files generated at runtime. + num_executors (SP_NUM_EXECUTORS) + The number of executors. + Default to 0, which means all tasks will be run on the current node. + ray_address (SP_RAY_ADDRESS) + If specified, use the given address to connect to an existing ray cluster. + Otherwise, create a new ray cluster. + bind_numa_node (SP_BIND_NUMA_NODE) + If true, bind executor processes to numa nodes. + memory_allocator (SP_MEMORY_ALLOCATOR) + The memory allocator used by worker processes. + Choices: "system", "jemalloc", "mimalloc". Default to "mimalloc". + platform (SP_PLATFORM) + The platform to use. Choices: "mpi". + By default, it will automatically detect the environment and choose the most suitable platform. + _remove_output_root + If true, remove the "{data_root}/output" directory after the job is finished. + Default to True. This is only used for compatibility. User should use `write_parquet` for saving outputs. + + Spawning a new job + ------------------ + If the environment variable `SP_SPAWN` is set to "1", it will spawn a new job to run the current script. + """ + import atexit + + from smallpond.dataframe import Session + + session = Session( + job_id=job_id, + job_time=job_time, + job_name=job_name, + data_root=data_root, + num_executors=num_executors, + ray_address=ray_address, + bind_numa_node=bind_numa_node, + platform=platform, + _remove_output_root=_remove_output_root, + **kwargs, + ) + atexit.register(session.shutdown) + return session diff --git a/smallpond/common.py b/smallpond/common.py new file mode 100644 index 0000000..b73981b --- /dev/null +++ b/smallpond/common.py @@ -0,0 +1,117 @@ +import itertools +import math +import sys +from typing import Dict, List, TypeVar + +import numpy as np + +from smallpond.logical.udf import * + +KB = 1024 +MB = 1024 * KB +GB = 1024 * MB +TB = 1024 * GB + +DEFAULT_MAX_RETRY_COUNT = 5 +DEFAULT_MAX_FAIL_COUNT = 3 +# duckdb default row group size https://duckdb.org/docs/data/parquet/tips#selecting-a-row_group_size +MAX_ROW_GROUP_SIZE = 10 * 1024 * 1024 +MAX_ROW_GROUP_BYTES = 2 * GB +MAX_NUM_ROW_GROUPS = 256 +MAX_PARQUET_FILE_BYTES = 8 * GB +DEFAULT_ROW_GROUP_SIZE = 122880 +DEFAULT_ROW_GROUP_BYTES = DEFAULT_ROW_GROUP_SIZE * 4 * KB +DEFAULT_BATCH_SIZE = 122880 +RAND_SEED_BYTE_LEN = 128 +DATA_PARTITION_COLUMN_NAME = "__data_partition__" +PARQUET_METADATA_KEY_PREFIX = "SMALLPOND:" +INPUT_VIEW_PREFIX = "__input" +GENERATED_COLUMNS = ("filename", "file_row_number") + + +def pytest_running(): + return "pytest" in sys.modules + + +def clamp_value(val, minval, maxval): + return max(minval, min(val, maxval)) + + +def clamp_row_group_size(val, minval=DEFAULT_ROW_GROUP_SIZE, maxval=MAX_ROW_GROUP_SIZE): + return clamp_value(val, minval, maxval) + + +def clamp_row_group_bytes( + val, minval=DEFAULT_ROW_GROUP_BYTES, maxval=MAX_ROW_GROUP_BYTES +): + return clamp_value(val, minval, maxval) + + +class SmallpondError(Exception): + """Base class for all errors in smallpond.""" + + +class InjectedFault(SmallpondError): + pass + + +class OutOfMemory(SmallpondError): + pass + + +class NonzeroExitCode(SmallpondError): + pass + + +K = TypeVar("K") +V = TypeVar("V") + + +def first_value_in_dict(d: Dict[K, V]) -> V: + return next(iter(d.values())) if d else None + + +def split_into_cols(items: List[V], npartitions: int) -> List[List[V]]: + none = object() + chunks = [items[i : i + npartitions] for i in range(0, len(items), npartitions)] + return [ + [x for x in col if x is not none] + for col in itertools.zip_longest([none] * npartitions, *chunks, fillvalue=none) + ] + + +def split_into_rows(items: List[V], npartitions: int) -> List[List[V]]: + """ + Evenly split items into npartitions. + + Example:: + >>> split_into_rows(list(range(10)), 3) + [[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]] + """ + split_idxs = np.array_split(np.arange(len(items)), npartitions) + return [[items[i] for i in idxs] for idxs in split_idxs] + + +def get_nth_partition(items: List[V], n: int, npartitions: int) -> List[V]: + num_items = len(items) + large_partition_size = (num_items + npartitions - 1) // npartitions + small_partition_size = num_items // npartitions + num_large_partitions = num_items - small_partition_size * npartitions + if n < num_large_partitions: + start = n * large_partition_size + items_in_partition = items[start : start + large_partition_size] + else: + start = ( + large_partition_size * num_large_partitions + + (n - num_large_partitions) * small_partition_size + ) + items_in_partition = items[start : start + small_partition_size] + return items_in_partition + + +def next_power_of_two(x) -> int: + return 2 ** (x - 1).bit_length() + + +def round_up(x, align_size=MB) -> int: + return math.ceil(x / align_size) * align_size diff --git a/smallpond/contrib/__init__.py b/smallpond/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/smallpond/contrib/copy_table.py b/smallpond/contrib/copy_table.py new file mode 100644 index 0000000..1b309b0 --- /dev/null +++ b/smallpond/contrib/copy_table.py @@ -0,0 +1,24 @@ +from typing import Iterable, List + +import pyarrow as arrow +from loguru import logger + +from smallpond.execution.task import RuntimeContext +from smallpond.logical.node import ArrowComputeNode, ArrowStreamNode + + +class CopyArrowTable(ArrowComputeNode): + def process( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + logger.info(f"copying table: {input_tables[0].num_rows} rows ...") + return input_tables[0] + + +class StreamCopy(ArrowStreamNode): + def process( + self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] + ) -> Iterable[arrow.Table]: + for batch in input_readers[0]: + logger.info(f"copying batch: {batch.num_rows} rows ...") + yield arrow.Table.from_batches([batch]) diff --git a/smallpond/contrib/log_dataset.py b/smallpond/contrib/log_dataset.py new file mode 100644 index 0000000..e2b20d9 --- /dev/null +++ b/smallpond/contrib/log_dataset.py @@ -0,0 +1,37 @@ +from typing import List, Tuple + +from smallpond.execution.task import PythonScriptTask, RuntimeContext +from smallpond.logical.dataset import DataSet +from smallpond.logical.node import Context, Node, PythonScriptNode + + +class LogDataSetTask(PythonScriptTask): + + num_rows = 200 + + @property + def exec_on_scheduler(self) -> bool: + return True + + def process( + self, + runtime_ctx: RuntimeContext, + input_datasets: List[DataSet], + output_path: str, + ) -> bool: + for dataset in input_datasets: + dataset.log(self.num_rows) + return True + + +class LogDataSet(PythonScriptNode): + def __init__( + self, ctx: Context, input_deps: Tuple[Node, ...], num_rows=200, **kwargs + ) -> None: + super().__init__(ctx, input_deps, **kwargs) + self.num_rows = num_rows + + def spawn(self, *args, **kwargs) -> LogDataSetTask: + task = LogDataSetTask(*args, **kwargs) + task.num_rows = self.num_rows + return task diff --git a/smallpond/contrib/warc.py b/smallpond/contrib/warc.py new file mode 100644 index 0000000..7f4e8dc --- /dev/null +++ b/smallpond/contrib/warc.py @@ -0,0 +1,143 @@ +import string +import sys +import unicodedata +from pathlib import PurePath +from typing import Iterable, List, Tuple +from urllib.parse import urlparse + +import pyarrow as arrow +import zstandard as zstd +from bs4 import BeautifulSoup +from loguru import logger +from warcio import ArchiveIterator + +from smallpond.common import MB +from smallpond.execution.task import RuntimeContext +from smallpond.io.arrow import dump_to_parquet_files +from smallpond.logical.dataset import DataSet +from smallpond.logical.node import ArrowStreamNode, PythonScriptNode + + +class ImportWarcFiles(PythonScriptNode): + + schema = arrow.schema( + [ + arrow.field("url", arrow.string()), + arrow.field("domain", arrow.string()), + arrow.field("date", arrow.string()), + arrow.field("content", arrow.binary()), + ] + ) + + def import_warc_file( + self, warc_path: PurePath, parquet_path: PurePath + ) -> Tuple[int, int]: + total_size = 0 + docs = [] + + with open(warc_path, "rb") as warc_file: + zstd_reader = zstd.ZstdDecompressor().stream_reader( + warc_file, read_size=16 * MB + ) + for record in ArchiveIterator(zstd_reader): + if record.rec_type == "response": + url = record.rec_headers.get_header("WARC-Target-URI") + domain = urlparse(url).netloc + date = record.rec_headers.get_header("WARC-Date") + content = record.content_stream().read() + total_size += len(content) + docs.append((url, domain, date, content)) + + table = arrow.Table.from_arrays( + [arrow.array(column) for column in zip(*docs)], schema=self.schema + ) + dump_to_parquet_files(table, parquet_path.parent, parquet_path.name) + return len(docs), total_size + + def process( + self, + runtime_ctx: RuntimeContext, + input_datasets: List[DataSet], + output_path: str, + ) -> bool: + warc_paths = [ + PurePath(warc_path) + for dataset in input_datasets + for warc_path in dataset.resolved_paths + ] + parquet_paths = [ + PurePath(output_path) + / f"data{path_index}-{PurePath(warc_path.name).with_suffix('.parquet')}" + for path_index, warc_path in enumerate(warc_paths) + ] + + logger.info(f"importing web pages from {len(warc_paths)} warc files...") + for warc_path, parquet_path in zip(warc_paths, parquet_paths): + try: + doc_count, total_size = self.import_warc_file(warc_path, parquet_path) + logger.info( + f"imported {doc_count} web pages ({total_size/MB:.3f}MB) from file '{warc_path}' to '{parquet_path}'" + ) + except Exception as ex: + logger.opt(exception=ex).error( + f"failed to import web pages from file '{warc_path}'" + ) + return False + return True + + +class ExtractHtmlBody(ArrowStreamNode): + + unicode_punctuation = "".join( + chr(i) + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") + ) + separator_str = string.whitespace + string.punctuation + unicode_punctuation + translator = str.maketrans(separator_str, " " * len(separator_str)) + + schema = arrow.schema( + [ + arrow.field("url", arrow.string()), + arrow.field("domain", arrow.string()), + arrow.field("date", arrow.string()), + arrow.field("tokens", arrow.list_(arrow.string())), + ] + ) + + def split_string(self, s: str): + return s.translate(self.translator).split() + + def extract_tokens(self, url: arrow.string, content: arrow.binary) -> List[str]: + tokens = [] + try: + doc = BeautifulSoup(content.as_py(), "lxml") + # if doc.title is not None and doc.title.string is not None: + # tokens.extend(self.split_string(doc.title.string.lower())) + tokens.extend(self.split_string(doc.get_text(" ", strip=True).lower())) + return tokens + except Exception as ex: + logger.opt(exception=ex).error( + f"failed to extract tokens from {url.as_py()}" + ) + return [] + + def process( + self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] + ) -> Iterable[arrow.Table]: + for batch in input_readers[0]: + urls, domains, dates, contents = batch.columns + doc_tokens = [] + try: + for i, (url, content) in enumerate(zip(urls, contents)): + tokens = self.extract_tokens(url, content) + logger.info( + f"#{i}/{len(urls)} extracted {len(tokens)} tokens from {url}" + ) + doc_tokens.append(tokens) + yield arrow.Table.from_arrays( + [urls, domains, dates, arrow.array(doc_tokens)], schema=self.schema + ) + except Exception as ex: + logger.opt(exception=ex).error(f"failed to extract tokens") + break diff --git a/smallpond/dataframe.py b/smallpond/dataframe.py new file mode 100644 index 0000000..4974910 --- /dev/null +++ b/smallpond/dataframe.py @@ -0,0 +1,715 @@ +from __future__ import annotations + +import os +import time +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import pandas as pd +import pyarrow as arrow +import ray +import ray.exceptions +from loguru import logger + +from smallpond.execution.task import Task +from smallpond.io.filesystem import remove_path +from smallpond.logical.dataset import * +from smallpond.logical.node import * +from smallpond.logical.optimizer import Optimizer +from smallpond.logical.planner import Planner +from smallpond.session import SessionBase + + +class Session(SessionBase): + # Extended session class with additional methods to create DataFrames. + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._nodes: List[Node] = [] + + self._node_to_tasks: Dict[Node, List[Task]] = {} + """ + When a DataFrame is evaluated, the tasks of the logical plan are stored here. + Subsequent DataFrames can reuse the tasks to avoid recomputation. + """ + + def read_csv( + self, paths: Union[str, List[str]], schema: Dict[str, str], delim="," + ) -> DataFrame: + """ + Create a DataFrame from CSV files. + """ + dataset = CsvDataSet(paths, OrderedDict(schema), delim) + plan = DataSourceNode(self._ctx, dataset) + return DataFrame(self, plan) + + def read_parquet( + self, + paths: Union[str, List[str]], + recursive: bool = False, + columns: Optional[List[str]] = None, + union_by_name: bool = False, + ) -> DataFrame: + """ + Create a DataFrame from Parquet files. + """ + dataset = ParquetDataSet( + paths, columns=columns, union_by_name=union_by_name, recursive=recursive + ) + plan = DataSourceNode(self._ctx, dataset) + return DataFrame(self, plan) + + def read_json( + self, paths: Union[str, List[str]], schema: Dict[str, str] + ) -> DataFrame: + """ + Create a DataFrame from JSON files. + """ + dataset = JsonDataSet(paths, schema) + plan = DataSourceNode(self._ctx, dataset) + return DataFrame(self, plan) + + def from_items(self, items: List[Any]) -> DataFrame: + """ + Create a DataFrame from a list of local Python objects. + """ + + assert isinstance(items, list), "items must be a list" + assert len(items) > 0, "items must not be empty" + if isinstance(items[0], dict): + return self.from_arrow(arrow.Table.from_pylist(items)) + else: + return self.from_arrow(arrow.table({"item": items})) + + def from_pandas(self, df: pd.DataFrame) -> DataFrame: + """ + Create a DataFrame from a pandas DataFrame. + """ + plan = DataSourceNode(self._ctx, PandasDataSet(df)) + return DataFrame(self, plan) + + def from_arrow(self, table: arrow.Table) -> DataFrame: + """ + Create a DataFrame from a pyarrow Table. + """ + plan = DataSourceNode(self._ctx, ArrowTableDataSet(table)) + return DataFrame(self, plan) + + def partial_sql(self, query: str, *inputs: DataFrame, **kwargs) -> DataFrame: + """ + Execute a SQL query on each partition of the input DataFrames. + + The query can contain placeholder `{0}`, `{1}`, etc. for the input DataFrames. + If multiple DataFrames are provided, they must have the same number of partitions. + + Examples + -------- + Join two datasets. You need to make sure the join key is correctly partitioned. + + .. code-block:: + + a = sp.read_parquet("a/*.parquet").repartition(10, hash_by="id") + b = sp.read_parquet("b/*.parquet").repartition(10, hash_by="id") + c = sp.partial_sql("select * from {0} join {1} on a.id = b.id", a, b) + """ + + plan = SqlEngineNode( + self._ctx, tuple(input.plan for input in inputs), query, **kwargs + ) + recompute = any(input.need_recompute for input in inputs) + return DataFrame(self, plan, recompute=recompute) + + def wait(self, *dfs: DataFrame): + """ + Wait for all DataFrames to be computed. + + Example + ------- + This can be used to wait for multiple outputs from a pipeline: + + .. code-block:: + + df = sp.read_parquet("input/*.parquet") + output1 = df.write_parquet("output1") + output2 = df.map("col1, col2").write_parquet("output2") + sp.wait(output1, output2) + """ + ray.get([task.run_on_ray() for df in dfs for task in df._get_or_create_tasks()]) + + def graph(self) -> Digraph: + """ + Get the DataFrame graph. + """ + dot = Digraph(comment="SmallPond") + for node in self._nodes: + dot.node(str(node.id), repr(node)) + for dep in node.input_deps: + dot.edge(str(dep.id), str(node.id)) + return dot + + def shutdown(self): + """ + Shutdown the session. + """ + # prevent shutdown from being called multiple times + if hasattr(self, "_shutdown_called"): + return + self._shutdown_called = True + + # log status + finished = self._all_tasks_finished() + with open(self._runtime_ctx.job_status_path, "a") as fout: + status = "success" if finished else "failure" + fout.write(f"{status}@{datetime.now():%Y-%m-%d-%H-%M-%S}\n") + + # clean up runtime directories if success + if finished: + logger.info("all tasks are finished, cleaning up") + self._runtime_ctx.cleanup(remove_output_root=self.config.remove_output_root) + else: + logger.warning("tasks are not finished!") + + super().shutdown() + + def _summarize_task(self) -> Tuple[int, int]: + """ + Return the total number of tasks and the number of tasks that are finished. + """ + dataset_refs = [ + task._dataset_ref + for tasks in self._node_to_tasks.values() + for task in tasks + if task._dataset_ref is not None + ] + ready_tasks, _ = ray.wait( + dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False + ) + return len(dataset_refs), len(ready_tasks) + + def _all_tasks_finished(self) -> bool: + """ + Check if all tasks are finished. + """ + dataset_refs = [ + task._dataset_ref + for tasks in self._node_to_tasks.values() + for task in tasks + ] + try: + ray.get(dataset_refs, timeout=0) + except Exception: + # GetTimeoutError is raised if any task is not finished + # RuntimeError is raised if any task failed + return False + return True + + +class DataFrame: + """ + A distributed data collection. It represents a 2 dimensional table of rows and columns. + + Internally, it's a wrapper around a `Node` and a `Session` required for execution. + """ + + def __init__(self, session: Session, plan: Node, recompute: bool = False): + self.session = session + self.plan = plan + self.optimized_plan: Optional[Node] = None + self.need_recompute = recompute + """Whether to recompute the data regardless of whether it's already computed.""" + + session._nodes.append(plan) + + def __str__(self) -> str: + return repr(self.plan) + + def _get_or_create_tasks(self) -> List[Task]: + """ + Get or create tasks to compute the data. + """ + # optimize the plan + if self.optimized_plan is None: + logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}") + self.optimized_plan = Optimizer( + exclude_nodes=set(self.session._node_to_tasks.keys()) + ).visit(self.plan) + logger.info( + f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}" + ) + # return the tasks if already created + if tasks := self.session._node_to_tasks.get(self.optimized_plan): + return tasks + + # remove all completed task files if recompute is needed + if self.need_recompute: + remove_path( + os.path.join( + self.session._runtime_ctx.completed_task_dir, + str(self.optimized_plan.id), + ) + ) + logger.info(f"cleared all results of {self.optimized_plan!r}") + + # create tasks for the optimized plan + planner = Planner(self.session._runtime_ctx) + # let planner update self.session._node_to_tasks + planner.node_to_tasks = self.session._node_to_tasks + return planner.visit(self.optimized_plan) + + def is_computed(self) -> bool: + """ + Check if the data is ready on disk. + """ + if tasks := self.session._node_to_tasks.get(self.plan): + _, unready_tasks = ray.wait(tasks, timeout=0) + return len(unready_tasks) == 0 + return False + + def compute(self) -> None: + """ + Compute the data. + + This operation will trigger execution of the lazy transformations performed on this DataFrame. + """ + self._compute() + + def _compute(self) -> List[DataSet]: + """ + Compute the data and return the datasets. + """ + for retry_count in range(3): + try: + return ray.get( + [task.run_on_ray() for task in self._get_or_create_tasks()] + ) + except ray.exceptions.RuntimeEnvSetupError as e: + # XXX: Ray may raise this error when a worker is interrupted. + # ``` + # ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment. + # Failed to create runtime env for job 01000000, status = IOError: + # on_read bad version, maybe there are some network problems, will fail the request. + # ``` + # This is a bug of Ray and has been fixed in Ray 2.24: + # However, since Ray dropped support for Python 3.8 since 2.11, we can not upgrade Ray. + # So we catch this error and retry by ourselves. + logger.error(f"found ray RuntimeEnvSetupError, retrying...\n{e}") + time.sleep(10 << retry_count) + raise RuntimeError("Failed to compute data after 3 retries") + + # operations + + def recompute(self) -> DataFrame: + """ + Always recompute the data regardless of whether it's already computed. + + Examples + -------- + Modify the code as follows and rerun: + + .. code-block:: diff + + - df = input.select('a') + + df = input.select('b').recompute() + + The result of `input` can be reused. + """ + self.need_recompute = True + return self + + def repartition( + self, + npartitions: int, + hash_by: Union[str, List[str], None] = None, + by: Optional[str] = None, + by_rows: bool = False, + **kwargs, + ) -> DataFrame: + """ + Repartition the data into the given number of partitions. + + Parameters + ---------- + npartitions + The dataset would be split and distributed to `npartitions` partitions. + If not specified, the number of partitions would be the default partition size of the context. + hash_by, optional + If specified, the dataset would be repartitioned by the hash of the given columns. + by, optional + If specified, the dataset would be repartitioned by the given column. + by_rows, optional + If specified, the dataset would be repartitioned by rows instead of by files. + + Examples + -------- + .. code-block:: + + df = df.repartition(10) # evenly distributed + df = df.repartition(10, by_rows=True) # evenly distributed by rows + df = df.repartition(10, hash_by='host') # hash partitioned + df = df.repartition(10, by='bucket') # partitioned by column + """ + if by is not None: + assert hash_by is None, "cannot specify both by and hash_by" + plan = ShuffleNode( + self.session._ctx, + (self.plan,), + npartitions, + data_partition_column=by, + **kwargs, + ) + elif hash_by is not None: + hash_columns = [hash_by] if isinstance(hash_by, str) else hash_by + plan = HashPartitionNode( + self.session._ctx, (self.plan,), npartitions, hash_columns, **kwargs + ) + else: + plan = EvenlyDistributedPartitionNode( + self.session._ctx, + (self.plan,), + npartitions, + partition_by_rows=by_rows, + **kwargs, + ) + return DataFrame(self.session, plan, recompute=self.need_recompute) + + def random_shuffle(self, **kwargs) -> DataFrame: + """ + Randomly shuffle all rows globally. + """ + + repartition = HashPartitionNode( + self.session._ctx, + (self.plan,), + self.plan.num_partitions, + random_shuffle=True, + **kwargs, + ) + plan = SqlEngineNode( + self.session._ctx, + (repartition,), + r"select * from {0} order by random()", + **kwargs, + ) + return DataFrame(self.session, plan, recompute=self.need_recompute) + + def partial_sort(self, by: Union[str, List[str]], **kwargs) -> DataFrame: + """ + Sort rows by the given columns in each partition. + + Parameters + ---------- + by + A column or a list of columns to sort by. + + Examples + -------- + .. code-block:: + + df = df.partial_sort(by='a') + df = df.partial_sort(by=['a', 'b desc']) + """ + + by = [by] if isinstance(by, str) else by + plan = SqlEngineNode( + self.session._ctx, + (self.plan,), + f"select * from {{0}} order by {', '.join(by)}", + **kwargs, + ) + return DataFrame(self.session, plan, recompute=self.need_recompute) + + def filter( + self, sql_or_func: Union[str, Callable[[Dict[str, Any]], bool]], **kwargs + ) -> DataFrame: + """ + Filter out rows that don't satisfy the given predicate. + + Parameters + ---------- + sql_or_func + A SQL expression or a predicate function. + For functions, it should take a dictionary of columns as input and returns a boolean. + SQL expression is preferred as it's more efficient. + + Examples + -------- + .. code-block:: + + df = df.filter('a > 1') + df = df.filter(lambda r: r['a'] > 1) + """ + if isinstance(sql := sql_or_func, str): + plan = SqlEngineNode( + self.session._ctx, + (self.plan,), + f"select * from {{0}} where ({sql})", + **kwargs, + ) + elif isinstance(func := sql_or_func, Callable): + + def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table: + table = tables[0] + return table.filter([func(row) for row in table.to_pylist()]) + + plan = ArrowBatchNode( + self.session._ctx, (self.plan,), process_func=process_func, **kwargs + ) + else: + raise ValueError( + "condition must be a SQL expression or a predicate function" + ) + return DataFrame(self.session, plan, recompute=self.need_recompute) + + def map( + self, + sql_or_func: Union[str, Callable[[Dict[str, Any]], Dict[str, Any]]], + *, + schema: Optional[arrow.Schema] = None, + **kwargs, + ) -> DataFrame: + """ + Apply a function to each row. + + Parameters + ---------- + sql_or_func + A SQL expression or a function to apply to each row. + For functions, it should take a dictionary of columns as input and returns a dictionary of columns. + SQL expression is preferred as it's more efficient. + schema, optional + The schema of the output DataFrame. + If not passed, will be inferred from the first row of the mapping values. + udfs, optional + A list of user defined functions to be referenced in the SQL expression. + + Examples + -------- + .. code-block:: + + df = df.map('a, b') + df = df.map('a + b as c') + df = df.map(lambda row: {'c': row['a'] + row['b']}) + + + Use user-defined functions in SQL expression: + + .. code-block:: + + @udf(params=[UDFType.INT, UDFType.INT], return_type=UDFType.INT) + def gcd(a: int, b: int) -> int: + while b: + a, b = b, a % b + return a + # load python udf + df = df.map('gcd(a, b)', udfs=[gcd]) + + # load udf from duckdb extension + df = df.map('gcd(a, b)', udfs=['path/to/udf.duckdb_extension']) + + """ + if isinstance(sql := sql_or_func, str): + plan = SqlEngineNode( + self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs + ) + elif isinstance(func := sql_or_func, Callable): + + def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table: + output_rows = [func(row) for row in tables[0].to_pylist()] + return arrow.Table.from_pylist(output_rows, schema=schema) + + plan = ArrowBatchNode( + self.session._ctx, (self.plan,), process_func=process_func, **kwargs + ) + else: + raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}") + return DataFrame(self.session, plan, recompute=self.need_recompute) + + def flat_map( + self, + sql_or_func: Union[str, Callable[[Dict[str, Any]], List[Dict[str, Any]]]], + *, + schema: Optional[arrow.Schema] = None, + **kwargs, + ) -> DataFrame: + """ + Apply a function to each row and flatten the result. + + Parameters + ---------- + sql_or_func + A SQL expression or a function to apply to each row. + For functions, it should take a dictionary of columns as input and returns a list of dictionaries. + SQL expression is preferred as it's more efficient. + schema, optional + The schema of the output DataFrame. + If not passed, will be inferred from the first row of the mapping values. + + Examples + -------- + .. code-block:: + + df = df.flat_map('unnest(array[a, b]) as c') + df = df.flat_map(lambda row: [{'c': row['a']}, {'c': row['b']}]) + """ + if isinstance(sql := sql_or_func, str): + + plan = SqlEngineNode( + self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs + ) + elif isinstance(func := sql_or_func, Callable): + + def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table: + output_rows = [ + item for row in tables[0].to_pylist() for item in func(row) + ] + return arrow.Table.from_pylist(output_rows, schema=schema) + + plan = ArrowBatchNode( + self.session._ctx, (self.plan,), process_func=process_func, **kwargs + ) + else: + raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}") + return DataFrame(self.session, plan, recompute=self.need_recompute) + + def map_batches( + self, + func: Callable[[arrow.Table], arrow.Table], + *, + batch_size: int = 122880, + **kwargs, + ) -> DataFrame: + """ + Apply the given function to batches of data. + + Parameters + ---------- + func + A function or a callable class to apply to each batch of data. + It should take a `arrow.Table` as input and returns a `arrow.Table`. + batch_size, optional + The number of rows in each batch. Defaults to 122880. + """ + + def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table: + return func(tables[0]) + + plan = ArrowBatchNode( + self.session._ctx, + (self.plan,), + process_func=process_func, + streaming_batch_size=batch_size, + **kwargs, + ) + return DataFrame(self.session, plan, recompute=self.need_recompute) + + def limit(self, limit: int) -> DataFrame: + """ + Limit the number of rows to the given number. + + Unlike `take`, this method doesn't trigger execution. + """ + plan = LimitNode(self.session._ctx, self.plan, limit) + return DataFrame(self.session, plan, recompute=self.need_recompute) + + def write_parquet(self, path: str) -> None: + """ + Write data to a series of parquet files under the given path. + + This is a blocking operation. See :func:`write_parquet_lazy` for a non-blocking version. + + Examples + -------- + .. code-block:: + + df.write_parquet('output') + """ + self.write_parquet_lazy(path).compute() + + def write_parquet_lazy(self, path: str) -> DataFrame: + """ + Write data to a series of parquet files under the given path. + + This is a non-blocking operation. See :func:`write_parquet` for a blocking version. + + Examples + -------- + .. code-block:: + + o1 = df.write_parquet_lazy('output1') + o2 = df.write_parquet_lazy('output2') + sp.wait(o1, o2) + """ + + plan = DataSinkNode( + self.session._ctx, (self.plan,), os.path.abspath(path), type="link_or_copy" + ) + return DataFrame(self.session, plan, recompute=self.need_recompute) + + # inspection + + def count(self) -> int: + """ + Count the number of rows. + + If this dataframe consists of more than a read, or if the row count can't be determined from + the metadata provided by the datasource, then this operation will trigger execution of the + lazy transformations performed on this dataframe. + """ + datasets = self._compute() + # FIXME: don't use ThreadPoolExecutor because duckdb results will be mixed up + return sum(dataset.num_rows for dataset in datasets) + + def take(self, limit: int) -> List[Dict[str, Any]]: + """ + Return up to `limit` rows. + + This operation will trigger execution of the lazy transformations performed on this DataFrame. + """ + if self.is_computed() or isinstance(self.plan, DataSourceNode): + datasets = self._compute() + else: + datasets = self.limit(limit)._compute() + rows = [] + for dataset in datasets: + for batch in dataset.to_batch_reader(): + rows.extend(batch.to_pylist()) + if len(rows) >= limit: + return rows[:limit] + return rows + + def take_all(self) -> List[Dict[str, Any]]: + """ + Return all rows. + + This operation will trigger execution of the lazy transformations performed on this DataFrame. + """ + datasets = self._compute() + rows = [] + for dataset in datasets: + for batch in dataset.to_batch_reader(): + rows.extend(batch.to_pylist()) + return rows + + def to_pandas(self) -> pd.DataFrame: + """ + Convert to a pandas DataFrame. + + This operation will trigger execution of the lazy transformations performed on this DataFrame. + """ + datasets = self._compute() + with ThreadPoolExecutor() as pool: + return pd.concat(pool.map(lambda dataset: dataset.to_pandas(), datasets)) + + def to_arrow(self) -> arrow.Table: + """ + Convert to an arrow Table. + + This operation will trigger execution of the lazy transformations performed on this DataFrame. + """ + datasets = self._compute() + with ThreadPoolExecutor() as pool: + return arrow.concat_tables( + pool.map(lambda dataset: dataset.to_arrow_table(), datasets) + ) diff --git a/smallpond/execution/__init__.py b/smallpond/execution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/smallpond/execution/driver.py b/smallpond/execution/driver.py new file mode 100644 index 0000000..4c17b46 --- /dev/null +++ b/smallpond/execution/driver.py @@ -0,0 +1,433 @@ +import argparse +import os +import socket +import sys +from multiprocessing import Process +from typing import List, Optional + +from loguru import logger + +import smallpond +from smallpond.common import DEFAULT_MAX_FAIL_COUNT, DEFAULT_MAX_RETRY_COUNT +from smallpond.dataframe import DataFrame +from smallpond.execution.task import ExecutionPlan +from smallpond.io.filesystem import load +from smallpond.logical.node import LogicalPlan + + +class Driver(object): + """ + A helper class that includes boilerplate code to execute a logical plan. + """ + + def __init__(self) -> None: + self.driver_args_parser = self._create_driver_args_parser() + self.user_args_parser = argparse.ArgumentParser(add_help=False) + self.driver_args = None + self.user_args = None + self.all_args = None + + def _create_driver_args_parser(self): + parser = argparse.ArgumentParser( + prog="driver.py", description="Smallpond Driver", add_help=False + ) + parser.add_argument( + "mode", choices=["executor", "scheduler", "ray"], default="executor" + ) + parser.add_argument( + "--exec_id", default=socket.gethostname(), help="Unique executor id" + ) + parser.add_argument("--job_id", type=str, help="Unique job id") + parser.add_argument( + "--job_time", type=float, help="Job create time (seconds since epoch)" + ) + parser.add_argument( + "--job_name", default="smallpond", help="Display name of the job" + ) + parser.add_argument( + "--job_priority", + type=int, + help="Job priority", + ) + parser.add_argument("--resource_group", type=str, help="Resource group") + parser.add_argument( + "--env_variables", nargs="*", default=[], help="Env variables for the job" + ) + parser.add_argument( + "--sidecars", nargs="*", default=[], help="Sidecars for the job" + ) + parser.add_argument( + "--tags", nargs="*", default=[], help="Tags for submitted platform task" + ) + parser.add_argument( + "--task_image", default="default", help="Container image of platform task" + ) + parser.add_argument( + "--python_venv", type=str, help="Python virtual env for the job" + ) + parser.add_argument( + "--data_root", + type=str, + help="The root folder for all files generated at runtime", + ) + parser.add_argument( + "--runtime_ctx_path", + default=None, + help="The path of pickled runtime context passed from scheduler to executor", + ) + parser.add_argument( + "--num_executors", + default=0, + type=int, + help="The number of nodes/executors (run all tasks on scheduler if set to zero)", + ) + parser.add_argument( + "--num_executors_per_task", + default=5, + type=int, + help="The number of nodes/executors in each platform task.", + ) + parser.add_argument( + "--random_seed", + type=int, + default=None, + help="Random seed for the job, default: int.from_bytes((os.urandom(128))", + ) + parser.add_argument( + "--max_retry", + "--max_retry_count", + dest="max_retry_count", + default=DEFAULT_MAX_RETRY_COUNT, + type=int, + help="The max number of times a task is retried by speculative execution", + ) + parser.add_argument( + "--max_fail", + "--max_fail_count", + dest="max_fail_count", + default=DEFAULT_MAX_FAIL_COUNT, + type=int, + help="The number of times a task is allowed to fail or crash before giving up", + ) + parser.add_argument( + "--prioritize_retry", + action="store_true", + help="Prioritize retry tasks in scheduling", + ) + parser.add_argument( + "--speculative_exec", + dest="speculative_exec", + choices=["disable", "enable", "aggressive"], + help="Level of speculative execution", + ) + parser.add_argument( + "--enable_speculative_exec", + dest="speculative_exec", + action="store_const", + const="enable", + help="Enable speculative execution", + ) + parser.add_argument( + "--disable_speculative_exec", + dest="speculative_exec", + action="store_const", + const="disable", + help="Disable speculative execution", + ) + parser.set_defaults(speculative_exec="enable") + parser.add_argument( + "--stop_executor_on_failure", + action="store_true", + help="Stop an executor if any task fails on it", + ) + parser.add_argument( + "--fail_fast", + "--fail_fast_on_failure", + dest="fail_fast_on_failure", + action="store_true", + help="Stop all executors if any task fails", + ) + parser.add_argument( + "--nonzero_exitcode_as_oom", + action="store_true", + help="Treat task crash as oom error", + ) + parser.add_argument( + "--fault_inject", + "--fault_inject_prob", + dest="fault_inject_prob", + type=float, + default=0.0, + help="Inject random errors at runtime (for test)", + ) + parser.add_argument( + "--enable_profiling", + action="store_true", + help="Enable Python profiling for each task", + ) + parser.add_argument( + "--enable_diagnostic_metrics", + action="store_true", + help="Enable diagnostic metrcis which may have performance impact", + ) + parser.add_argument( + "--disable_sched_state_dump", + dest="enable_sched_state_dump", + action="store_false", + help="Disable periodic dump of scheduler state", + ) + parser.add_argument( + "--enable_sched_state_dump", + dest="enable_sched_state_dump", + action="store_true", + help="Enable periodic dump of scheduler state so that scheduler can resume execution after restart", + ) + parser.set_defaults(enable_sched_state_dump=False) + parser.add_argument( + "--remove_empty_parquet", + action="store_true", + help="Remove empty parquet files from hash partition output", + ) + parser.add_argument( + "--remove_output_root", + action="store_true", + help="Remove all output files after job completed (for test)", + ) + parser.add_argument( + "--skip_task_with_empty_input", + action="store_true", + help="Skip running a task if any of its input datasets is empty", + ) + parser.add_argument( + "--self_contained_final_results", + action="store_true", + help="Build self-contained final results, i.e., create hard/symbolic links in output folder of final results", + ) + parser.add_argument( + "--malloc", + "--memory_allocator", + dest="memory_allocator", + default="mimalloc", + choices=["system", "jemalloc", "mimalloc"], + help="Override memory allocator used by worker processes", + ) + parser.add_argument( + "--memory_purge_delay", + default=10000, + help="The delay in milliseconds after which jemalloc/mimalloc will purge memory pages that are not in use.", + ) + parser.add_argument( + "--bind_numa_node", + action="store_true", + help="Bind executor processes to numa nodes.", + ) + parser.add_argument( + "--enforce_memory_limit", + action="store_true", + help="Set soft/hard memory limit for each task process", + ) + parser.add_argument( + "--enable_log_analytics", + dest="share_log_analytics", + action="store_true", + help="Share log analytics with smallpond team", + ) + parser.add_argument( + "--disable_log_analytics", + dest="share_log_analytics", + action="store_false", + help="Do not share log analytics with smallpond team", + ) + log_level_choices = [ + "CRITICAL", + "ERROR", + "WARNING", + "SUCCESS", + "INFO", + "DEBUG", + "TRACE", + ] + parser.add_argument( + "--console_log_level", + default="INFO", + choices=log_level_choices, + ) + parser.add_argument( + "--file_log_level", + default="DEBUG", + choices=log_level_choices, + ) + parser.add_argument( + "--disable_log_rotation", action="store_true", help="Disable log rotation" + ) + parser.add_argument( + "--output_path", + help="Set the output directory of final results and all nodes that have output_name but no output_path specified", + ) + parser.add_argument( + "--platform", + type=str, + help="The platform to use for the job. available: mpi", + ) + return parser + + def add_argument(self, *args, **kwargs): + """ + Add a command-line argument. This is a wrapper of `argparse.ArgumentParser.add_argument(...)`. + """ + self.user_args_parser.add_argument(*args, **kwargs) + + def parse_arguments(self, args=None): + if self.user_args is None or self.driver_args is None: + args_parser = argparse.ArgumentParser( + parents=[self.driver_args_parser, self.user_args_parser] + ) + self.all_args = args_parser.parse_args(args) + self.user_args, other_args = self.user_args_parser.parse_known_args(args) + self.driver_args = self.driver_args_parser.parse_args(other_args) + return self.user_args, self.driver_args + + def get_user_arguments(self, to_dict=True): + """ + Get user-defined arguments. + """ + user_args, _ = self.parse_arguments() + return vars(user_args) if to_dict else user_args + + get_arguments = get_user_arguments + + def get_driver_arguments(self, to_dict=True): + _, driver_args = self.parse_arguments() + return vars(driver_args) if to_dict else driver_args + + @property + def mode(self) -> str: + return self.get_driver_arguments(to_dict=False).mode + + @property + def job_id(self) -> str: + return self.get_driver_arguments(to_dict=False).job_id + + @property + def data_root(self) -> str: + return self.get_driver_arguments(to_dict=False).data_root + + @property + def num_executors(self) -> str: + return self.get_driver_arguments(to_dict=False).num_executors + + def run( + self, + plan: LogicalPlan, + stop_process_on_done=True, + tags: List[str] = None, + ) -> Optional[ExecutionPlan]: + """ + The entry point for executor and scheduler of `plan`. + """ + from smallpond.execution.executor import Executor + from smallpond.execution.manager import JobManager + from smallpond.execution.task import RuntimeContext + + _, args = self.parse_arguments() + retval = None + + def run_executor(runtime_ctx: RuntimeContext, exec_id: str, numa_node_id=None): + if numa_node_id is not None: + import numa + + exec_id += f".numa{numa_node_id}" + numa.schedule.bind(numa_node_id) + runtime_ctx.numa_node_id = numa_node_id + runtime_ctx.initialize(exec_id) + executor = Executor.create(runtime_ctx, exec_id) + return executor.run() + + if args.mode == "ray": + assert plan is not None + sp = smallpond.init(_remove_output_root=args.remove_output_root) + DataFrame(sp, plan.root_node).compute() + retval = True + elif args.mode == "executor": + assert os.path.isfile( + args.runtime_ctx_path + ), f"cannot find runtime context: {args.runtime_ctx_path}" + runtime_ctx: RuntimeContext = load(args.runtime_ctx_path) + + if runtime_ctx.bind_numa_node: + exec_procs = [ + Process( + target=run_executor, + args=(runtime_ctx, args.exec_id, numa_node_id), + ) + for numa_node_id in range(runtime_ctx.numa_node_count) + ] + for proc in exec_procs: + proc.start() + for proc in exec_procs: + proc.join() + retval = all(proc.exitcode == 0 for proc in exec_procs) + else: + retval = run_executor(runtime_ctx, args.exec_id) + elif args.mode == "scheduler": + assert plan is not None + jobmgr = JobManager( + args.data_root, args.python_venv, args.task_image, args.platform + ) + exec_plan = jobmgr.run( + plan, + job_id=args.job_id, + job_time=args.job_time, + job_name=args.job_name, + job_priority=args.job_priority, + num_executors=args.num_executors, + num_executors_per_task=args.num_executors_per_task, + resource_group=args.resource_group, + env_variables=args.env_variables, + sidecars=args.sidecars, + user_tags=args.tags + (tags or []), + random_seed=args.random_seed, + max_retry_count=args.max_retry_count, + max_fail_count=args.max_fail_count, + prioritize_retry=args.prioritize_retry, + speculative_exec=args.speculative_exec, + stop_executor_on_failure=args.stop_executor_on_failure, + fail_fast_on_failure=args.fail_fast_on_failure, + nonzero_exitcode_as_oom=args.nonzero_exitcode_as_oom, + fault_inject_prob=args.fault_inject_prob, + enable_profiling=args.enable_profiling, + enable_diagnostic_metrics=args.enable_diagnostic_metrics, + enable_sched_state_dump=args.enable_sched_state_dump, + remove_empty_parquet=args.remove_empty_parquet, + remove_output_root=args.remove_output_root, + skip_task_with_empty_input=args.skip_task_with_empty_input, + manifest_only_final_results=not args.self_contained_final_results, + memory_allocator=args.memory_allocator, + memory_purge_delay=args.memory_purge_delay, + bind_numa_node=args.bind_numa_node, + enforce_memory_limit=args.enforce_memory_limit, + share_log_analytics=args.share_log_analytics, + console_log_level=args.console_log_level, + file_log_level=args.file_log_level, + disable_log_rotation=args.disable_log_rotation, + output_path=args.output_path, + ) + retval = exec_plan if exec_plan.successful else None + + if stop_process_on_done: + exit_code = os.EX_OK if retval else os.EX_SOFTWARE + logger.info(f"exit code: {exit_code}") + sys.exit(exit_code) + else: + logger.info(f"return value: {repr(retval)}") + return retval + + +def main(): + # run in executor mode + driver = Driver() + driver.run(plan=None) + + +if __name__ == "__main__": + main() diff --git a/smallpond/execution/executor.py b/smallpond/execution/executor.py new file mode 100755 index 0000000..890a69a --- /dev/null +++ b/smallpond/execution/executor.py @@ -0,0 +1,341 @@ +import multiprocessing as mp +import os.path +import socket +import time +from collections import OrderedDict +from typing import Dict, List, Tuple + +from GPUtil import GPU +from loguru import logger + +from smallpond.common import NonzeroExitCode, pytest_running +from smallpond.execution.task import Probe, RuntimeContext +from smallpond.execution.workqueue import ( + StopExecutor, + StopWorkItem, + WorkItem, + WorkQueue, + WorkQueueOnFilesystem, + WorkStatus, +) + + +class SimplePoolTask(object): + def __init__(self, func, args, name: str): + self.proc: mp.Process = mp.Process(target=func, args=args, name=name) + self.stopping = False + + def start(self): + self.proc.start() + + def terminate(self): + self.proc.terminate() + self.stopping = True + + def join(self, timeout=None): + self.proc.join(timeout) + if not self.ready() and timeout is not None: + logger.warning( + f"worker process {self.proc.name}({self.proc.pid}) does not exit after {timeout} secs, stopping it" + ) + self.terminate() + self.proc.join() + + def ready(self): + return self.proc.pid and not self.proc.is_alive() + + def exitcode(self): + assert ( + self.ready() + ), f"worker process {self.proc.name}({self.proc.pid}) has not exited yet" + if self.stopping: + logger.info( + f"worker process stopped: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}" + ) + elif self.proc.exitcode != 0: + logger.error( + f"worker process crashed: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}" + ) + return self.proc.exitcode + + +class SimplePool(object): + def __init__(self, pool_size: int): + self.pool_size = pool_size + self.queued_tasks: List[SimplePoolTask] = [] + self.running_tasks: List[SimplePoolTask] = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.join(1) + + def apply_async(self, func, args, name: str = None): + task = SimplePoolTask(func, args, name) + self.queued_tasks.append(task) + return task + + def update_queue(self): + self.running_tasks = [t for t in self.running_tasks if not t.ready()] + tasks_to_run = self.queued_tasks[: self.pool_size - len(self.running_tasks)] + self.queued_tasks = self.queued_tasks[ + self.pool_size - len(self.running_tasks) : + ] + for task in tasks_to_run: + task.start() + self.running_tasks += tasks_to_run + + def join(self, timeout=None): + for task in self.running_tasks: + logger.info(f"joining process: {task.proc.name}({task.proc.pid})") + task.join(timeout) + + +class Executor(object): + """ + The task executor. + """ + + def __init__( + self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue + ) -> None: + self.ctx = ctx + self.id = id + self.wq = wq + self.cq = cq + self.running_works: Dict[str, Tuple[SimplePoolTask, WorkItem]] = OrderedDict() + self.running = True + self.epochs_to_skip = 0 + self.numa_node = self.ctx.numa_node_id + self.local_gpus = {gpu: 1.0 for gpu in self.ctx.get_local_gpus()} + """ { GPU: available_quota } """ + + def __str__(self) -> str: + return f"Executor({self.id}), running_works[{len(self.running_works)}]={list(self.running_works.keys())[:3]}..." + + @property + def busy(self) -> bool: + return len(self.running_works) > 0 + + @property + def available_gpu_quota(self) -> float: + return sum(self.local_gpus.values()) + + @staticmethod + def create(ctx: RuntimeContext, id: str) -> "Executor": + queue_dir = os.path.join(ctx.queue_root, id) + wq = WorkQueueOnFilesystem(os.path.join(queue_dir, "wq")) + cq = WorkQueueOnFilesystem(os.path.join(queue_dir, "cq")) + executor = Executor(ctx, id, wq, cq) + return executor + + @staticmethod + @logger.catch(reraise=True, message="work item failed unexpectedly") + def process_work(item: WorkItem, cq: WorkQueue): + item.exec(cq) + cq.push(item) + logger.info( + f"finished work: {repr(item)}, status: {item.status}, elapsed time: {item.elapsed_time:.3f} secs" + ) + logger.complete() + + # for test + def stop(self): + self.wq.push(StopExecutor(f".FailStop-{self.id}", ack=False)) + + # for test + def skip_probes(self, epochs: int): + self.wq.push( + Probe(self.ctx, f".FalseFail-{self.id}", epoch=0, epochs_to_skip=epochs) + ) + + @logger.catch(reraise=True, message="executor terminated unexpectedly") + def run(self) -> bool: + mp.current_process().name = "ExecutorMainProcess" + logger.info( + f"start to run executor {self.id} on numa node #{self.ctx.numa_node_id} of {socket.gethostname()}" + ) + + with SimplePool(self.ctx.usable_cpu_count + 1) as pool: + retval = self.exec_loop(pool) + + logger.info(f"executor exits: {self}") + logger.complete() + return retval + + def exec_loop(self, pool: SimplePool) -> bool: + stop_request = None + latest_probe_time = time.time() + + while self.running: + # get new work items + try: + items = self.wq.pop(count=self.ctx.usable_cpu_count) + except Exception as ex: + logger.opt(exception=ex).critical( + f"failed to pop from work queue: {self.wq}" + ) + self.running = False + items = [] + + if not items: + secs_quiet_period = time.time() - latest_probe_time + if ( + secs_quiet_period > self.ctx.secs_executor_probe_interval * 2 + and os.path.exists(self.ctx.job_status_path) + ): + with open(self.ctx.job_status_path) as status_file: + if ( + status := status_file.read().strip() + ) and not status.startswith("running"): + logger.critical( + f"job scheduler already stopped: {status}, stopping executor" + ) + self.running = False + break + if ( + secs_quiet_period > self.ctx.secs_executor_probe_timeout * 2 + and not pytest_running() + ): + logger.critical( + f"no probe received for {secs_quiet_period:.1f} secs, stopping executor" + ) + self.running = False + break + # no pending works, so wait a few seconds before checking results + time.sleep(self.ctx.secs_wq_poll_interval) + + for item in items: + if isinstance(item, StopExecutor): + logger.info(f"stop request received from scheduler: {item}") + stop_request = item + self.running = False + break + + if isinstance(item, StopWorkItem): + running_work = self.running_works.get(item.work_to_stop, None) + if running_work is None: + logger.debug( + f"cannot find {item.work_to_stop} in running works of {self.id}" + ) + self.cq.push(item) + else: + logger.info(f"stopping work: {item.work_to_stop}") + task, _ = running_work + task.terminate() + continue + + if isinstance(item, Probe): + latest_probe_time = time.time() + if item.epochs_to_skip > 0: + self.epochs_to_skip += item.epochs_to_skip + if self.epochs_to_skip > 0: + self.epochs_to_skip -= 1 + continue + + if self.numa_node is not None: + item._numa_node = self.numa_node + + # wait and allocate GPU to work item + if item.gpu_limit > 0: + if item.gpu_limit > len(self.local_gpus): + logger.warning( + f"task {item.key} requires more GPUs than physical GPUs, downgrading from {item.gpu_limit} to {len(self.local_gpus)}" + ) + item.gpu_limit = len(self.local_gpus) + # FIXME: this will block the executor if there is no available GPU + while not (granted_gpus := self.acquire_gpu(item.gpu_limit)): + logger.info(f"collecting finished works to find available GPUs") + self.collect_finished_works() + time.sleep(self.ctx.secs_wq_poll_interval) + item._local_gpu = granted_gpus + logger.info( + f"{repr(item)} is assigned to run on GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }" + ) + + # enqueue work item to the pool + self.running_works[item.key] = ( + pool.apply_async( + func=Executor.process_work, args=(item, self.cq), name=item.key + ), + item, + ) + logger.info( + f"started work: {repr(item)}, {len(self.running_works)} running works: {list(self.running_works.keys())[:10]}..." + ) + + # start to run works + pool.update_queue() + self.collect_finished_works() + + pool.join(self.ctx.secs_executor_probe_interval) + + if stop_request and stop_request.ack: + self.collect_finished_works() + stop_request.exec() + self.cq.push(stop_request) + + return True + + def collect_finished_works(self): + finished_works: List[WorkItem] = [] + for work, item in self.running_works.values(): + if not work.ready(): + continue + else: + work.join() + if (exitcode := work.exitcode()) != 0: + item.status = WorkStatus.CRASHED + item.exception = NonzeroExitCode( + f"worker process {work.proc.name}({work.proc.pid}) exited with non-zero code {exitcode}" + ) + try: + self.cq.push(item) + except Exception as ex: + logger.opt(exception=ex).critical( + f"failed to push into completion queue: {self.cq}" + ) + self.running = False + finished_works.append(item) + + # remove finished works + for item in finished_works: + self.running_works.pop(item.key) + if item._local_gpu: + self.release_gpu(item._local_gpu) + logger.info( + f"{repr(item)} released GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }" + ) + + def acquire_gpu(self, quota: float) -> Dict[GPU, float]: + """ + Acquire GPU resources with the given quota. + Return a dict of granted GPUs with their quotas. + `release_gpu` should be called later to release GPUs. + """ + if self.available_gpu_quota < quota: + # no enough GPU resources, return empty dict + return {} + granted_gpus: Dict[GPU, float] = {} + for gpu in self.local_gpus: + gpu_available_quota = self.local_gpus[gpu] + if gpu_available_quota <= 0: + continue + granted_quota = min(gpu_available_quota, quota) + granted_gpus[gpu] = granted_quota + self.local_gpus[gpu] -= granted_quota + quota -= granted_quota + if quota <= 0: + break + return granted_gpus + + def release_gpu(self, gpus: Dict[GPU, float]): + """ + Release GPU resources to the pool. + """ + for gpu, quota in gpus.items(): + self.local_gpus[gpu] += quota + assert ( + self.local_gpus[gpu] <= 1.0 + ), f"GPU {gpu} quota is greater than 1.0: {self.local_gpus[gpu]}" diff --git a/smallpond/execution/manager.py b/smallpond/execution/manager.py new file mode 100644 index 0000000..a516f5e --- /dev/null +++ b/smallpond/execution/manager.py @@ -0,0 +1,283 @@ +import os +import shutil +import socket +import sys +from datetime import datetime +from typing import Dict, List, Literal, Optional + +from loguru import logger + +import smallpond +from smallpond.common import DEFAULT_MAX_FAIL_COUNT, DEFAULT_MAX_RETRY_COUNT, MB +from smallpond.execution.scheduler import Scheduler +from smallpond.execution.task import ExecutionPlan, JobId, RuntimeContext +from smallpond.io.filesystem import dump, load +from smallpond.logical.node import LogicalPlan +from smallpond.logical.planner import Planner +from smallpond.platform import get_platform + + +class SchedStateExporter(Scheduler.StateObserver): + def __init__(self, sched_state_path: str) -> None: + super().__init__() + self.sched_state_path = sched_state_path + + def update(self, sched_state: Scheduler): + if sched_state.large_runtime_state: + logger.debug(f"pause exporting scheduler state") + elif sched_state.num_local_running_works > 0: + logger.debug( + f"pause exporting scheduler state: {sched_state.num_local_running_works} local running works" + ) + else: + dump( + sched_state, self.sched_state_path, buffering=32 * MB, atomic_write=True + ) + sched_state.log_overall_progress() + logger.debug(f"exported scheduler state to {self.sched_state_path}") + + +class JobManager(object): + + jemalloc_filename = "libjemalloc.so.2" + mimalloc_filename = "libmimalloc.so.2.1" + + env_template = r""" + ARROW_DEFAULT_MEMORY_POOL={arrow_default_malloc} + ARROW_IO_THREADS=2 + OMP_NUM_THREADS=2 + POLARS_MAX_THREADS=2 + NUMEXPR_MAX_THREADS=2 +""" + + def __init__( + self, + data_root: Optional[str] = None, + python_venv: Optional[str] = None, + task_image: str = "default", + platform: Optional[str] = None, + ) -> None: + self.platform = get_platform(platform) + self.data_root = os.path.abspath(data_root or self.platform.default_data_root()) + self.python_venv = python_venv + self.task_image = task_image + + @logger.catch(reraise=True, message="job manager terminated unexpectedly") + def run( + self, + plan: LogicalPlan, + job_id: Optional[str] = None, + job_time: Optional[float] = None, + job_name: str = "smallpond", + job_priority: Optional[int] = None, + num_executors: int = 1, + num_executors_per_task: int = 5, + resource_group: str = "localhost", + env_variables: List[str] = None, + sidecars: List[str] = None, + user_tags: List[str] = None, + random_seed: int = None, + max_retry_count: int = DEFAULT_MAX_RETRY_COUNT, + max_fail_count: int = DEFAULT_MAX_FAIL_COUNT, + prioritize_retry=False, + speculative_exec: Literal["disable", "enable", "aggressive"] = "enable", + stop_executor_on_failure=False, + fail_fast_on_failure=False, + nonzero_exitcode_as_oom=False, + fault_inject_prob: float = 0.0, + enable_profiling=False, + enable_diagnostic_metrics=False, + enable_sched_state_dump=False, + remove_empty_parquet=False, + remove_output_root=False, + skip_task_with_empty_input=False, + manifest_only_final_results=True, + memory_allocator: Literal["system", "jemalloc", "mimalloc"] = "mimalloc", + memory_purge_delay: int = 10000, + bind_numa_node=False, + enforce_memory_limit=False, + share_log_analytics: Optional[bool] = None, + console_log_level: Literal[ + "CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE" + ] = "INFO", + file_log_level: Literal[ + "CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE" + ] = "DEBUG", + disable_log_rotation=False, + sched_state_observers: Optional[List[Scheduler.StateObserver]] = None, + output_path: Optional[str] = None, + **kwargs, + ) -> ExecutionPlan: + logger.info(f"using platform: {self.platform}") + + job_id = JobId(hex=job_id or self.platform.default_job_id()) + job_time = ( + datetime.fromtimestamp(job_time) + if job_time is not None + else self.platform.default_job_time() + ) + + malloc_path = "" + if memory_allocator == "system": + malloc_path = "" + elif memory_allocator == "jemalloc": + malloc_path = shutil.which(self.jemalloc_filename) + elif memory_allocator == "mimalloc": + malloc_path = shutil.which(self.mimalloc_filename) + else: + logger.critical(f"failed to find memory allocator: {memory_allocator}") + + env_overrides = self.env_template.format( + arrow_default_malloc=memory_allocator, + ).splitlines() + env_overrides = env_overrides + (env_variables or []) + env_overrides = dict( + tuple(kv.strip().split("=", maxsplit=1)) + for kv in filter(None, env_overrides) + ) + + share_log_analytics = ( + share_log_analytics + if share_log_analytics is not None + else self.platform.default_share_log_analytics() + ) + shared_log_root = ( + self.platform.shared_log_root() if share_log_analytics else None + ) + + runtime_ctx = RuntimeContext( + job_id, + job_time, + self.data_root, + num_executors=num_executors, + random_seed=random_seed, + env_overrides=env_overrides, + bind_numa_node=bind_numa_node, + enforce_memory_limit=enforce_memory_limit, + fault_inject_prob=fault_inject_prob, + enable_profiling=enable_profiling, + enable_diagnostic_metrics=enable_diagnostic_metrics, + remove_empty_parquet=remove_empty_parquet, + skip_task_with_empty_input=skip_task_with_empty_input, + shared_log_root=shared_log_root, + console_log_level=console_log_level, + file_log_level=file_log_level, + disable_log_rotation=disable_log_rotation, + output_path=output_path, + **kwargs, + ) + runtime_ctx.initialize(socket.gethostname(), root_exist_ok=True) + logger.info( + f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}" + ) + + dump(runtime_ctx, runtime_ctx.runtime_ctx_path, atomic_write=True) + logger.info(f"saved runtime context at {runtime_ctx.runtime_ctx_path}") + + dump(plan, runtime_ctx.logcial_plan_path, atomic_write=True) + logger.info(f"saved logcial plan at {runtime_ctx.logcial_plan_path}") + + plan.graph().render(runtime_ctx.logcial_plan_graph_path, format="png") + logger.info( + f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png" + ) + + exec_plan = Planner(runtime_ctx).create_exec_plan( + plan, manifest_only_final_results + ) + dump(exec_plan, runtime_ctx.exec_plan_path, atomic_write=True) + logger.info(f"saved execution plan at {runtime_ctx.exec_plan_path}") + + sidecar_list = sidecars or [] + + fs_name, cluster = self.data_root.split("/")[1:3] + tag_list = [ + "smallpond", + "executor", + smallpond.__version__, + job_name, + fs_name, + cluster, + f"malloc:{memory_allocator}", + ] + (user_tags or []) + + if self.python_venv: + tag_list.append(self.python_venv) + if fail_fast_on_failure: + max_fail_count = 0 + if max_fail_count == 0: + tag_list.append("fail_fast") + tag_list.append(f"max_fail:{max_fail_count}") + + tag_list.append(f"speculative_exec:{speculative_exec}") + tag_list.append(f"max_retry:{max_retry_count}") + + if prioritize_retry: + tag_list.append("prioritize_retry") + if stop_executor_on_failure: + tag_list.append("stop_executor") + if bind_numa_node: + tag_list.append("bind_numa_node") + if enforce_memory_limit: + tag_list.append("enforce_memory_limit") + nonzero_exitcode_as_oom = True + + sched_state_observers = sched_state_observers or [] + + if enable_sched_state_dump: + sched_state_exporter = SchedStateExporter(runtime_ctx.sched_state_path) + sched_state_observers.insert(0, sched_state_exporter) + + if os.path.exists(runtime_ctx.sched_state_path): + logger.warning( + f"loading scheduler state from: {runtime_ctx.sched_state_path}" + ) + scheduler: Scheduler = load(runtime_ctx.sched_state_path) + scheduler.sched_epoch += 1 + scheduler.sched_state_observers = sched_state_observers + else: + scheduler = Scheduler( + exec_plan, + max_retry_count, + max_fail_count, + prioritize_retry, + speculative_exec, + stop_executor_on_failure, + nonzero_exitcode_as_oom, + remove_output_root, + sched_state_observers, + ) + # start executors + self.platform.start_job( + num_nodes=num_executors, + entrypoint=os.path.join(os.path.dirname(__file__), "driver.py"), + args=[ + "--job_id", + str(job_id), + "--data_root", + self.data_root, + "--runtime_ctx_path", + runtime_ctx.runtime_ctx_path, + "executor", + ], + envs={ + "LD_PRELOAD": malloc_path, + "MALLOC_CONF": f"percpu_arena:percpu,background_thread:true,metadata_thp:auto,dirty_decay_ms:{memory_purge_delay},muzzy_decay_ms:{memory_purge_delay},oversize_threshold:0,lg_tcache_max:16", + "MIMALLOC_PURGE_DELAY": memory_purge_delay, + }, + extra_opts=dict( + job_id=job_id, + job_name=job_name, + job_priority=job_priority, + num_executors_per_task=num_executors_per_task, + resource_group=resource_group, + image=self.task_image, + python_venv=self.python_venv, + tags=tag_list, + sidecars=sidecar_list, + ), + ) + + # run scheduler + scheduler.run() + return scheduler.exec_plan diff --git a/smallpond/execution/scheduler.py b/smallpond/execution/scheduler.py new file mode 100644 index 0000000..f64135a --- /dev/null +++ b/smallpond/execution/scheduler.py @@ -0,0 +1,1369 @@ +import copy +import cProfile +import itertools +import multiprocessing as mp +import os +import queue +import shutil +import socket +import sys +import time +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from enum import Enum +from functools import cached_property +from typing import Any, Callable, Dict, Iterable, List, Literal, Set, Tuple, Union + +import numpy as np +from loguru import logger + +from smallpond.common import ( + DEFAULT_MAX_FAIL_COUNT, + DEFAULT_MAX_RETRY_COUNT, + GB, + MB, + pytest_running, +) +from smallpond.execution.task import ( + ExecutionPlan, + Probe, + RuntimeContext, + Task, + TaskRuntimeId, + WorkStatus, +) +from smallpond.execution.workqueue import ( + StopExecutor, + StopWorkItem, + WorkItem, + WorkQueue, + WorkQueueInMemory, + WorkQueueOnFilesystem, +) +from smallpond.io.filesystem import dump, remove_path +from smallpond.logical.node import LogicalPlan, Node +from smallpond.utility import cprofile_to_string + + +class ExecutorState(Enum): + GOOD = 1 + FAIL = 2 + RESOURCE_LOW = 3 + STOPPING = 4 + STOPPED = 5 + + +class RemoteExecutor(object): + def __init__( + self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue, init_epoch=0 + ) -> None: + self.ctx = ctx + self.id = id + self.wq = wq + self.cq = cq + self.running_works: Dict[str, WorkItem] = {} + self.state = ExecutorState.RESOURCE_LOW + self.last_acked_probe = Probe(self.ctx, f"Probe-{self.id}#{0}", init_epoch) + self.stop_request_sent = False + self.stop_request_acked = False + self._allocated_cpus = 0 + self._allocated_gpus = 0 + self._allocated_memory = 0 + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.id}), running_works[{len(self.running_works)}]={list(self.running_works.keys())[:3]}..., \ +allocated={self.allocated_cpus}CPUs/{self.allocated_gpus}GPUs/{self.allocated_memory//GB}GB, \ +state={self.state}, probe={self.last_acked_probe}" + + def __repr__(self) -> str: + return f"RemoteExecutor({self.id}):{self.state}" + + @staticmethod + def create( + ctx: RuntimeContext, id: str, queue_dir: str, init_epoch=0 + ) -> "RemoteExecutor": + wq = WorkQueueOnFilesystem(os.path.join(queue_dir, "wq")) + cq = WorkQueueOnFilesystem(os.path.join(queue_dir, "cq")) + return RemoteExecutor(ctx, id, wq, cq, init_epoch) + + @property + def idle(self) -> bool: + return len(self.running_works) == 0 + + @property + def busy(self) -> bool: + return ( + len(self.running_works) >= self.max_running_works + or (self.cpu_count > 0 and self.allocated_cpus >= self.cpu_count) + or (self.gpu_count > 0 and self.allocated_gpus >= self.gpu_count) + or (self.memory_size > 0 and self.allocated_memory >= self.memory_size) + ) + + @property + def local(self) -> bool: + return False + + @property + def good(self) -> bool: + return self.state == ExecutorState.GOOD + + @property + def fail(self) -> bool: + return self.state == ExecutorState.FAIL + + @property + def stopping(self) -> bool: + return self.state == ExecutorState.STOPPING + + @property + def stopped(self) -> bool: + return self.state == ExecutorState.STOPPED + + @property + def resource_low(self) -> bool: + return self.state == ExecutorState.RESOURCE_LOW + + @property + def working(self) -> bool: + return self.state in (ExecutorState.GOOD, ExecutorState.RESOURCE_LOW) + + @property + def alive(self) -> bool: + return self.state in ( + ExecutorState.GOOD, + ExecutorState.RESOURCE_LOW, + ExecutorState.STOPPING, + ) + + @property + def cpu_count(self) -> int: + return self.last_acked_probe.cpu_count + + @property + def gpu_count(self) -> int: + return self.last_acked_probe.gpu_count + + @property + def memory_size(self) -> int: + return self.last_acked_probe.total_memory + + @property + def allocated_cpus(self) -> int: + return self._allocated_cpus + + @property + def allocated_gpus(self) -> int: + return self._allocated_gpus + + @property + def allocated_memory(self) -> int: + return self._allocated_memory + + @property + def available_cpus(self) -> int: + return self.cpu_count - self.allocated_cpus + + @property + def available_memory(self) -> int: + return self.memory_size - self.allocated_memory + + @property + def max_running_works(self) -> int: + # limit max number of running works on an executor: reserve 1/16 cpu cores for filesystem and others + return self.cpu_count - self.cpu_count // 16 + + def add_running_work(self, item: WorkItem): + assert ( + item.key not in self.running_works + ), f"duplicate work item assigned to {repr(self)}: {item.key}" + self.running_works[item.key] = item + self._allocated_cpus += item.cpu_limit + self._allocated_gpus += item.gpu_limit + self._allocated_memory += item.memory_limit + + def pop_running_work(self, key: str): + if (item := self.running_works.pop(key, None)) is None: + logger.debug(f"work item {key} not found in running works of {self}") + else: + self._allocated_cpus -= item.cpu_limit + self._allocated_gpus -= item.gpu_limit + self._allocated_memory -= item.memory_limit + return item + + def pop(self) -> List[Task]: + finished_items = self.cq.pop(count=max(1, len(self.running_works))) + finished_tasks = [] + + for item in finished_items: + if isinstance(item, Probe): + self.last_acked_probe = max(self.last_acked_probe, item) + continue + + if isinstance(item, StopWorkItem): + self.pop_running_work(item.work_to_stop) + logger.debug(f"work item stopped: {item}") + continue + + if isinstance(item, StopExecutor): + self.stop_request_acked |= self.state != ExecutorState.STOPPED + logger.info(f"executor stopped: {item}") + else: + assert isinstance(item, Task), f"unexpected work item type: {item}" + item.finish_time = item.finish_time or time.time() + finished_tasks.append(item) + + if item.status != WorkStatus.INCOMPLETE: + self.pop_running_work(item.key) + + return finished_tasks + + def push(self, item: WorkItem, buffering=False) -> bool: + if item.key in self.running_works: + logger.warning( + f"work item {item.key} already exists in running works of {self}" + ) + return False + item.start_time = time.time() + item.exec_id = self.id + self.add_running_work(item) + return self.wq.push(item, buffering) + + def flush(self) -> bool: + return self.wq.flush() + + def probe(self, epoch: int): + self.wq.push(Probe(self.ctx, f".Probe-{self.id}#{epoch:06d}", epoch)) + + def stop(self): + if self.working and not self.stop_request_sent: + logger.info(f"stopping remote executor: {self}") + self.push(StopExecutor(f".StopExecutor-{self.id}")) + self.stop_request_sent = True + + def reset_state(self, current_epoch: int): + self.__init__(self.ctx, self.id, self.wq, self.cq, current_epoch) + + def update_state(self, current_epoch: int) -> bool: + num_missed_probes = current_epoch - self.last_acked_probe.epoch + if self.state == ExecutorState.STOPPED: + return False + elif num_missed_probes > self.ctx.max_num_missed_probes: + if self.state != ExecutorState.FAIL: + self.state = ExecutorState.FAIL + logger.error( + f"find failed executor: {self}, missed probes: {num_missed_probes}, current epoch: {current_epoch}" + ) + return True + elif self.state == ExecutorState.STOPPING: + if self.stop_request_acked: + self.state = ExecutorState.STOPPED + logger.info(f"find stopped executor: {self}") + return True + elif self.stop_request_sent: + if self.state != ExecutorState.STOPPING: + self.state = ExecutorState.STOPPING + return True + elif self.last_acked_probe.resource_low: + if self.state != ExecutorState.RESOURCE_LOW: + self.state = ExecutorState.RESOURCE_LOW + logger.warning(f"find low-resource executor: {self}") + return True + elif self.last_acked_probe.status == WorkStatus.SUCCEED: + if self.state != ExecutorState.GOOD: + self.state = ExecutorState.GOOD + logger.info(f"find working executor: {self}") + return True + return False + + +class LocalExecutor(RemoteExecutor): + def __init__( + self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue + ) -> None: + super().__init__(ctx, id, wq, cq) + self.work = None + self.running = False + + def __getstate__(self): + state = self.__dict__.copy() + del state["wq"] + del state["cq"] + del state["work"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.wq = WorkQueueInMemory(queue_type=queue.Queue) + self.cq = WorkQueueInMemory(queue_type=queue.Queue) + self.work = None + + @staticmethod + def create(ctx: RuntimeContext, id: str) -> "LocalExecutor": + wq = WorkQueueInMemory(queue_type=queue.Queue) + cq = WorkQueueInMemory(queue_type=queue.Queue) + return LocalExecutor(ctx, id, wq, cq) + + @logger.catch(reraise=True, message="local executor terminated unexpectedly") + def run(self): + logger.info(f"local executor started: {self.id}") + local_gpus = self.ctx.get_local_gpus() + + while self.running: + items = self.wq.pop() + + if len(items) == 0: + time.sleep(self.ctx.secs_wq_poll_interval) + continue + + for item in items: + if not self.running: + break + if item.gpu_limit > 0: + assert len(local_gpus) > 0 + item._local_gpu = local_gpus[0] + logger.info( + f"{repr(item)} is assigned to run on GPU #{item.local_rank}: {item.local_gpu}" + ) + + item = copy.copy(item) + item.exec() + self.cq.push(item, buffering=True) + + self.cq.flush() + + logger.info(f"local executor exits: {self.id}") + logger.complete() + + @property + def local(self) -> bool: + return True + + def start(self, pool: ThreadPoolExecutor): + self.running = True + self.running_works.clear() + self.work = pool.submit(self.run) + self.state = ExecutorState.GOOD + + def stop(self): + if self.working: + logger.info(f"stopping local executor: {self}") + self.running = False + self.state = ExecutorState.STOPPING + self.work.result(timeout=self.ctx.secs_executor_probe_interval) + self.state = ExecutorState.STOPPED + + +class Scheduler(object): + """ + The task scheduler. + """ + + large_num_nontrivial_tasks = 200 if pytest_running() else 20000 + StateCallback = Callable[["Scheduler"], Any] + + class StateObserver(object): + def __init__(self, callback: "Scheduler.StateCallback" = None) -> None: + assert callback is None or isinstance(callback, Callable) + self.enabled = True + self.callback = callback + + def __repr__(self) -> str: + return ( + repr(self.callback) if self.callback is not None else super().__repr__() + ) + + __str__ = __repr__ + + def update(self, sched_state: "Scheduler"): + assert self.callback is not None + self.callback(sched_state) + + def __init__( + self, + exec_plan: ExecutionPlan, + max_retry_count: int = DEFAULT_MAX_RETRY_COUNT, + max_fail_count: int = DEFAULT_MAX_FAIL_COUNT, + prioritize_retry=False, + speculative_exec: Literal["disable", "enable", "aggressive"] = "enable", + stop_executor_on_failure=False, + nonzero_exitcode_as_oom=False, + remove_output_root=False, + sched_state_observers=None, + ) -> None: + self.ctx = exec_plan.ctx + self.exec_plan = exec_plan + self.logical_plan: LogicalPlan = self.exec_plan.logical_plan + self.logical_nodes = self.logical_plan.nodes + self.max_retry_count = max_retry_count + self.max_fail_count = max_fail_count + self.standalone_mode = self.ctx.num_executors == 0 + self.prioritize_retry = prioritize_retry + self.disable_speculative_exec = speculative_exec == "disable" + self.aggressive_speculative_exec = speculative_exec == "aggressive" + self.stop_executor_on_failure = stop_executor_on_failure + self.nonzero_exitcode_as_oom = nonzero_exitcode_as_oom + self.remove_output_root = remove_output_root + self.sched_state_observers: List[Scheduler.StateObserver] = ( + sched_state_observers or [] + ) + self.secs_state_notify_interval = self.ctx.secs_executor_probe_interval * 2 + # task states + self.local_queue: List[Task] = [] + self.sched_queue: List[Task] = [] + self.tasks: Dict[str, Task] = self.exec_plan.tasks + self.scheduled_tasks: Dict[TaskRuntimeId, Task] = OrderedDict() + self.finished_tasks: Dict[TaskRuntimeId, Task] = OrderedDict() + self.succeeded_tasks: Dict[str, Task] = OrderedDict() + self.nontrivial_tasks = dict( + (key, task) + for (key, task) in self.tasks.items() + if not task.exec_on_scheduler + ) + self.succeeded_nontrivial_tasks: Dict[str, Task] = OrderedDict() + # executor pool + self.local_executor = LocalExecutor.create(self.ctx, "localhost") + self.available_executors = {self.local_executor.id: self.local_executor} + # other runtime states + self.sched_running = False + self.sched_start_time = 0 + self.last_executor_probe_time = 0 + self.last_state_notify_time = 0 + self.probe_epoch = 0 + self.sched_epoch = 0 + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.run() + + def __getstate__(self): + state = self.__dict__.copy() + del state["sched_state_observers"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.sched_state_observers = [] + + @property + def elapsed_time(self): + return time.time() - self.sched_start_time + + @property + def success(self) -> bool: + return self.exec_plan.root_task.key in self.succeeded_tasks + + @property + def progress(self) -> Tuple[int, int, float]: + num_succeeded = len(self.succeeded_nontrivial_tasks) + num_tasks = len(self.nontrivial_tasks) + return num_succeeded, num_tasks, num_succeeded / num_tasks * 100.0 + + @property + def large_runtime_state(self) -> bool: + return len(self.nontrivial_tasks) > self.large_num_nontrivial_tasks + + @property + def running_works(self) -> Iterable[WorkItem]: + return ( + work + for executor in (self.alive_executors + self.local_executors) + for work in executor.running_works.values() + ) + + @property + def num_running_works(self) -> int: + return sum( + len(executor.running_works) + for executor in (self.alive_executors + self.local_executors) + ) + + @property + def num_local_running_works(self) -> int: + return sum(len(executor.running_works) for executor in self.local_executors) + + @property + def num_pending_tasks(self) -> int: + assert len(self.tasks) >= len( + self.succeeded_tasks + ), f"number of tasks {len(self.tasks)} < number of succeeded tasks {len(self.succeeded_tasks)}" + return len(self.tasks) - len(self.succeeded_tasks) + + @property + def pending_nontrivial_tasks(self) -> Dict[str, Task]: + return dict( + (key, task) + for key, task in self.nontrivial_tasks.items() + if key not in self.succeeded_nontrivial_tasks + ) + + @property + def num_pending_nontrivial_tasks(self) -> int: + assert len(self.nontrivial_tasks) >= len( + self.succeeded_nontrivial_tasks + ), f"number of nontrivial tasks {len(self.nontrivial_tasks)} < number of succeeded nontrivial tasks {len(self.succeeded_nontrivial_tasks)}" + return len(self.nontrivial_tasks) - len(self.succeeded_nontrivial_tasks) + + @property + def succeeded_task_ids(self) -> Set[TaskRuntimeId]: + return set( + TaskRuntimeId(task.id, task.sched_epoch, task.retry_count) + for task in self.succeeded_tasks.values() + ) + + @property + def abandoned_tasks(self) -> List[Task]: + succeeded_task_ids = self.succeeded_task_ids + return [ + task + for task in {**self.scheduled_tasks, **self.finished_tasks}.values() + if task.runtime_id not in succeeded_task_ids + ] + + @cached_property + def remote_executors(self) -> List[RemoteExecutor]: + return [ + executor + for executor in self.available_executors.values() + if not executor.local + ] + + @cached_property + def local_executors(self) -> List[RemoteExecutor]: + return [ + executor for executor in self.available_executors.values() if executor.local + ] + + @cached_property + def working_executors(self) -> List[RemoteExecutor]: + return [executor for executor in self.remote_executors if executor.working] + + @cached_property + def alive_executors(self) -> List[RemoteExecutor]: + return [executor for executor in self.remote_executors if executor.alive] + + @cached_property + def good_executors(self) -> List[RemoteExecutor]: + return [executor for executor in self.remote_executors if executor.good] + + @cached_property + def failed_executors(self) -> List[RemoteExecutor]: + return [executor for executor in self.remote_executors if executor.fail] + + @cached_property + def stopped_executors(self) -> List[RemoteExecutor]: + return [executor for executor in self.remote_executors if executor.stopped] + + @cached_property + def stopping_executors(self) -> List[RemoteExecutor]: + return [executor for executor in self.remote_executors if executor.stopping] + + @cached_property + def low_resource_executors(self) -> List[RemoteExecutor]: + return [executor for executor in self.remote_executors if executor.resource_low] + + def suspend_good_executors(self): + for executor in self.good_executors: + executor.reset_state(self.probe_epoch) + self.clear_cached_executor_lists() + + def clear_cached_executor_lists(self): + if hasattr(self, "remote_executors"): + del self.remote_executors + if hasattr(self, "local_executors"): + del self.local_executors + if hasattr(self, "working_executors"): + del self.working_executors + if hasattr(self, "alive_executors"): + del self.alive_executors + if hasattr(self, "good_executors"): + del self.good_executors + if hasattr(self, "failed_executors"): + del self.failed_executors + if hasattr(self, "stopped_executors"): + del self.stopped_executors + if hasattr(self, "stopping_executors"): + del self.stopping_executors + if hasattr(self, "low_resource_executors"): + del self.low_resource_executors + + def stop_executors(self): + for exec in self.available_executors.values(): + exec.stop() + + def start_speculative_execution(self): + for executor in self.working_executors: + for idx, item in enumerate(executor.running_works.values()): + aggressive_retry = ( + self.aggressive_speculative_exec + and len(self.good_executors) >= self.ctx.num_executors + ) + short_sched_queue = len(self.sched_queue) < len(self.good_executors) + if ( + isinstance(item, Task) + and item.key not in self.succeeded_tasks + and item.allow_speculative_exec + and item.retry_count < self.max_retry_count + and item.retry_count == self.tasks[item.key].retry_count + and (logical_node := self.logical_nodes.get(item.node_id, None)) + is not None + ): + perf_stats = logical_node.get_perf_stats("elapsed wall time (secs)") + if perf_stats is not None and perf_stats.cnt >= 20: + if short_sched_queue: + retry_threshold = max( + self.ctx.secs_executor_probe_timeout, + perf_stats.p95 - perf_stats.p50, + ) + elif aggressive_retry: + retry_threshold = max( + self.ctx.secs_executor_probe_timeout, + perf_stats.p99 - perf_stats.p50, + ) + else: + retry_threshold = max( + self.ctx.secs_executor_probe_timeout, + perf_stats.p99 - perf_stats.p50, + ) * (2 + item.retry_count) + excess_time = item.elapsed_time - perf_stats.p50 + if excess_time >= retry_threshold: + logger.warning( + f"retry long-running task: {repr(item)} on {repr(executor)}, elapsed time: {item.elapsed_time:.1f} secs, elapsed time stats: {perf_stats}" + ) + self.try_enqueue(self.get_retry_task(item.key)) + + def probe_executors(self): + secs_since_last_executor_probe = time.time() - self.last_executor_probe_time + if secs_since_last_executor_probe >= self.ctx.secs_executor_probe_interval: + # discover new executors + with os.scandir(self.ctx.queue_root) as dir_iter: + for entry in dir_iter: + if entry.is_dir(): + _, exec_id = os.path.split(entry.path) + if exec_id not in self.available_executors: + self.available_executors[exec_id] = RemoteExecutor.create( + self.ctx, exec_id, entry.path, self.probe_epoch + ) + logger.info( + f"find a new executor #{len(self.available_executors)}: {self.available_executors[exec_id]}" + ) + self.clear_cached_executor_lists() + # start a new probe epoch + self.last_executor_probe_time = time.time() + self.probe_epoch += 1 + logger.info( + f"send a new round of probes #{self.probe_epoch} to {len(self.working_executors)} working executors: {self.working_executors}" + ) + for executor in self.working_executors: + executor.probe(self.probe_epoch) + # start speculative execution of tasks + if not self.disable_speculative_exec: + self.start_speculative_execution() + + def update_executor_states(self): + executor_state_changed = [] + for executor in self.alive_executors: + old_state = executor.state + executor_state_changed.append(executor.update_state(self.probe_epoch)) + if executor.state == ExecutorState.FAIL and executor.state != old_state: + for item in executor.running_works.values(): + item.status = WorkStatus.EXEC_FAILED + item.finish_time = time.time() + if isinstance(item, Task) and item.key not in self.succeeded_tasks: + logger.warning( + f"reschedule {repr(item)} on failed executor: {repr(executor)}" + ) + self.try_enqueue(self.get_retry_task(item.key)) + + if any(executor_state_changed): + self.clear_cached_executor_lists() + logger.info( + f"in total {len(self.available_executors)} executors: " + f"{len(self.local_executors)} local, " + f"{len(self.good_executors)} good, " + f"{len(self.failed_executors)} failed, " + f"{len(self.stopped_executors)} stopped, " + f"{len(self.stopping_executors)} stopping, " + f"{len(self.low_resource_executors)} low-resource" + ) + + def copy_task_for_execution(self, task: Task) -> Task: + task = copy.copy(task) + # remove the reference to input deps + task.input_deps = {dep_key: None for dep_key in task.input_deps} + # feed input datasets + task.input_datasets = [ + self.succeeded_tasks[dep_key].output for dep_key in task.input_deps + ] + task.sched_epoch = self.sched_epoch + return task + + def save_task_final_state(self, finished_task: Task): + # update perf metrics of logical node + logical_node: Node = self.logical_nodes.get(finished_task.node_id, None) + if logical_node is not None: + for name, value in finished_task.perf_metrics.items(): + logical_node.add_perf_metrics(name, value) + + # update task instance in execution plan + task = self.tasks[finished_task.key] + task.status = finished_task.status + task.start_time = finished_task.start_time + task.finish_time = finished_task.finish_time + task.retry_count = finished_task.retry_count + task.sched_epoch = finished_task.sched_epoch + task.dataset = finished_task.dataset + + def get_runnable_tasks(self, finished_task: Task) -> Iterable[Task]: + assert ( + finished_task.status == WorkStatus.SUCCEED + ), f"task not succeeded: {finished_task}" + for output_key in finished_task.output_deps: + output_dep = self.tasks[output_key] + if all(key in self.succeeded_tasks for key in output_dep.input_deps): + logger.trace( + "{} initiates a new runnable task: {}", + repr(finished_task), + repr(output_dep), + ) + yield output_dep + + def stop_running_tasks(self, task_key: str): + for executor in self.remote_executors: + running_task = executor.running_works.get(task_key, None) + if running_task is not None: + logger.info( + f"try to stop {repr(running_task)} running on {repr(executor)}" + ) + executor.wq.push( + StopWorkItem( + f".StopWorkItem-{repr(running_task)}", running_task.key + ) + ) + + def try_relax_memory_limit(self, task: Task, executor: RemoteExecutor) -> bool: + if task.memory_limit >= executor.memory_size: + logger.warning(f"failed to relax memory limit of {task}") + return False + relaxed_memory_limit = min(executor.memory_size, task.memory_limit * 2) + task._memory_boost = relaxed_memory_limit / task._memory_limit + logger.warning( + f"relax memory limit of {task.key} to {task.memory_limit/GB:.3f}GB and retry ..." + ) + return True + + def try_boost_resource(self, item: WorkItem, executor: RemoteExecutor): + if ( + item._cpu_boost == 1 + and item._memory_boost == 1 + and isinstance(item, Task) + and item.node_id in self.logical_nodes + and self.logical_nodes[item.node_id].enable_resource_boost + ): + boost_cpu = max( + item._cpu_limit, + min( + item._cpu_limit * 2, + executor.available_cpus, + executor.cpu_count // 2, + ), + ) + boost_mem = max( + item._memory_limit, + min( + item._memory_limit * 2, + executor.available_memory, + executor.memory_size // 2, + ), + ) + if item._cpu_limit < boost_cpu or item._memory_limit < boost_mem: + item._cpu_boost = boost_cpu / item._cpu_limit + item._memory_boost = boost_mem / item._memory_limit + logger.info( + f"boost resource usage of {repr(item)}: {item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB" + ) + + def get_retry_task(self, key: str) -> Task: + task = self.tasks[key] + task.retry_count += 1 + assert task.status != WorkStatus.SUCCEED or task.sched_epoch != self.sched_epoch + return task + + @logger.catch(reraise=pytest_running(), message="failed to clean temp files") + def clean_temp_files(self, pool: ThreadPoolExecutor): + remove_path(self.ctx.queue_root) + remove_path(self.ctx.temp_root) + remove_path(self.ctx.staging_root) + + if abandoned_tasks := self.abandoned_tasks: + logger.info( + f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ..." + ) + assert list(pool.map(lambda t: t.clean_output(force=True), abandoned_tasks)) + + @logger.catch(reraise=pytest_running(), message="failed to export task metrics") + def export_task_metrics(self): + import pyarrow as arrow + import pyarrow.csv as csv + + def pristine_attrs_dict(task: Task): + return { + key: str(val) if isinstance(val, Enum) else val + for key in task._pristine_attrs + if isinstance( + val := getattr(task, key), + (bool, str, int, float, Enum, np.integer, np.floating), + ) + } + + dump( + self.finished_tasks, + os.path.join(self.ctx.config_root, "finished_tasks.pickle"), + buffering=32 * MB, + ) + dump( + self.scheduled_tasks, + os.path.join(self.ctx.config_root, "scheduled_tasks.pickle"), + buffering=32 * MB, + ) + + task_props = arrow.array( + pristine_attrs_dict(task) for task in self.finished_tasks.values() + ) + partition_infos = arrow.array( + task.partition_infos_as_dict for task in self.finished_tasks.values() + ) + perf_metrics = arrow.array( + dict(task.perf_metrics) for task in self.finished_tasks.values() + ) + task_metrics = arrow.Table.from_arrays( + [task_props, partition_infos, perf_metrics], + names=["task_props", "partition_infos", "perf_metrics"], + ) + + task_metrics_csv = os.path.join(self.ctx.log_root, "task_metrics.csv") + csv.write_csv(task_metrics.flatten(), task_metrics_csv) + + if self.ctx.shared_log_root: + shutil.copy(task_metrics_csv, self.ctx.shared_log_root) + logger.debug(f"exported task metrics to {task_metrics_csv}") + + @logger.catch(reraise=pytest_running(), message="failed to export timeline figures") + def export_timeline_figs(self): + from datetime import datetime + + import pandas as pd + import plotly.express as px + + if self.large_runtime_state: + logger.debug(f"pause exporting timeline figure") + return + + now = datetime.now() + task_data = pd.DataFrame( + [ + dict( + task=repr(task), + node=( + repr(node) + if (node := self.logical_nodes.get(task.node_id, None)) + is not None + else "StandaloneTasks" + ), + status=str(task.status), + executor=task.exec_id, + start_time=datetime.fromtimestamp(task.start_time), + finish_time=datetime.fromtimestamp( + max( + task.finish_time or now.timestamp(), + task.start_time + 0.0001, + ) + ), + elapsed_time=task.elapsed_time, + partition=str(task.partition_infos), + cpu_limit=task.cpu_limit, + gpu_limit=task.gpu_limit, + mem_limit=task.memory_limit, + ) + for task in {**self.scheduled_tasks, **self.finished_tasks}.values() + if task.start_time is not None + ] + ) + + if task_data.empty: + return + + timeline_figs = [ + px.timeline( + task_data, + x_start="start_time", + x_end="finish_time", + y="node", + color="executor", + hover_name="task", + hover_data=task_data.columns, + title="plan_timeline - progress: {}/{} ({:.1f}%), elapsed: {:.1f} secs, job: {}".format( + *self.progress, self.elapsed_time, self.ctx.job_id + ), + opacity=0.3, + ), + px.timeline( + task_data, + x_start="start_time", + x_end="finish_time", + y="executor", + color="node", + hover_name="task", + hover_data=task_data.columns, + title="exec_timeline - progress: {}/{} ({:.1f}%), elapsed: {:.1f} secs, job: {}".format( + *self.progress, self.elapsed_time, self.ctx.job_id + ), + opacity=0.3, + ), + ] + + for fig in timeline_figs: + fig_title = str(fig.layout["title_text"]) + fig_filename, _ = fig_title.split(" - ", maxsplit=1) + fig_filename += ".html" + fig_path = os.path.join(self.ctx.log_root, fig_filename) + fig.update_yaxes( + autorange="reversed" + ) # otherwise tasks are listed from the bottom up + fig.update_traces(marker_line_color="black", marker_line_width=1, opacity=1) + fig.write_html( + fig_path, include_plotlyjs="cdn" if pytest_running() else True + ) + if self.ctx.shared_log_root: + shutil.copy(fig_path, self.ctx.shared_log_root) + logger.debug(f"exported timeline figure to {fig_path}") + + def notify_state_observers(self, force_notify=False) -> bool: + secs_since_last_state_notify = time.time() - self.last_state_notify_time + if ( + force_notify + or secs_since_last_state_notify >= self.secs_state_notify_interval + ): + self.last_state_notify_time = time.time() + for observer in self.sched_state_observers: + if force_notify or observer.enabled: + start_time = time.time() + observer.update(self) + elapsed_time = time.time() - start_time + if elapsed_time >= self.ctx.secs_executor_probe_interval / 2: + self.secs_state_notify_interval = ( + self.ctx.secs_executor_probe_timeout + ) + if elapsed_time >= self.ctx.secs_executor_probe_interval: + observer.enabled = False + logger.warning( + f"disabled slow scheduler state observer (elapsed time: {elapsed_time:.1f} secs): {observer}" + ) + return True + else: + return False + + def add_state_observer(self, observer: StateObserver): + self.sched_state_observers.append(observer) + logger.info(f"added a scheduler state observer: {observer}") + + def log_overall_progress(self): + logger.info( + "overall progress: {}/{} ({:.1f}%), ".format(*self.progress) + + f"{len(self.local_queue) + len(self.sched_queue)} queued works: {self.local_queue[:3] + self.sched_queue[:3]}, " + + f"{self.num_running_works} running works: {list(itertools.islice(self.running_works, 3))} ..." + ) + + def log_current_status(self): + with open(self.ctx.job_status_path, "w") as fout: + if self.sched_running: + status = "running" + elif self.success: + status = "success" + else: + status = "failure" + fout.write(f"{status}@{int(time.time())}") + + def run(self) -> bool: + mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}" + logger.info( + f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}" + ) + + perf_profile = None + if self.ctx.enable_profiling: + perf_profile = cProfile.Profile() + perf_profile.enable() + + with ThreadPoolExecutor(32) as pool: + self.sched_running = True + self.sched_start_time = time.time() + self.last_executor_probe_time = 0 + self.last_state_notify_time = 0 + self.prioritize_retry |= self.sched_epoch > 0 + + if self.local_queue or self.sched_queue: + pending_tasks = [ + item + for item in self.local_queue + self.sched_queue + if isinstance(item, Task) + ] + self.local_queue.clear() + self.sched_queue.clear() + logger.info( + f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..." + ) + self.try_enqueue(pending_tasks) + + if self.sched_epoch == 0: + leaf_tasks = self.exec_plan.leaves + logger.info( + f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ..." + ) + self.try_enqueue(leaf_tasks) + + self.log_overall_progress() + while (num_finished_tasks := self.process_finished_tasks(pool)) > 0: + logger.info( + f"processed {num_finished_tasks} finished tasks during startup" + ) + self.log_overall_progress() + + earlier_running_tasks = [ + item for item in self.running_works if isinstance(item, Task) + ] + if earlier_running_tasks: + logger.info( + f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..." + ) + self.try_enqueue(earlier_running_tasks) + + self.suspend_good_executors() + self.add_state_observer( + Scheduler.StateObserver(Scheduler.log_current_status) + ) + self.add_state_observer( + Scheduler.StateObserver(Scheduler.export_timeline_figs) + ) + self.notify_state_observers(force_notify=True) + + try: + self.local_executor.start(pool) + self.sched_loop(pool) + finally: + logger.info(f"schedule loop stopped") + self.sched_running = False + self.notify_state_observers(force_notify=True) + self.export_task_metrics() + self.stop_executors() + + # if --output_path is specified, remove the output root as well + if self.remove_output_root or self.ctx.final_output_path: + remove_path(self.ctx.staging_root) + remove_path(self.ctx.output_root) + + if self.success: + self.clean_temp_files(pool) + logger.success(f"final output path: {self.exec_plan.final_output_path}") + logger.info( + f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}" + ) + + if perf_profile is not None: + logger.debug( + f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}" + ) + + logger.info(f"scheduler of job {self.ctx.job_id} exits") + logger.complete() + return self.success + + def try_enqueue(self, tasks: Union[Iterable[Task], Task]): + tasks = tasks if isinstance(tasks, Iterable) else [tasks] + for task in tasks: + task = self.copy_task_for_execution(task) + if task.key in self.succeeded_tasks: + logger.debug(f"task {repr(task)} already succeeded, skipping") + self.try_enqueue( + self.get_runnable_tasks(self.succeeded_tasks[task.key]) + ) + continue + if task.runtime_id in self.scheduled_tasks: + logger.debug(f"task {repr(task)} already scheduled, skipping") + continue + # save enqueued task + self.scheduled_tasks[task.runtime_id] = task + if ( + self.standalone_mode + or task.exec_on_scheduler + or task.skip_when_any_input_empty + ): + self.local_queue.append(task) + else: + self.sched_queue.append(task) + + def sched_loop(self, pool: ThreadPoolExecutor) -> bool: + has_progress = True + do_notify = False + + if self.success: + logger.success(f"job already succeeded, stopping scheduler ...") + return True + + while self.sched_running: + self.probe_executors() + self.update_executor_states() + + if self.local_queue: + assert self.local_executor.alive + logger.info( + f"running {len(self.local_queue)} works on local executor: {self.local_queue[:3]} ..." + ) + self.local_queue = [ + item + for item in self.local_queue + if not self.local_executor.push(item, buffering=True) + ] + self.local_executor.flush() + + has_progress |= self.dispatch_tasks(pool) > 0 + + if len( + self.sched_queue + ) == 0 and self.num_pending_nontrivial_tasks + 1 < len(self.good_executors): + for executor in self.good_executors: + if executor.idle: + logger.info( + f"{len(self.good_executors)} remote executors running, stopping {executor}" + ) + executor.stop() + break + + if ( + len(self.sched_queue) == 0 + and len(self.local_queue) == 0 + and self.num_running_works == 0 + ): + self.log_overall_progress() + assert ( + self.num_pending_tasks == 0 + ), f"scheduler stuck in idle state, but still have {self.num_pending_tasks} pending tasks: {self.tasks.keys() - self.succeeded_tasks.keys()}" + logger.info(f"no queued or running works, stopping scheduler ...") + break + + if has_progress: + has_progress = False + do_notify = True + self.log_overall_progress() + else: + time.sleep(self.ctx.secs_wq_poll_interval) + + if do_notify: + do_notify = not self.notify_state_observers() + + has_progress |= self.process_finished_tasks(pool) > 0 + + # out of loop + return self.success + + def dispatch_tasks(self, pool: ThreadPoolExecutor): + # sort pending tasks + item_sort_key = ( + (lambda item: (-item.retry_count, item.id)) + if self.prioritize_retry + else (lambda item: (item.retry_count, item.id)) + ) + items_sorted_by_node_id = sorted(self.sched_queue, key=lambda t: t.node_id) + items_group_by_node_id = itertools.groupby( + items_sorted_by_node_id, key=lambda t: t.node_id + ) + sorted_item_groups = [ + sorted(items, key=item_sort_key) for _, items in items_group_by_node_id + ] + self.sched_queue = [ + item + for batch in itertools.zip_longest(*sorted_item_groups, fillvalue=None) + for item in batch + if item is not None + ] + + final_phase = ( + self.num_pending_nontrivial_tasks - self.num_running_works + <= len(self.good_executors) * 2 + ) + num_dispatched_tasks = 0 + unassigned_tasks = [] + + while self.sched_queue and self.good_executors: + first_item = self.sched_queue[0] + + # assign tasks to executors in round-robin fashion + usable_executors = [ + executor for executor in self.good_executors if not executor.busy + ] + for executor in sorted( + usable_executors, key=lambda exec: len(exec.running_works) + ): + if not self.sched_queue: + break + item = self.sched_queue[0] + + if item._memory_limit is None: + item._memory_limit = np.int64( + executor.memory_size * item._cpu_limit // executor.cpu_count + ) + + if item.key in self.succeeded_tasks: + logger.debug(f"task {repr(item)} already succeeded, skipping") + self.sched_queue.pop(0) + self.try_enqueue( + self.get_runnable_tasks(self.succeeded_tasks[item.key]) + ) + elif ( + len(executor.running_works) < executor.max_running_works + and executor.allocated_cpus + item.cpu_limit <= executor.cpu_count + and executor.allocated_gpus + item.gpu_limit <= executor.gpu_count + and executor.allocated_memory + item.memory_limit + <= executor.memory_size + and item.key not in executor.running_works + ): + if final_phase: + self.try_boost_resource(item, executor) + # push to wq of executor but not flushed yet + executor.push(item, buffering=True) + logger.info( + f"appended {repr(item)} ({item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB) to the queue of {executor}" + ) + self.sched_queue.pop(0) + num_dispatched_tasks += 1 + + if self.sched_queue and self.sched_queue[0] is first_item: + unassigned_tasks.append(self.sched_queue.pop(0)) + + # append unassigned tasks to the queue + self.sched_queue.extend(unassigned_tasks) + + # flush the buffered work items into wq + assert all( + pool.map(RemoteExecutor.flush, self.good_executors) + ), f"failed to flush work queues" + return num_dispatched_tasks + + def process_finished_tasks(self, pool: ThreadPoolExecutor) -> int: + pop_results = pool.map(RemoteExecutor.pop, self.available_executors.values()) + num_finished_tasks = 0 + + for executor, finished_tasks in zip( + self.available_executors.values(), pop_results + ): + + for finished_task in finished_tasks: + assert isinstance(finished_task, Task) + + scheduled_task = self.scheduled_tasks.get( + finished_task.runtime_id, None + ) + if scheduled_task is None: + logger.info( + f"task not initiated by current scheduler: {finished_task}" + ) + if finished_task.status != WorkStatus.SUCCEED and ( + missing_inputs := [ + key + for key in finished_task.input_deps + if key not in self.succeeded_tasks + ] + ): + logger.info( + f"ignore {repr(finished_task)} since some of the input deps are missing: {missing_inputs}" + ) + continue + + if finished_task.status == WorkStatus.INCOMPLETE: + logger.trace( + f"{repr(finished_task)} checkpoint created on {executor.id}: {finished_task.runtime_state}" + ) + self.tasks[finished_task.key].runtime_state = ( + finished_task.runtime_state + ) + continue + + prior_task = self.finished_tasks.get(finished_task.runtime_id, None) + if prior_task is not None: + logger.info( + f"found duplicate tasks, current: {repr(finished_task)}, prior: {repr(prior_task)}" + ) + continue + else: + self.finished_tasks[finished_task.runtime_id] = finished_task + num_finished_tasks += 1 + + succeeded_task = self.succeeded_tasks.get(finished_task.key, None) + if succeeded_task is not None: + logger.info( + f"task already succeeded, current: {repr(finished_task)}, succeeded: {repr(succeeded_task)}" + ) + continue + + if finished_task.status in (WorkStatus.FAILED, WorkStatus.CRASHED): + logger.warning( + f"task failed on {executor.id}: {finished_task}, error: {finished_task.exception}" + ) + finished_task.dump() + + task = self.tasks[finished_task.key] + task.fail_count += 1 + + if task.fail_count > self.max_fail_count: + logger.critical( + f"task failed too many times: {finished_task}, stopping ..." + ) + self.stop_executors() + self.sched_running = False + + if not executor.local and finished_task.oom( + self.nonzero_exitcode_as_oom + ): + if task._memory_limit is None: + task._memory_limit = finished_task._memory_limit + self.try_relax_memory_limit(task, executor) + + if not executor.local and self.stop_executor_on_failure: + logger.warning(f"stop executor: {executor}") + executor.stop() + + self.try_enqueue(self.get_retry_task(finished_task.key)) + else: + assert ( + finished_task.status == WorkStatus.SUCCEED + ), f"unexpected task status: {finished_task}" + logger.log( + "TRACE" if finished_task.exec_on_scheduler else "INFO", + "task succeeded on {}: {}", + finished_task.exec_id, + finished_task, + ) + + self.succeeded_tasks[finished_task.key] = finished_task + if not finished_task.exec_on_scheduler: + self.succeeded_nontrivial_tasks[finished_task.key] = ( + finished_task + ) + + # stop the redundant retries of finished task + self.stop_running_tasks(finished_task.key) + self.save_task_final_state(finished_task) + self.try_enqueue(self.get_runnable_tasks(finished_task)) + + if finished_task.id == self.exec_plan.root_task.id: + self.sched_queue = [] + self.stop_executors() + logger.success( + f"all tasks completed, root task: {finished_task}" + ) + logger.success( + f"{len(self.succeeded_tasks)} succeeded tasks, success: {self.success}, elapsed time: {self.elapsed_time:.3f} secs" + ) + + # clear inputs since they are not needed anymore + finished_task.input_datasets = [] + + return num_finished_tasks diff --git a/smallpond/execution/task.py b/smallpond/execution/task.py new file mode 100644 index 0000000..43d8d22 --- /dev/null +++ b/smallpond/execution/task.py @@ -0,0 +1,3671 @@ +import contextlib +import copy +import cProfile +import io +import itertools +import json +import logging +import math +import os +import pprint +import random +import resource +import shutil +import sys +import time +import uuid +from collections import OrderedDict, defaultdict, namedtuple +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from datetime import datetime +from functools import cached_property +from pathlib import Path, PurePath +from typing import ( + Any, + BinaryIO, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Set, + Tuple, + Union, +) + +import duckdb +import fsspec +import GPUtil +import numpy as np +import pandas as pd +import psutil +import pyarrow as arrow +import pyarrow.parquet as parquet +import ray +from loguru import logger + +from smallpond.common import ( + DATA_PARTITION_COLUMN_NAME, + DEFAULT_BATCH_SIZE, + DEFAULT_MAX_RETRY_COUNT, + DEFAULT_ROW_GROUP_SIZE, + GB, + GENERATED_COLUMNS, + INPUT_VIEW_PREFIX, + KB, + MAX_NUM_ROW_GROUPS, + MAX_ROW_GROUP_BYTES, + MAX_ROW_GROUP_SIZE, + MB, + PARQUET_METADATA_KEY_PREFIX, + RAND_SEED_BYTE_LEN, + TB, + InjectedFault, + OutOfMemory, + clamp_row_group_bytes, + clamp_row_group_size, + pytest_running, + round_up, + split_into_rows, +) +from smallpond.execution.workqueue import WorkItem, WorkStatus +from smallpond.io.arrow import ( + cast_columns_to_large_string, + dump_to_parquet_files, + filter_schema, +) +from smallpond.io.filesystem import dump, find_mount_point, load, remove_path +from smallpond.logical.dataset import ( + ArrowTableDataSet, + CsvDataSet, + DataSet, + FileSet, + JsonDataSet, + ParquetDataSet, + PartitionedDataSet, + SqlQueryDataSet, +) +from smallpond.logical.udf import UDFContext +from smallpond.utility import ConcurrentIter, InterceptHandler, cprofile_to_string + + +class JobId(uuid.UUID): + """ + A unique identifier for a job. + """ + + @staticmethod + def new(): + return JobId(int=uuid.uuid4().int) + + +class TaskId(int): + """ + A unique identifier for a task. + """ + + def __str__(self) -> str: + return f"{self:06d}" + + +@dataclass(frozen=True) +class TaskRuntimeId: + """ + A unique identifier for a task at runtime. + + Besides the task id, it also includes the epoch and retry count. + """ + + id: TaskId + epoch: int + retry: int + """How many times the task has been retried.""" + + def __str__(self) -> str: + return f"{self.id}.{self.epoch}.{self.retry}" + + +class PerfStats( + namedtuple( + "PerfStats", ("cnt", "sum", "min", "max", "avg", "p50", "p75", "p95", "p99") + ) +): + """ + Performance statistics for a task. + """ + + def __str__(self) -> str: + return ", ".join([f"{k}={v:,.1f}" for k, v in self._asdict().items()]) + + __repr__ = __str__ + + +class RuntimeContext(object): + """ + The configuration and state for a running job. + """ + + def __init__( + self, + job_id: JobId, + job_time: datetime, + data_root: str, + *, + num_executors: int = 1, + random_seed: int = None, + env_overrides: Dict[str, str] = None, + bind_numa_node=False, + enforce_memory_limit=False, + max_usable_cpu_count: int = 1024, + max_usable_gpu_count: int = 1024, + max_usable_memory_size: int = 16 * TB, + secs_wq_poll_interval: float = 0.5, + secs_executor_probe_interval: float = 30, + max_num_missed_probes: int = 6, + fault_inject_prob=0.0, + enable_profiling=False, + enable_diagnostic_metrics=False, + remove_empty_parquet=False, + skip_task_with_empty_input=False, + shared_log_root: Optional[str] = None, + console_log_level="INFO", + file_log_level="DEBUG", + disable_log_rotation=False, + output_path: Optional[str] = None, + ) -> None: + self.job_id = job_id + self.job_time = job_time + self.data_root = data_root + self.next_task_id = 0 + self.num_executors = num_executors + self.random_seed: int = random_seed or int.from_bytes( + os.urandom(RAND_SEED_BYTE_LEN), byteorder=sys.byteorder + ) + self.env_overrides = env_overrides or {} + self.bind_numa_node = bind_numa_node + self.numa_node_id: Optional[int] = None + self.enforce_memory_limit = enforce_memory_limit + self.max_usable_cpu_count = max_usable_cpu_count + self.max_usable_gpu_count = max_usable_gpu_count + self.max_usable_memory_size = max_usable_memory_size + self.secs_wq_poll_interval = secs_wq_poll_interval + self.secs_executor_probe_interval = secs_executor_probe_interval + self.max_num_missed_probes = max_num_missed_probes + self.fault_inject_prob = fault_inject_prob + self.enable_profiling = enable_profiling + self.enable_diagnostic_metrics = enable_diagnostic_metrics + self.remove_empty_parquet = remove_empty_parquet + self.skip_task_with_empty_input = skip_task_with_empty_input + + self.shared_log_root = ( + os.path.join(shared_log_root, self.job_root_dirname) + if shared_log_root + else None + ) + self.console_log_level = console_log_level + self.file_log_level = file_log_level + self.disable_log_rotation = disable_log_rotation + + self.job_root = os.path.abspath(os.path.join(data_root, self.job_root_dirname)) + self.config_root = os.path.join(self.job_root, "config") + self.queue_root = os.path.join(self.job_root, "queue") + self.output_root = os.path.join(self.job_root, "output") + self.staging_root = os.path.join(self.job_root, "staging") + self.temp_root = os.path.join(self.job_root, "temp") + self.log_root = os.path.join(self.job_root, "log") + self.final_output_path = os.path.abspath(output_path) if output_path else None + self.current_task: Task = None + + # used by ray executors to checkpoint task states + self.started_task_dir = os.path.join(self.staging_root, "started_tasks") + self.completed_task_dir = os.path.join(self.staging_root, "completed_tasks") + + @property + def job_root_dirname(self): + return f"{self.job_time:%Y-%m-%d-%H-%M-%S}.{self.job_id}" + + @property + def job_status_path(self): + return os.path.join(self.log_root, ".STATUS") + + @property + def runtime_ctx_path(self): + return os.path.join(self.config_root, f"runtime_ctx.pickle") + + @property + def logcial_plan_path(self): + return os.path.join(self.config_root, f"logical_plan.pickle") + + @property + def logcial_plan_graph_path(self): + return os.path.join(self.log_root, "graph") + + @property + def ray_log_path(self): + return os.path.join(self.log_root, "ray.log") + + @property + def exec_plan_path(self): + return os.path.join(self.config_root, f"exec_plan.pickle") + + @property + def sched_state_path(self): + return os.path.join(self.config_root, f"sched_state.pickle") + + @property + def numa_node_count(self): + if sys.platform == "darwin": + # numa is not supported on macos + return 1 + import numa + + return numa.info.get_num_configured_nodes() + + @property + def physical_cpu_count(self): + cpu_count = psutil.cpu_count(logical=False) + return cpu_count // self.numa_node_count if self.bind_numa_node else cpu_count + + @property + def available_memory(self): + available_memory = psutil.virtual_memory().available + return ( + available_memory // self.numa_node_count + if self.bind_numa_node + else available_memory + ) + + @property + def total_memory(self): + total_memory = psutil.virtual_memory().total + return ( + total_memory // self.numa_node_count + if self.bind_numa_node + else total_memory + ) + + @property + def usable_cpu_count(self): + return min(self.max_usable_cpu_count, self.physical_cpu_count) + + @property + def usable_memory_size(self): + return min(self.max_usable_memory_size, self.total_memory) + + @property + def secs_executor_probe_timeout(self): + return self.secs_executor_probe_interval * self.max_num_missed_probes + + def get_local_gpus(self) -> List[GPUtil.GPU]: + gpus = GPUtil.getGPUs() + gpus_on_node = split_into_rows(gpus, self.numa_node_count) + return ( + gpus_on_node[self.numa_node_id] + if self.bind_numa_node and self.numa_node_id is not None + else gpus + ) + + @property + def usable_gpu_count(self): + return min(self.max_usable_gpu_count, len(self.get_local_gpus())) + + @property + def task(self): + return self.current_task + + def set_current_task(self, task: "Task" = None) -> "RuntimeContext": + self.current_task = None + ctx = copy.copy(self) + ctx.current_task = copy.copy(task) + return ctx + + def new_task_id(self) -> TaskId: + self.next_task_id += 1 + return TaskId(self.next_task_id) + + def initialize(self, exec_id: str, root_exist_ok=True, cleanup_root=False) -> None: + import smallpond + + self._make_dirs(root_exist_ok, cleanup_root) + self._init_logs(exec_id) + self._init_envs() + logger.info(f"smallpond version: {smallpond.__version__}") + logger.info(f"runtime context:{os.linesep}{pprint.pformat(vars(self))}") + logger.info(f"local GPUs: {[gpu.id for gpu in self.get_local_gpus()]}") + + def cleanup_root(self): + if os.path.exists(self.job_root): + remove_path(self.job_root) + + def _make_dirs(self, root_exist_ok, cleanup_root): + if os.path.exists(self.job_root): + if cleanup_root: + remove_path(self.job_root) + elif not root_exist_ok: + raise FileExistsError(f"Folder already exists: {self.job_root}") + os.makedirs(self.config_root, exist_ok=root_exist_ok) + os.makedirs(self.queue_root, exist_ok=root_exist_ok) + os.makedirs(self.output_root, exist_ok=root_exist_ok) + os.makedirs(self.staging_root, exist_ok=root_exist_ok) + os.makedirs(self.temp_root, exist_ok=root_exist_ok) + os.makedirs(self.log_root, exist_ok=root_exist_ok) + os.makedirs(self.started_task_dir, exist_ok=root_exist_ok) + os.makedirs(self.completed_task_dir, exist_ok=root_exist_ok) + + def _init_envs(self): + sys.setrecursionlimit(100000) + env_overrides = copy.copy(self.env_overrides) + ld_library_path = os.getenv("LD_LIBRARY_PATH", "") + py_library_path = os.path.join(sys.exec_prefix, "lib") + if py_library_path not in ld_library_path: + env_overrides["LD_LIBRARY_PATH"] = ":".join( + [py_library_path, ld_library_path] + ) + for key, val in env_overrides.items(): + if (old := os.getenv(key, None)) is not None: + logger.info( + f"overwrite environment variable '{key}': '{old}' -> '{val}'" + ) + else: + logger.info(f"set environment variable '{key}': '{val}'") + os.environ[key] = val + logger.info( + f"RANK='{os.getenv('RANK', '')}' LD_LIBRARY_PATH='{os.getenv('LD_LIBRARY_PATH', '')}' LD_PRELOAD='{os.getenv('LD_PRELOAD', '')}' MALLOC_CONF='{os.getenv('MALLOC_CONF', '')}'" + ) + + def _init_logs(self, exec_id: str, capture_stdout_stderr: bool = False) -> None: + log_rotation = ( + {"rotation": "100 MB", "retention": 5} + if not self.disable_log_rotation + else {} + ) + log_file_paths = [os.path.join(self.log_root, f"{exec_id}.log")] + user_log_only = {"": self.file_log_level, "smallpond": False} + user_log_path = os.path.join(self.log_root, f"{exec_id}-user.log") + # create shared log dir + if self.shared_log_root is not None: + os.makedirs(self.shared_log_root, exist_ok=True) + shared_log_path = os.path.join(self.shared_log_root, f"{exec_id}.log") + log_file_paths.append(shared_log_path) + # remove existing handlers + logger.remove() + # register stdout log handler + format_str = f"[{{time:%Y-%m-%d %H:%M:%S.%f}}] [{exec_id}] [{{process.name}}({{process.id}})] [{{file}}:{{line}}] {{level}} {{message}}" + logger.add( + sys.stdout, + format=format_str, + colorize=False, + enqueue=True, + backtrace=False, + level=self.console_log_level, + ) + # register file log handlers + for log_path in log_file_paths: + logger.add( + log_path, + format=format_str, + colorize=False, + enqueue=True, + backtrace=False, + level=self.file_log_level, + **log_rotation, + ) + logger.info(f"initialized logging to file: {log_path}") + # register user log handler + logger.add( + user_log_path, + format=format_str, + colorize=False, + enqueue=True, + backtrace=False, + level=self.file_log_level, + filter=user_log_only, + **log_rotation, + ) + logger.info(f"initialized user logging to file: {user_log_path}") + # intercept messages from logging + logging.basicConfig( + handlers=[InterceptHandler()], level=logging.INFO, force=True + ) + # capture stdout as INFO level + # https://loguru.readthedocs.io/en/stable/resources/recipes.html#capturing-standard-stdout-stderr-and-warnings + if capture_stdout_stderr: + + class StreamToLogger(io.TextIOBase): + def __init__(self, level="INFO"): + super().__init__() + self._level = level + + def write(self, buffer): + for line in buffer.rstrip().splitlines(): + logger.opt(depth=1).log(self._level, line.rstrip()) + + def flush(self): + pass + + sys.stdout = StreamToLogger() + sys.stderr = StreamToLogger() + + def cleanup(self, remove_output_root: bool = True): + """ + Clean up the runtime directory. This will be called when the job is finished. + """ + remove_path(self.queue_root) + remove_path(self.temp_root) + remove_path(self.staging_root) + if remove_output_root: + remove_path(self.output_root) + + +class Probe(WorkItem): + def __init__( + self, ctx: RuntimeContext, key: str, epoch: int, epochs_to_skip=0 + ) -> None: + super().__init__(key, cpu_limit=0, gpu_limit=0, memory_limit=0) + self.ctx = ctx + self.epoch = epoch + self.epochs_to_skip = epochs_to_skip + self.resource_low = True + self.cpu_count = 0 + self.gpu_count = 0 + self.cpu_percent = 0 + self.total_memory = 0 + self.available_memory = 0 + + def __str__(self) -> str: + return ( + super().__str__() + + f", epoch={self.epoch}, resource_low={self.resource_low}, cpu_count={self.cpu_count}, gpu_count={self.gpu_count}, cpu_usage={self.cpu_percent}%, available_memory={self.available_memory//GB}GB/{self.total_memory//GB}GB" + ) + + def __lt__(self, other: "Probe") -> bool: + return self.epoch < other.epoch + + def run(self) -> bool: + self.cpu_percent = psutil.cpu_percent( + interval=min(self.ctx.secs_executor_probe_interval / 2, 3) + ) + self.total_memory = self.ctx.usable_memory_size + self.available_memory = self.ctx.available_memory + self.resource_low = ( + self.cpu_percent >= 80.0 or self.available_memory < self.total_memory // 16 + ) + self.cpu_count = self.ctx.usable_cpu_count + self.gpu_count = self.ctx.usable_gpu_count + logger.info("resource usage: {}", self) + return True + + +class PartitionInfo(object): + """ + Information about a partition of a dataset. + """ + + toplevel_dimension = "@toplevel@" + default_dimension = DATA_PARTITION_COLUMN_NAME + + __slots__ = ( + "index", + "npartitions", + "dimension", + ) + + def __init__( + self, index: int = 0, npartitions: int = 1, dimension: str = toplevel_dimension + ) -> None: + self.index = index + self.npartitions = npartitions + self.dimension = dimension + + def __str__(self): + return f"{self.dimension}[{self.index}/{self.npartitions}]" + + __repr__ = __str__ + + def __lt__(self, other: "PartitionInfo"): + return (self.dimension, self.index) < (other.dimension, other.index) + + def __eq__(self, other: "PartitionInfo"): + return (self.dimension, self.index) == (other.dimension, other.index) + + def __hash__(self): + return hash((self.dimension, self.index)) + + +class Task(WorkItem): + """ + The base class for all tasks. + + Task is the basic unit of work in smallpond. + Each task represents a specific operation that takes a series of input datasets and produces an output dataset. + Tasks can depend on other tasks, forming a directed acyclic graph (DAG). + Tasks can specify resource requirements such as CPU, GPU, and memory limits. + Tasks should be idempotent. They can be retried if they fail. + + Lifetime of a task object: + + - instantiated at planning time on the scheduler node + - pickled and sent to a worker node + - `initialize()` is called to prepare for execution + - `run()` is called to execute the task + - `finalize()` or `cleanup()` is called to release resources + - pickled and sent back to the scheduler node + """ + + __slots__ = ( + "ctx", + "id", + "node_id", + "sched_epoch", + "output_name", + "output_root", + "_temp_output", + "dataset", + "input_deps", + "output_deps", + "_np_randgen", + "_py_randgen", + "_timer_start", + "perf_metrics", + "perf_profile", + "_partition_infos", + "runtime_state", + "input_datasets", + "_dataset_ref", + ) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: "List[Task]", + partition_infos: List[PartitionInfo], + output_name: Optional[str] = None, + output_path: Optional[str] = None, + cpu_limit: int = 1, + gpu_limit: float = 0, + memory_limit: Optional[int] = None, + ) -> None: + assert isinstance(input_deps, Iterable), f"{input_deps} is not iterable" + assert all( + isinstance(dep, Task) for dep in input_deps + ), f"not every element of {input_deps} is a task" + assert isinstance( + partition_infos, Iterable + ), f"{partition_infos} is not iterable" + assert all( + isinstance(info, PartitionInfo) for info in partition_infos + ), f"not every element of {partition_infos} is a partition info" + assert any( + info.dimension == PartitionInfo.toplevel_dimension + for info in partition_infos + ), f"cannot find toplevel partition dimension: {partition_infos}" + assert cpu_limit > 0, f"cpu_limit should be greater than zero" + self.ctx = ctx + self.id = ctx.new_task_id() + self.node_id = 0 + self.sched_epoch = 0 + self._np_randgen = None + self._py_randgen = None + self._timer_start = None + self.perf_metrics: Dict[str, np.int64] = defaultdict(np.int64) + self.perf_profile = None + self._partition_infos = sorted(partition_infos) or [] + assert len(self.partition_dims) == len( + set(self.partition_dims) + ), f"found duplicate partition dimensions: {self.partition_dims}" + super().__init__( + f"{self.__class__.__name__}-{self.id}", cpu_limit, gpu_limit, memory_limit + ) + self.output_name = output_name + self.output_root = output_path + self._temp_output = output_name is None and output_path is None + + # dependency references + # NOTICE: `input_deps` is only used to maintain the task graph at planning time. + # before execution, references to dependencies are cleared so that the + # task can be sent to executors independently. + # DO NOT use `input_deps.values()` in execution time. + self.input_deps = dict((dep.key, dep) for dep in input_deps) + self.output_deps: Set[str] = set() + for dep in input_deps: + dep.output_deps.add(self.key) + + # input datasets for each dependency + # implementor should read input from here + self.input_datasets: List[DataSet] = None + # the output dataset + # implementor should set this variable as the output + # if not set, the output dataset will be all parquet files in the output directory + self.dataset: Optional[DataSet] = None + # runtime state + # implementor can use this variable as a checkpoint and restore from it after interrupted + self.runtime_state = None + + # if the task is executed by ray, this is the reference to the output dataset + # do not use this variable directly, use `self.run_on_ray()` instead + self._dataset_ref: Optional[ray.ObjectRef] = None + + def __repr__(self) -> str: + return f"{self.key}.{self.sched_epoch}.{self.retry_count},{self.node_id}" + + def __str__(self) -> str: + return ( + f"{repr(self)}: status={self.status}, elapsed_time={self.elapsed_time:.3f}s, epoch={self.sched_epoch}, #retries={self.retry_count}, " + f"input_deps[{len(self.input_deps)}]={list(self.input_deps.keys())[:3]}..., output_dir={self.output_dirname}, " + f"resource_limit={self.cpu_limit}CPUs/{self.gpu_limit}GPUs/{(self.memory_limit or 0)//GB}GB, " + f"partition_infos={self.partition_infos}" + ) + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.exec() + + @property + def _pristine_attrs(self) -> Set[str]: + """All attributes in __slots__.""" + return set( + itertools.chain.from_iterable( + getattr(cls, "__slots__", []) for cls in type(self).__mro__ + ) + ) + + @property + def partition_infos(self) -> Tuple[PartitionInfo]: + return tuple(self._partition_infos) + + @property + def partition_infos_as_dict(self): + return dict((info.dimension, info.index) for info in self.partition_infos) + + def parquet_kv_metadata_str(self, extra_partitions: List[PartitionInfo] = None): + task_infos = [ + ("__task_key__", self.key), + ("__task_id__", str(self.id)), + ("__node_id__", str(self.node_id)), + ("__job_id__", str(self.ctx.job_id)), + ("__job_root__", self.ctx.job_root), + ] + partition_infos = [ + (info.dimension, str(info.index)) + for info in self._partition_infos + (extra_partitions or []) + ] + return dict( + (PARQUET_METADATA_KEY_PREFIX + k, v) + for k, v in task_infos + partition_infos + ) + + def parquet_kv_metadata_bytes(self, extra_partitions: List[PartitionInfo] = None): + return dict( + (k.encode("utf-8"), v.encode("utf-8")) + for k, v in self.parquet_kv_metadata_str(extra_partitions).items() + ) + + @property + def partition_dims(self): + return tuple(info.dimension for info in self._partition_infos) + + def get_partition_info(self, dimension: str): + for info in self._partition_infos: + if info.dimension == dimension: + return info + raise KeyError(f"cannot find dimension {dimension} in {self._partition_infos}") + + @property + def any_input_empty(self) -> bool: + return any(dataset.empty for dataset in self.input_datasets) + + @property + def skip_when_any_input_empty(self) -> bool: + return self.ctx.skip_task_with_empty_input and self.any_input_empty + + @property + def runtime_id(self): + return TaskRuntimeId(self.id, self.sched_epoch, self.retry_count) + + @property + def default_output_name(self): + return ".".join(map(str, filter(None, [self.__class__.__name__, self.node_id]))) + + @property + def output_filename(self): + output_name = self.output_name or self.default_output_name + return f"{output_name}-{self.ctx.job_id}.{self.runtime_id}" + + @property + def output_dirname(self): + output_name = self.output_name or self.default_output_name + return os.path.join(output_name, f"{self.ctx.job_id}.{self.runtime_id}") + + @property + def staging_root(self) -> Optional[str]: + """ + If the task has a special output directory, its runtime output directory will be under it. + """ + return ( + None + if self.output_root is None + else os.path.join(self.output_root, ".staging") + ) + + @property + def _final_output_root(self): + return ( + self.ctx.staging_root + if self.temp_output + else (self.output_root or self.ctx.output_root) + ) + + @property + def _runtime_output_root(self): + return self.staging_root or self.ctx.staging_root + + @property + def final_output_abspath(self) -> str: + return os.path.join(self._final_output_root, self.output_dirname) + + @property + def runtime_output_abspath(self) -> str: + """ + Output data will be produced in this directory at runtime. + + When the task is finished, this directory will be atomically moved to `final_output_abspath`. + """ + return os.path.join(self._runtime_output_root, self.output_dirname) + + @property + def temp_abspath(self) -> str: + return os.path.join(self.ctx.temp_root, self.output_dirname) + + @property + def output(self) -> DataSet: + return self.dataset or ParquetDataSet(["*"], root_dir=self.final_output_abspath) + + @property + def self_contained_output(self) -> bool: + """ + Whether the output of this node is not dependent on any input nodes. + """ + return True + + @property + def temp_output(self) -> bool: + """ + Whether the output of this node is stored in a temporary directory. + """ + return self._temp_output + + @temp_output.setter + def temp_output(self, temp_output: bool): + assert temp_output == False, f"cannot change temp_output to True in {self}" + self._temp_output = False + if not self.self_contained_output: + for task in self.input_deps.values(): + if task.temp_output: + task.temp_output = False + + @property + def allow_speculative_exec(self) -> bool: + """ + Whether the task is allowed to be executed by speculative execution. + """ + return True + + @property + def ray_marker_path(self) -> str: + """ + The path of an empty file that is used to determine if the task has been started. + Only used by the ray executor. + """ + return os.path.join( + self.ctx.started_task_dir, f"{self.node_id}.{self.key}.{self.retry_count}" + ) + + @property + def ray_dataset_path(self) -> str: + """ + The path of a pickle file that contains the output dataset of the task. + If this file exists, the task is considered finished. + Only used by the ray executor. + """ + return os.path.join( + self.ctx.completed_task_dir, str(self.node_id), f"{self.key}.pickle" + ) + + @property + def random_seed_bytes(self) -> bytes: + return self.id.to_bytes(4, sys.byteorder) + self.ctx.random_seed.to_bytes( + RAND_SEED_BYTE_LEN, sys.byteorder + ) + + @property + def numpy_random_gen(self): + if self._np_randgen is None: + self._np_randgen = np.random.default_rng( + int.from_bytes(self.random_seed_bytes, sys.byteorder) + ) + return self._np_randgen + + @property + def python_random_gen(self): + if self._py_randgen is None: + self._py_randgen = random.Random( + int.from_bytes(self.random_seed_bytes, sys.byteorder) + ) + return self._py_randgen + + def random_uint32(self) -> int: + return self.python_random_gen.randint(0, 0x7FFFFFFF) + + def random_float(self) -> float: + return self.python_random_gen.random() + + @property + def uniform_failure_prob(self): + return 1.0 / (self.ctx.next_task_id - self.id + 1) + + def inject_fault(self): + if self.ctx.fault_inject_prob > 0 and self.fail_count <= 1: + random_value = self.random_float() + if ( + random_value < self.uniform_failure_prob + and random_value < self.ctx.fault_inject_prob + ): + raise InjectedFault( + f"inject fault to {repr(self)}, uniform_failure_prob={self.uniform_failure_prob:.6f}, fault_inject_prob={self.ctx.fault_inject_prob:.6f}" + ) + + def compute_avg_row_size(self, nbytes, num_rows): + return max(1, nbytes // num_rows) if num_rows > 0 else 1 + + def adjust_row_group_size( + self, + nbytes, + num_rows, + max_row_group_size=MAX_ROW_GROUP_SIZE, + max_row_group_bytes=MAX_ROW_GROUP_BYTES, + max_num_row_groups=MAX_NUM_ROW_GROUPS, + ): + parquet_row_group_size = self.parquet_row_group_size + num_row_groups = num_rows // parquet_row_group_size + + if num_row_groups > max_num_row_groups: + parquet_row_group_size = round_up( + clamp_row_group_size( + num_rows // max_num_row_groups, maxval=max_row_group_size + ), + KB, + ) + avg_row_size = self.compute_avg_row_size(nbytes, num_rows) + parquet_row_group_size = round_up( + min(parquet_row_group_size, max_row_group_bytes // avg_row_size), KB + ) + + if self.parquet_row_group_size != parquet_row_group_size: + parquet_row_group_bytes = round_up( + clamp_row_group_bytes( + parquet_row_group_size * avg_row_size, maxval=max_row_group_bytes + ), + MB, + ) + logger.info( + f"adjust row group size for dataset ({num_rows} rows, {nbytes/MB:.3f}MB): {self.parquet_row_group_size} -> {parquet_row_group_size} rows, {parquet_row_group_bytes/MB:.1f}MB" + ) + self.parquet_row_group_size = parquet_row_group_size + self.parquet_row_group_bytes = parquet_row_group_bytes + + def run(self) -> bool: + return True + + def set_memory_limit(self, soft_limit: int, hard_limit: int): + soft_old, hard_old = resource.getrlimit(resource.RLIMIT_DATA) + resource.setrlimit(resource.RLIMIT_DATA, (soft_limit, hard_limit)) + logger.debug( + f"updated memory limit from ({soft_old/GB:.3f}GB, {hard_old/GB:.3f}GB) to ({soft_limit/GB:.3f}GB, {hard_limit/GB:.3f}GB)" + ) + + def initialize(self): + self.inject_fault() + + if self._memory_limit is None: + self._memory_limit = np.int64( + self.ctx.usable_memory_size + * self._cpu_limit + // self.ctx.usable_cpu_count + ) + assert self.partition_infos, f"empty partition infos: {self}" + os.makedirs(self.runtime_output_abspath, exist_ok=self.output_root is not None) + os.makedirs(self.temp_abspath, exist_ok=False) + + if not self.exec_on_scheduler: + if self.ctx.enable_profiling: + self.perf_profile = cProfile.Profile() + self.perf_profile.enable() + if self.ctx.enforce_memory_limit: + self.set_memory_limit( + round_up(self.memory_limit * 1.2), round_up(self.memory_limit * 1.5) + ) + if self.ctx.remove_empty_parquet: + for dataset in self.input_datasets: + if isinstance(dataset, ParquetDataSet): + dataset.remove_empty_files() + logger.info("running task: {}", self) + logger.debug("input datasets: {}", self.input_datasets) + logger.trace(f"final output directory: {self.final_output_abspath}") + logger.trace(f"runtime output directory: {self.runtime_output_abspath}") + logger.trace( + f"resource limit: {self.cpu_limit} cpus, {self.gpu_limit} gpus, {self.memory_limit/GB:.3f}GB memory" + ) + random.seed(self.random_seed_bytes) + arrow.set_cpu_count(self.cpu_limit) + arrow.set_io_thread_count(self.cpu_limit) + os.environ["OMP_NUM_THREADS"] = str(self.cpu_limit) + os.environ["POLARS_MAX_THREADS"] = str(self.cpu_limit) + + def finalize(self): + self.inject_fault() + assert self.status == WorkStatus.SUCCEED + logger.info("finished task: {}", self) + + # move the task output from staging dir to output dir + if self.runtime_output_abspath != self.final_output_abspath and os.path.exists( + self.runtime_output_abspath + ): + os.makedirs(os.path.dirname(self.final_output_abspath), exist_ok=True) + os.rename(self.runtime_output_abspath, self.final_output_abspath) + + def collect_file_sizes(file_paths): + if not file_paths: + return [] + try: + with ThreadPoolExecutor(min(32, len(file_paths))) as pool: + file_sizes = list(pool.map(os.path.getsize, file_paths)) + except FileNotFoundError: + logger.warning( + f"some of the output files not found: {file_paths[:3]}..." + ) + file_sizes = [] + return file_sizes + + if self.ctx.enable_diagnostic_metrics: + input_file_paths = [ + path + for dataset in self.input_datasets + for path in dataset.resolved_paths + ] + output_file_paths = self.output.resolved_paths + for metric_name, file_paths in [ + ("input", input_file_paths), + ("output", output_file_paths), + ]: + file_sizes = collect_file_sizes(file_paths) + if file_paths and file_sizes: + self.perf_metrics[f"num {metric_name} files"] += len(file_paths) + self.perf_metrics[f"total {metric_name} size (MB)"] += ( + sum(file_sizes) / MB + ) + + self.perf_metrics["elapsed wall time (secs)"] += self.elapsed_time + if not self.exec_on_scheduler: + resource_usage = resource.getrusage(resource.RUSAGE_SELF) + self.perf_metrics["max resident set size (MB)"] += ( + resource_usage.ru_maxrss / 1024 + ) + self.perf_metrics["user mode cpu time (secs)"] += resource_usage.ru_utime + self.perf_metrics["system mode cpu time (secs)"] += resource_usage.ru_stime + logger.debug( + f"{self.key} perf metrics:{os.linesep}{os.linesep.join(f'{name}: {value}' for name, value in self.perf_metrics.items())}" + ) + + if self.perf_profile is not None and self.elapsed_time > 3: + logger.debug( + f"{self.key} perf profile:{os.linesep}{cprofile_to_string(self.perf_profile)}" + ) + + def cleanup(self): + if self.perf_profile is not None: + self.perf_profile.disable() + self.perf_profile = None + self.clean_complex_attrs() + + def clean_complex_attrs(self): + + self._np_randgen = None + self._py_randgen = None + self.perf_profile = None + + def is_primitive(obj: Any): + return isinstance(obj, (bool, str, int, float)) + + def is_primitive_iterable(obj: Any): + if isinstance(obj, dict): + return all( + is_primitive(key) and is_primitive(value) + for key, value in obj.items() + ) + elif isinstance(obj, Iterable): + return all(is_primitive(elem) for elem in obj) + return False + + if hasattr(self, "__dict__"): + complex_attrs = [ + attr + for attr, obj in vars(self).items() + if not ( + attr in self._pristine_attrs + or is_primitive(obj) + or is_primitive_iterable(obj) + ) + ] + if complex_attrs: + logger.debug( + f"removing complex attributes not explicitly declared in __slots__: {complex_attrs}" + ) + for attr in complex_attrs: + delattr(self, attr) + + def clean_output(self, force=False) -> None: + if force or self.temp_output: + remove_path(self.runtime_output_abspath) + remove_path(self.final_output_abspath) + + @logger.catch(reraise=pytest_running(), message="failed to dump task") + def dump(self): + os.makedirs(self.temp_abspath, exist_ok=True) + dump_path = os.path.join(self.temp_abspath, f"{self.key}.pickle") + dump(self, dump_path) + logger.info(f"{self.key} saved to {dump_path}") + + def add_elapsed_time(self, metric_name: str = None) -> float: + """ + Start or stop the timer. If `metric_name` is provided, add the elapsed time to the task's performance metrics. + + Example: + ``` + task.add_elapsed_time() # @t0 start timer + e1 = task.add_elapsed_time("input load time (secs)") # @t1 stop timer and add elapsed time e1=t1-t0 to metric + e2 = task.add_elapsed_time("compute time (secs)") # @t2 stop timer and add elapsed time e2=t2-t1 to metric + ``` + """ + self.inject_fault() + assert ( + self._timer_start is not None or metric_name is None + ), f"timer not started, cannot save '{metric_name}'" + if self._timer_start is None or metric_name is None: + self._timer_start = time.time() + return 0.0 + else: + current_time = time.time() + elapsed_time = current_time - self._timer_start + self.perf_metrics[metric_name] += elapsed_time + self._timer_start = current_time + return elapsed_time + + def merge_metrics(self, metrics: Dict[str, int]): + for name, value in metrics.items(): + self.perf_metrics[name] += value + + def run_on_ray(self) -> ray.ObjectRef: + """ + Run the task on Ray. + Return an `ObjectRef`, which can be used with `ray.get` to wait for the output dataset. + """ + if self._dataset_ref is not None: + # already started + return self._dataset_ref + + # read the output dataset if the task has already finished + if os.path.exists(self.ray_dataset_path): + logger.info(f"task {self.key} already finished, skipping") + output = load(self.ray_dataset_path) + self._dataset_ref = ray.put(output) + return self._dataset_ref + + task = copy.copy(self) + task.input_deps = {dep_key: None for dep_key in task.input_deps} + + @ray.remote + def exec_task(task: Task, *inputs: DataSet) -> DataSet: + import multiprocessing as mp + import os + from pathlib import Path + + from loguru import logger + + # ray use a process pool to execute tasks + # we set the current process name to the task name + # so that we can see task name in the logs + mp.current_process().name = task.key + + # probe the retry count + task.retry_count = 0 + while os.path.exists(task.ray_marker_path): + task.retry_count += 1 + if task.retry_count > DEFAULT_MAX_RETRY_COUNT: + raise RuntimeError( + f"task {task.key} failed after {task.retry_count} retries" + ) + if task.retry_count > 0: + logger.warning( + f"task {task.key} is being retried for the {task.retry_count}th time" + ) + # create the marker file + Path(task.ray_marker_path).touch() + + # put the inputs into the task + assert len(inputs) == len(task.input_deps) + task.input_datasets = list(inputs) + # execute the task + status = task.exec() + if status != WorkStatus.SUCCEED: + raise task.exception or RuntimeError( + f"task {task.key} failed with status {status}" + ) + + # dump the output dataset atomically + os.makedirs(os.path.dirname(task.ray_dataset_path), exist_ok=True) + dump(task.output, task.ray_dataset_path, atomic_write=True) + return task.output + + # this shows as {"name": ...} in timeline + exec_task._function_name = repr(task) + + remote_function = exec_task.options( + # ray task name + # do not include task id so that they can be grouped by node in ray dashboard + name=f"{task.node_id}.{self.__class__.__name__}", + num_cpus=self.cpu_limit, + num_gpus=self.gpu_limit, + memory=int(self.memory_limit), + # note: `exec_on_scheduler` is ignored here, + # because dataset is distributed on ray + ) + try: + self._dataset_ref = remote_function.remote( + task, *[dep.run_on_ray() for dep in self.input_deps.values()] + ) + except RuntimeError as e: + if ( + "SimpleQueue objects should only be shared between processes through inheritance" + in str(e) + ): + raise RuntimeError( + f"Can't pickle task '{task.key}'. Please check if your function has captured unpicklable objects. {task.location}\n" + f"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." + ) from e + raise e + return self._dataset_ref + + +class ExecSqlQueryMixin(Task): + + enable_temp_directory = False + cpu_overcommit_ratio = 1.0 + memory_overcommit_ratio = 1.0 + input_view_index = 0 + query_udfs: List[UDFContext] = [] + parquet_compression = None + parquet_compression_level = None + + @cached_property + def rand_seed_float(self) -> int: + return self.random_float() + + @cached_property + def rand_seed_uint32(self) -> int: + return self.random_uint32() + + @property + def input_udfs(self) -> List[UDFContext]: + if self.input_datasets is None: + return [] + return [udf for dataset in self.input_datasets for udf in dataset.udfs] + + @property + def udfs(self): + return set(self.query_udfs + self.input_udfs) + + @property + def compression_type_str(self): + return ( + f"COMPRESSION '{self.parquet_compression}'" + if self.parquet_compression is not None + else "COMPRESSION 'uncompressed'" + ) + + @property + def compression_level_str(self): + return ( + f"COMPRESSION_LEVEL {self.parquet_compression_level}" + if self.parquet_compression == "ZSTD" + and self.parquet_compression_level is not None + else "" + ) + + @property + def compression_options(self): + return ", ".join( + filter(None, (self.compression_type_str, self.compression_level_str)) + ) + + def prepare_connection(self, conn: duckdb.DuckDBPyConnection): + logger.debug(f"duckdb version: {duckdb.__version__}") + # set random seed + self.exec_query(conn, f"select setseed({self.rand_seed_float})") + # prepare connection + effective_cpu_count = math.ceil(self.cpu_limit * self.cpu_overcommit_ratio) + effective_memory_size = round_up( + self.memory_limit * self.memory_overcommit_ratio, MB + ) + self.exec_query( + conn, + f""" + SET threads TO {effective_cpu_count}; + SET memory_limit='{effective_memory_size // MB}MB'; + SET temp_directory='{self.temp_abspath if self.enable_temp_directory else ""}'; + SET enable_object_cache=true; + SET arrow_large_buffer_size=true; + SET preserve_insertion_order=false; + SET max_expression_depth=10000; +""", + ) + for udf in self.udfs: + logger.debug("bind udf: {}", udf) + udf.bind(conn) + + def create_input_views( + self, + conn: duckdb.DuckDBPyConnection, + input_datasets: List[DataSet], + filesystem: fsspec.AbstractFileSystem = None, + ) -> List[str]: + input_views = OrderedDict() + for input_dataset in input_datasets: + self.input_view_index += 1 + view_name = f"{INPUT_VIEW_PREFIX}_{self.id}_{self.input_view_index:06d}" + input_views[view_name] = ( + f"CREATE VIEW {view_name} AS {input_dataset.sql_query_fragment(filesystem, conn)};" + ) + logger.debug(f"create input view '{view_name}': {input_views[view_name]}") + conn.sql(input_views[view_name]) + return list(input_views.keys()) + + def exec_query( + self, + conn: duckdb.DuckDBPyConnection, + query_statement: str, + enable_profiling=False, + log_query=True, + log_output=False, + ) -> Dict[str, int]: + perf_metrics: Dict[str, np.int64] = defaultdict(np.int64) + + try: + if log_query: + logger.debug(f"running sql query: {query_statement}") + start_time = time.time() + query_output = conn.sql( + "SET enable_profiling='json';" + if enable_profiling + else "RESET enable_profiling;" + ) + query_output = conn.sql(query_statement) + elapsed_time = time.time() - start_time + if log_query: + logger.debug(f"query elapsed time: {elapsed_time:.3f} secs") + except duckdb.OutOfMemoryException as ex: + raise OutOfMemory(f"{self.key} failed with OOM error") from ex + except Exception as ex: + # attach the query statement to the exception + raise RuntimeError(f"failed to run query: {query_statement}") from ex + + def sum_children_metrics(obj: Dict, metric: str): + value = obj.get(metric, None) + if value is not None: + return value + if "children" not in obj: + return 0 + return sum(sum_children_metrics(child, metric) for child in obj["children"]) + + def extract_perf_metrics(obj: Dict): + name = obj.get("operator_type", "") + if name.startswith("TABLE_SCAN"): + perf_metrics["num input rows"] += obj["operator_cardinality"] + perf_metrics["input load time (secs)"] += obj["operator_timing"] + elif name.startswith("COPY_TO_FILE"): + perf_metrics["num output rows"] += sum( + sum_children_metrics(child, "operator_cardinality") + for child in obj["children"] + ) + perf_metrics["output dump time (secs)"] += obj["operator_timing"] + return obj + + if query_output is not None: + output_rows = query_output.fetchall() + if log_output or (enable_profiling and self.ctx.enable_profiling): + for row in output_rows: + logger.debug( + f"query output:{os.linesep}{''.join(filter(None, row))}" + ) + if enable_profiling: + _, json_str = output_rows[0] + json.loads(json_str, object_hook=extract_perf_metrics) + + return perf_metrics + + +class DataSourceTask(Task): + def __init__( + self, + ctx: RuntimeContext, + dataset: DataSet, + partition_infos: List[PartitionInfo], + ) -> None: + super().__init__(ctx, [], partition_infos) + self.dataset = dataset + + def __str__(self) -> str: + return super().__str__() + f", dataset=<{self.dataset}>" + + @property + def exec_on_scheduler(self) -> bool: + return True + + def run(self) -> bool: + logger.info(f"added data source: {self.dataset}") + if isinstance(self.dataset, (SqlQueryDataSet, ArrowTableDataSet)): + self.dataset = ParquetDataSet.create_from( + self.dataset.to_arrow_table(), self.runtime_output_abspath + ) + return True + + +class MergeDataSetsTask(Task): + @property + def exec_on_scheduler(self) -> bool: + return True + + @property + def self_contained_output(self): + return False + + def initialize(self): + pass + + def finalize(self): + pass + + def run(self) -> bool: + datasets = self.input_datasets + assert datasets, f"empty list of input datasets: {self}" + assert all( + isinstance(dataset, (DataSet, type(datasets[0]))) for dataset in datasets + ) + self.dataset = datasets[0].merge(datasets) + logger.info(f"created merged dataset: {self.dataset}") + return True + + +class SplitDataSetTask(Task): + + __slots__ = ( + "partition", + "npartitions", + ) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> None: + assert ( + len(input_deps) == 1 + ), f"wrong number of input deps for data set partition: {input_deps}" + super().__init__(ctx, input_deps, partition_infos) + self.partition = partition_infos[-1].index + self.npartitions = partition_infos[-1].npartitions + + @property + def exec_on_scheduler(self) -> bool: + return True + + @property + def self_contained_output(self): + return False + + def initialize(self): + pass + + def finalize(self): + pass + + def run(self) -> bool: + self.dataset = self.input_datasets[0].partition_by_files(self.npartitions)[ + self.partition + ] + return True + + +class PartitionProducerTask(Task): + + __slots__ = ( + "npartitions", + "dimension", + "partitioned_datasets", + ) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + npartitions: int, + dimension: str, + output_name: str = None, + output_path: str = None, + cpu_limit: int = None, + memory_limit: int = None, + ) -> None: + assert len(input_deps) == 1, f"wrong number of inputs: {input_deps}" + assert isinstance( + npartitions, int + ), f"npartitions is not an integer: {npartitions}" + super().__init__( + ctx, + input_deps, + partition_infos, + output_name, + output_path, + cpu_limit, + memory_limit=memory_limit, + ) + self.npartitions = npartitions + self.dimension = dimension + # implementor should set this rather than `dataset` + self.partitioned_datasets: List[DataSet] = None + + def __str__(self) -> str: + return ( + super().__str__() + + f", npartitions={self.npartitions}, dimension={self.dimension}" + ) + + def _create_empty_file(self, partition_idx: int, dataset: DataSet) -> str: + """ + Create an empty file for a partition according to the schema of the dataset. + Return the path relative to the output directory. + """ + if isinstance(self, HashPartitionDuckDbTask) and self.hive_partitioning: + empty_file_prefix = os.path.join( + self.runtime_output_abspath, + f"{self.data_partition_column}={partition_idx}", + f"{self.output_filename}-{partition_idx}-empty", + ) + Path(empty_file_prefix).parent.mkdir(exist_ok=True) + else: + empty_file_prefix = os.path.join( + self.runtime_output_abspath, + f"{self.output_filename}-{partition_idx}-empty", + ) + + if isinstance(dataset, CsvDataSet): + empty_file_path = Path(empty_file_prefix + ".csv") + empty_file_path.touch() + elif isinstance(dataset, JsonDataSet): + empty_file_path = Path(empty_file_prefix + ".json") + empty_file_path.touch() + elif isinstance(dataset, ParquetDataSet): + with duckdb.connect(database=":memory:") as conn: + conn.sql(f"SET threads TO 1") + dataset_schema = dataset.to_batch_reader(batch_size=1, conn=conn).schema + extra_partitions = ( + [PartitionInfo(partition_idx, self.npartitions, self.dimension)] + if not isinstance(self, HashPartitionTask) + else [ + PartitionInfo(partition_idx, self.npartitions, self.dimension), + PartitionInfo( + partition_idx, self.npartitions, self.data_partition_column + ), + ] + ) + schema_with_metadata = filter_schema( + dataset_schema, excluded_cols=GENERATED_COLUMNS + ).with_metadata(self.parquet_kv_metadata_bytes(extra_partitions)) + empty_file_path = Path(empty_file_prefix + ".parquet") + parquet.ParquetWriter(empty_file_path, schema_with_metadata).close() + else: + raise ValueError(f"unsupported dataset type: {type(dataset)}") + + return str(empty_file_path.relative_to(self.runtime_output_abspath)) + + def finalize(self): + assert ( + len(self.partitioned_datasets) == self.npartitions + ), f"number of partitions {len(self.partitioned_datasets)} not equal to {self.npartitions}" + is_empty_partition = [dataset.empty for dataset in self.partitioned_datasets] + + if all(is_empty_partition): + for dataset in self.partitioned_datasets: + dataset.paths.clear() + else: + # Create an empty file for each empty partition. + # This is to ensure that partition consumers have at least one file to read. + empty_partitions = [ + idx for idx, empty in enumerate(is_empty_partition) if empty + ] + nonempty_partitions = [ + idx for idx, empty in enumerate(is_empty_partition) if not empty + ] + first_nonempty_dataset = self.partitioned_datasets[nonempty_partitions[0]] + if empty_partitions: + with ThreadPoolExecutor(self.cpu_limit) as pool: + empty_file_paths = list( + pool.map( + lambda idx: self._create_empty_file( + idx, first_nonempty_dataset + ), + empty_partitions, + ) + ) + for partition_idx, empty_file_path in zip( + empty_partitions, empty_file_paths + ): + self.partitioned_datasets[partition_idx].reset( + [empty_file_path], self.runtime_output_abspath + ) + logger.debug( + f"created empty output files in partitions {empty_partitions} of {repr(self)}: {empty_file_paths[:3]}..." + ) + + # reset root_dir from runtime_output_abspath to final_output_abspath + for dataset in self.partitioned_datasets: + # XXX: if the task has output in `runtime_output_abspath`, + # `root_dir` must be set and all row ranges must be full ranges. + if dataset.root_dir == self.runtime_output_abspath: + dataset.reset( + dataset.paths, self.final_output_abspath, dataset.recursive + ) + # XXX: otherwise, we assume there is no output in `runtime_output_abspath`. + # do nothing to the dataset. + self.dataset = PartitionedDataSet(self.partitioned_datasets) + + super().finalize() + + def run(self) -> bool: + raise NotImplementedError + + +class RepeatPartitionProducerTask(PartitionProducerTask): + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + npartitions: int, + dimension: str, + cpu_limit: int = None, + memory_limit: int = None, + ) -> None: + super().__init__( + ctx, + input_deps, + partition_infos, + npartitions, + dimension, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + + @property + def exec_on_scheduler(self) -> bool: + return True + + @property + def self_contained_output(self): + return False + + def initialize(self): + pass + + def run(self) -> bool: + self.partitioned_datasets = [ + self.input_datasets[0] for _ in range(self.npartitions) + ] + return True + + +class UserDefinedPartitionProducerTask(PartitionProducerTask): + + __slots__ = ("partition_func",) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + npartitions: int, + dimension: str, + partition_func: Callable[[RuntimeContext, DataSet], List[DataSet]], + cpu_limit: int = None, + memory_limit: int = None, + ) -> None: + super().__init__( + ctx, + input_deps, + partition_infos, + npartitions, + dimension, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + self.partition_func = partition_func + + def run(self) -> bool: + try: + self.partitioned_datasets = self.partition_func( + self.ctx, self.input_datasets[0] + ) + return True + finally: + self.partition_func = None + + +class EvenlyDistributedPartitionProducerTask(PartitionProducerTask): + + __slots__ = ( + "partition_by_rows", + "random_shuffle", + ) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + npartitions: int, + dimension: str, + partition_by_rows=False, + random_shuffle=False, + cpu_limit: int = None, + memory_limit: int = None, + ) -> None: + super().__init__( + ctx, + input_deps, + partition_infos, + npartitions, + dimension, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + self.partition_by_rows = partition_by_rows + self.random_shuffle = random_shuffle + + @property + def exec_on_scheduler(self) -> bool: + return True + + @property + def self_contained_output(self): + return False + + def run(self) -> bool: + input_dataset = self.input_datasets[0] + assert not ( + self.partition_by_rows and not isinstance(input_dataset, ParquetDataSet) + ), f"Only parquet dataset supports partition by rows, found: {input_dataset}" + if isinstance(input_dataset, ParquetDataSet) and self.partition_by_rows: + self.partitioned_datasets = input_dataset.partition_by_rows( + self.npartitions, self.random_shuffle + ) + else: + self.partitioned_datasets = input_dataset.partition_by_files( + self.npartitions, self.random_shuffle + ) + return True + + +class LoadPartitionedDataSetProducerTask(PartitionProducerTask): + + __slots__ = ( + "data_partition_column", + "hive_partitioning", + ) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + npartitions: int, + dimension: str, + data_partition_column: str, + hive_partitioning: bool, + cpu_limit: int = None, + memory_limit: int = None, + ) -> None: + super().__init__( + ctx, + input_deps, + partition_infos, + npartitions, + dimension, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + self.data_partition_column = data_partition_column + self.hive_partitioning = hive_partitioning + + def run(self) -> bool: + input_dataset = self.input_datasets[0] + assert isinstance( + input_dataset, ParquetDataSet + ), f"Not parquet dataset: {input_dataset}" + self.partitioned_datasets = input_dataset.load_partitioned_datasets( + self.npartitions, self.data_partition_column, self.hive_partitioning + ) + return True + + +class PartitionConsumerTask(Task): + + __slots__ = ("last_partition",) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[PartitionProducerTask], + partition_infos: List[PartitionInfo], + ) -> None: + assert all(isinstance(task, PartitionProducerTask) for task in input_deps) + super().__init__(ctx, input_deps, partition_infos) + self.last_partition = partition_infos[-1] + + def __str__(self) -> str: + return super().__str__() + f", dataset=<{self.dataset}>" + + @property + def exec_on_scheduler(self) -> bool: + return not self.ctx.remove_empty_parquet + + @property + def self_contained_output(self): + return False + + def initialize(self): + pass + + def finalize(self): + pass + + def run(self) -> bool: + # Build the dataset only after all `input_deps` finished, since `input_deps` could be tried multiple times. + # Consumers always follow producers, so the input is a list of partitioned datasets. + assert all( + isinstance(dataset, PartitionedDataSet) for dataset in self.input_datasets + ) + datasets = [ + dataset[self.last_partition.index] for dataset in self.input_datasets + ] + self.dataset = datasets[0].merge(datasets) + + if self.ctx.remove_empty_parquet and isinstance(self.dataset, ParquetDataSet): + self.dataset.remove_empty_files() + + assert ( + self.ctx.skip_task_with_empty_input or not self.dataset.empty + ), f"found unexpected empty partition {self.last_partition} generated by {self.input_deps.keys()}" + return True + + +class RangePartitionTask(Task): + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> None: + super().__init__(ctx, input_deps, partition_infos) + + +class PythonScriptTask(ExecSqlQueryMixin, Task): + + __slots__ = ("process_func",) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + process_func: Callable[[RuntimeContext, List[DataSet], str], bool] = None, + output_name: str = None, + output_path: str = None, + cpu_limit: int = None, + gpu_limit: float = None, + memory_limit: int = None, + ) -> None: + super().__init__( + ctx, + input_deps, + partition_infos, + output_name, + output_path, + cpu_limit, + gpu_limit, + memory_limit, + ) + self.process_func = process_func + + def process( + self, + runtime_ctx: RuntimeContext, + input_datasets: List[DataSet], + output_path: str, + ) -> bool: + """ + This method can be overridden in subclass of `PythonScriptTask`. + + Parameters + ---------- + runtime_ctx + The runtime context, which defines a few global configuration info. + input_datasets + A list of input datasets. The number of datasets equal to the number of input_deps. + output_path + The absolute path of output directory created for each task generated from this node. + The outputs generated by this node would be consumed by tasks of downstream nodes. + + Returns + ------- + Return true if success. Return false or throw an exception if there is any error. + """ + return self.process_func(runtime_ctx, input_datasets, output_path) + + def run(self) -> bool: + try: + self.add_elapsed_time() + if self.skip_when_any_input_empty: + return True + return self.process( + self.ctx.set_current_task(self), + self.input_datasets, + self.runtime_output_abspath, + ) + finally: + self.process_func = None + self.dataset = FileSet(["*"], root_dir=self.final_output_abspath) + + +class ArrowComputeTask(ExecSqlQueryMixin, Task): + + cpu_overcommit_ratio = 0.5 + memory_overcommit_ratio = 0.5 + + __slots__ = ( + "process_func", + "parquet_row_group_size", + "parquet_row_group_bytes", + "parquet_dictionary_encoding", + "use_duckdb_reader", + ) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + process_func: Callable[[RuntimeContext, List[arrow.Table]], arrow.Table] = None, + parquet_row_group_size: int = DEFAULT_ROW_GROUP_SIZE, + parquet_dictionary_encoding=False, + parquet_compression="ZSTD", + parquet_compression_level=3, + use_duckdb_reader=False, + output_name: str = None, + output_path: str = None, + cpu_limit: int = None, + gpu_limit: float = None, + memory_limit: int = None, + ) -> None: + super().__init__( + ctx, + input_deps, + partition_infos, + output_name, + output_path, + cpu_limit, + gpu_limit, + memory_limit, + ) + self.process_func = process_func + self.parquet_row_group_size = parquet_row_group_size + self.parquet_row_group_bytes = clamp_row_group_bytes( + parquet_row_group_size * 4 * KB + ) + self.parquet_dictionary_encoding = parquet_dictionary_encoding + self.parquet_compression = parquet_compression + self.parquet_compression_level = parquet_compression_level + self.use_duckdb_reader = use_duckdb_reader + + def clean_complex_attrs(self): + self.exec_cq = None + self.process_func = None + super().clean_complex_attrs() + + def _call_process( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + return self.process(runtime_ctx, input_tables) + + def process( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + """ + This method can be overridden in subclass of `ArrowComputeTask`. + + Parameters + ---------- + runtime_ctx + The runtime context, which defines a few global configuration info. + input_tables + A list of arrow tables. The number of arrow tables equal to the number of input_deps. + + Returns + ------- + Return the output as a arrow table. Throw an exception if there is any error. + """ + return self.process_func(runtime_ctx, input_tables) + + def run(self) -> bool: + try: + conn = None + self.add_elapsed_time() + if self.skip_when_any_input_empty: + return True + + if self.use_duckdb_reader: + conn = duckdb.connect( + database=":memory:", config={"allow_unsigned_extensions": "true"} + ) + self.prepare_connection(conn) + + input_tables = [ + dataset.to_arrow_table(max_workers=self.cpu_limit, conn=conn) + for dataset in self.input_datasets + ] + self.perf_metrics["num input rows"] += sum( + table.num_rows for table in input_tables + ) + self.add_elapsed_time("input load time (secs)") + if conn is not None: + conn.close() + + output_table = self._call_process( + self.ctx.set_current_task(self), input_tables + ) + self.add_elapsed_time("compute time (secs)") + + return self.dump_output(output_table) + except arrow.lib.ArrowMemoryError as ex: + raise OutOfMemory(f"{self.key} failed with OOM error") from ex + finally: + if conn is not None: + conn.close() + + def dump_output(self, output_table: arrow.Table): + if output_table is None: + logger.warning(f"user's process method returns none") + return True + + if self.parquet_row_group_size == DEFAULT_ROW_GROUP_SIZE: + # adjust row group size if it is not set by user + self.adjust_row_group_size(output_table.nbytes, output_table.num_rows) + + # write arrow table to parquet files + dump_to_parquet_files( + output_table.replace_schema_metadata(self.parquet_kv_metadata_bytes()), + self.runtime_output_abspath, + self.output_filename, + compression=( + self.parquet_compression + if self.parquet_compression is not None + else "NONE" + ), + compression_level=self.parquet_compression_level, + row_group_size=self.parquet_row_group_size, + row_group_bytes=self.parquet_row_group_bytes, + use_dictionary=self.parquet_dictionary_encoding, + max_workers=self.cpu_limit, + ) + self.perf_metrics["num output rows"] += output_table.num_rows + self.add_elapsed_time("output dump time (secs)") + + return True + + +class StreamOutput(object): + + __slots__ = ( + "output_table", + "batch_indices", + "force_checkpoint", + ) + + def __init__( + self, + output_table: arrow.Table, + batch_indices: List[int] = None, + force_checkpoint=False, + ) -> None: + self.output_table = cast_columns_to_large_string(output_table) + self.batch_indices = batch_indices or [] + self.force_checkpoint = force_checkpoint and bool(batch_indices) + + +class ArrowStreamTask(ExecSqlQueryMixin, Task): + + cpu_overcommit_ratio = 0.5 + memory_overcommit_ratio = 0.5 + + __slots__ = ( + "process_func", + "background_io_thread", + "streaming_batch_size", + "streaming_batch_count", + "parquet_row_group_size", + "parquet_row_group_bytes", + "parquet_dictionary_encoding", + "parquet_compression", + "parquet_compression_level", + "secs_checkpoint_interval", + ) + + class RuntimeState(object): + + __slots__ = ( + "last_batch_indices", + "input_batch_offsets", + "streaming_output_paths", + "streaming_batch_size", + "streaming_batch_count", + ) + + def __init__( + self, streaming_batch_size: int, streaming_batch_count: int + ) -> None: + self.last_batch_indices: List[int] = None + self.input_batch_offsets: List[int] = None + self.streaming_output_paths: List[str] = [] + self.streaming_batch_size: int = streaming_batch_size + self.streaming_batch_count: int = streaming_batch_count + + def __str__(self) -> str: + return f"streaming_batch_size={self.streaming_batch_size}, input_batch_offsets={self.input_batch_offsets}, streaming_output_paths[{len(self.streaming_output_paths)}]={self.streaming_output_paths[:3]}..." + + @property + def max_batch_offsets(self): + return max(self.input_batch_offsets) + + def update_batch_offsets(self, batch_indices: Optional[List[int]]): + if batch_indices is None: + return + if self.last_batch_indices is None: + self.last_batch_indices = [-1] * len(batch_indices) + if self.input_batch_offsets is None: + self.input_batch_offsets = [0] * len(batch_indices) + self.input_batch_offsets = [ + i + j - k + for i, j, k in zip( + self.input_batch_offsets, batch_indices, self.last_batch_indices + ) + ] + self.last_batch_indices = batch_indices + + def reset(self): + self.input_batch_offsets.clear() + self.streaming_output_paths.clear() + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + process_func: Callable[ + [RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table] + ] = None, + background_io_thread=True, + streaming_batch_size: int = DEFAULT_BATCH_SIZE, + secs_checkpoint_interval: int = None, + parquet_row_group_size: int = DEFAULT_ROW_GROUP_SIZE, + parquet_dictionary_encoding=False, + parquet_compression="ZSTD", + parquet_compression_level=3, + use_duckdb_reader=False, + output_name: str = None, + output_path: str = None, + cpu_limit: int = None, + gpu_limit: float = None, + memory_limit: int = None, + ) -> None: + super().__init__( + ctx, + input_deps, + partition_infos, + output_name, + output_path, + cpu_limit, + gpu_limit, + memory_limit, + ) + self.process_func = process_func + self.background_io_thread = background_io_thread + self.streaming_batch_size = streaming_batch_size + self.streaming_batch_count = 1 + self.parquet_row_group_size = parquet_row_group_size + self.parquet_row_group_bytes = clamp_row_group_bytes( + parquet_row_group_size * 4 * KB + ) + self.parquet_dictionary_encoding = parquet_dictionary_encoding + self.parquet_compression = parquet_compression + self.parquet_compression_level = parquet_compression_level + self.use_duckdb_reader = use_duckdb_reader + self.secs_checkpoint_interval = ( + secs_checkpoint_interval or self.ctx.secs_executor_probe_timeout + ) + self.runtime_state: Optional[ArrowStreamTask.RuntimeState] = None + + def __str__(self) -> str: + return ( + super().__str__() + + f", background_io_thread={self.background_io_thread}, streaming_batch_size={self.streaming_batch_size}, checkpoint_interval={self.secs_checkpoint_interval}s" + ) + + @property + def max_batch_size(self) -> int: + return self._memory_limit // 2 + + def finalize(self): + if self.runtime_state is not None: + for path in self.runtime_state.streaming_output_paths: + if not path.startswith(self.runtime_output_abspath): + os.link( + path, + os.path.join( + self.runtime_output_abspath, os.path.basename(path) + ), + ) + self.runtime_state = None + super().finalize() + + def clean_complex_attrs(self): + self.exec_cq = None + self.process_func = None + super().clean_complex_attrs() + + def _wrap_output( + self, output: Union[arrow.Table, StreamOutput], batch_indices: List[int] = None + ) -> StreamOutput: + if isinstance(output, StreamOutput): + assert len(output.batch_indices) == 0 or len(output.batch_indices) == len( + self.input_deps + ), f"num of batch indices {len(output.batch_indices)} not equal to num of inputs {len(self.input_deps)}" + return output + else: + assert isinstance(output, arrow.Table) + return StreamOutput(output, batch_indices) + + def _call_process( + self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] + ) -> Iterable[StreamOutput]: + for output in self.process(runtime_ctx, input_readers): + yield self._wrap_output(output) + + def process( + self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] + ) -> Iterable[arrow.Table]: + """ + This method can be overridden in subclass of `ArrowStreamTask`. + + Parameters + ---------- + runtime_ctx + The runtime context, which defines a few global configuration info. + input_readers + A list of RecordBatchReader. The number of readers equal to the number of input_deps. + + Returns + ------- + Return the output as a arrow table. Throw an exception if there is any error. + """ + return self.process_func(runtime_ctx, input_readers) + + def restore_input_state( + self, runtime_state: RuntimeState, input_readers: List[arrow.RecordBatchReader] + ): + logger.info(f"restore input state to: {runtime_state}") + assert len(runtime_state.input_batch_offsets) == len( + input_readers + ), f"num of batch offsets {len(runtime_state.input_batch_offsets)} not equal to num of input readers {len(input_readers)}" + + for batch_offset, input_reader in zip( + runtime_state.input_batch_offsets, input_readers + ): + if batch_offset <= 0: + continue + for ( + batch_index, + input_batch, + ) in enumerate(input_reader): + logger.debug( + f"skipped input batch #{batch_index}: {input_batch.num_rows} rows" + ) + if batch_index + 1 == batch_offset: + break + assert batch_index + 1 <= batch_offset + + def run(self) -> bool: + self.add_elapsed_time() + if self.skip_when_any_input_empty: + return True + + input_row_ranges = [ + dataset.resolved_row_ranges + for dataset in self.input_datasets + if isinstance(dataset, ParquetDataSet) + ] + input_byte_size = [ + sum(row_range.estimated_data_size for row_range in row_ranges) + for row_ranges in input_row_ranges + ] + input_num_rows = [ + sum(row_range.num_rows for row_range in row_ranges) + for row_ranges in input_row_ranges + ] + input_files = [ + set(row_range.path for row_range in row_ranges) + for row_ranges in input_row_ranges + ] + self.perf_metrics["num input rows"] += sum(input_num_rows) + self.perf_metrics["input data size (MB)"] += sum(input_byte_size) / MB + + # calculate the max streaming batch size based on memory limit + avg_input_row_size = sum( + self.compute_avg_row_size(nbytes, num_rows) + for nbytes, num_rows in zip(input_byte_size, input_num_rows) + ) + max_batch_rows = self.max_batch_size // avg_input_row_size + + if self.runtime_state is None: + if self.streaming_batch_size > max_batch_rows: + logger.warning( + f"reduce streaming batch size from {self.streaming_batch_size} to {max_batch_rows} (approx. {self.max_batch_size/GB:.3f}GB)" + ) + self.streaming_batch_size = max_batch_rows + self.streaming_batch_count = max( + 1, + max(map(len, input_files)), + math.ceil(max(input_num_rows) / self.streaming_batch_size), + ) + else: + self.streaming_batch_size = self.runtime_state.streaming_batch_size + self.streaming_batch_count = self.runtime_state.streaming_batch_count + + try: + conn = None + if self.use_duckdb_reader: + conn = duckdb.connect( + database=":memory:", config={"allow_unsigned_extensions": "true"} + ) + self.prepare_connection(conn) + + input_readers = [ + dataset.to_batch_reader( + batch_size=self.streaming_batch_size, + conn=conn, + ) + for dataset in self.input_datasets + ] + + if self.runtime_state is None: + self.runtime_state = ArrowStreamTask.RuntimeState( + self.streaming_batch_size, self.streaming_batch_count + ) + else: + self.restore_input_state(self.runtime_state, input_readers) + self.runtime_state.last_batch_indices = None + + output_iter = self._call_process( + self.ctx.set_current_task(self), input_readers + ) + self.add_elapsed_time("compute time (secs)") + + if self.background_io_thread: + with ConcurrentIter(output_iter) as concurrent_iter: + return self.dump_output(concurrent_iter) + else: + return self.dump_output(output_iter) + except arrow.lib.ArrowMemoryError as ex: + raise OutOfMemory(f"{self.key} failed with OOM error") from ex + finally: + if conn is not None: + conn.close() + + def dump_output(self, output_iter: Iterable[StreamOutput]): + def write_table(writer: parquet.ParquetWriter, table: arrow.Table): + if table.num_rows == 0: + return + writer.write_table(table, self.parquet_row_group_size) + self.perf_metrics["num output rows"] += table.num_rows + self.add_elapsed_time("output dump time (secs)") + + create_checkpoint = False + last_checkpoint_time = ( + time.time() - self.random_float() * self.secs_checkpoint_interval / 2 + ) + + output: StreamOutput = next(output_iter, None) + self.add_elapsed_time("compute time (secs)") + + if output is None: + logger.warning(f"user's process method returns none") + return True + + if self.parquet_row_group_size == DEFAULT_ROW_GROUP_SIZE: + # adjust row group size if it is not set by user + self.adjust_row_group_size( + self.streaming_batch_count * output.output_table.nbytes, + self.streaming_batch_count * output.output_table.num_rows, + ) + + output_iter = itertools.chain([output], output_iter) + buffered_output = output.output_table.slice(length=0) + + for output_file_idx in itertools.count(): + output_path = os.path.join( + self.runtime_output_abspath, + f"{self.output_filename}-{output_file_idx}.parquet", + ) + output_file = open(output_path, "wb", buffering=32 * MB) + + try: + with parquet.ParquetWriter( + where=output_file, + schema=buffered_output.schema.with_metadata( + self.parquet_kv_metadata_bytes() + ), + use_dictionary=self.parquet_dictionary_encoding, + compression=( + self.parquet_compression + if self.parquet_compression is not None + else "NONE" + ), + compression_level=self.parquet_compression_level, + write_batch_size=max(16 * 1024, self.parquet_row_group_size // 8), + data_page_size=max(64 * MB, self.parquet_row_group_bytes // 8), + ) as writer: + + while (output := next(output_iter, None)) is not None: + self.add_elapsed_time("compute time (secs)") + + if ( + buffered_output.num_rows + output.output_table.num_rows + < self.parquet_row_group_size + ): + buffered_output = arrow.concat_tables( + (buffered_output, output.output_table) + ) + else: + write_table(writer, buffered_output) + buffered_output = output.output_table + + periodic_checkpoint = ( + bool(output.batch_indices) + and (time.time() - last_checkpoint_time) + >= self.secs_checkpoint_interval + ) + create_checkpoint = ( + output.force_checkpoint or periodic_checkpoint + ) + + if create_checkpoint: + self.runtime_state.update_batch_offsets( + output.batch_indices + ) + last_checkpoint_time = time.time() + break + + if buffered_output is not None: + write_table(writer, buffered_output) + buffered_output = buffered_output.slice(length=0) + + finally: + if isinstance(output_file, io.IOBase): + output_file.close() + + assert buffered_output is None or buffered_output.num_rows == 0 + self.runtime_state.streaming_output_paths.append(output_path) + + if output is None: + break + + if create_checkpoint and self.exec_cq is not None: + checkpoint = copy.copy(self) + checkpoint.clean_complex_attrs() + self.exec_cq.push(checkpoint, buffering=False) + logger.debug( + f"created and sent checkpoint #{self.runtime_state.max_batch_offsets}/{self.streaming_batch_count}: {self.runtime_state}" + ) + + return True + + +class ArrowBatchTask(ArrowStreamTask): + @property + def max_batch_size(self) -> int: + return self._memory_limit // 3 + + def _call_process( + self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] + ) -> Iterable[arrow.Table]: + with contextlib.ExitStack() as stack: + opened_readers = [ + stack.enter_context( + ConcurrentIter(reader) if self.background_io_thread else reader + ) + for reader in input_readers + ] + for batch_index, input_batches in enumerate( + itertools.zip_longest(*opened_readers, fillvalue=None) + ): + input_tables = [ + ( + reader.schema.empty_table() + if batch is None + else arrow.Table.from_batches([batch], reader.schema) + ) + for reader, batch in zip(input_readers, input_batches) + ] + output_table = self._process_batches(runtime_ctx, input_tables) + yield self._wrap_output( + output_table, [batch_index] * len(input_batches) + ) + + def _process_batches( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + return self.process(runtime_ctx, input_tables) + + def process( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + return self.process_func(runtime_ctx, input_tables) + + +class PandasComputeTask(ArrowComputeTask): + def _call_process( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + input_dfs = [table.to_pandas() for table in input_tables] + output_df = self.process(runtime_ctx, input_dfs) + return ( + arrow.Table.from_pandas(output_df, preserve_index=False) + if output_df is not None + else None + ) + + def process( + self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame] + ) -> pd.DataFrame: + return self.process_func(runtime_ctx, input_dfs) + + +class PandasBatchTask(ArrowBatchTask): + def _process_batches( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + input_dfs = [table.to_pandas() for table in input_tables] + output_df = self.process(runtime_ctx, input_dfs) + return arrow.Table.from_pandas(output_df, preserve_index=False) + + def process( + self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame] + ) -> pd.DataFrame: + return self.process_func(runtime_ctx, input_dfs) + + +class SqlEngineTask(ExecSqlQueryMixin, Task): + + __slots__ = ( + "sql_queries", + "per_thread_output", + "materialize_output", + "materialize_in_memory", + "batched_processing", + "parquet_row_group_size", + "parquet_row_group_bytes", + "parquet_dictionary_encoding", + "parquet_compression", + "parquet_compression_level", + ) + + memory_overcommit_ratio = 0.9 + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + sql_queries: List[str], + udfs: List[UDFContext] = None, + per_thread_output=True, + materialize_output=True, + materialize_in_memory=False, + batched_processing=False, + enable_temp_directory=False, + parquet_row_group_size: int = DEFAULT_ROW_GROUP_SIZE, + parquet_dictionary_encoding=False, + parquet_compression="ZSTD", + parquet_compression_level=3, + output_name: str = None, + output_path: str = None, + cpu_limit: int = None, + gpu_limit: float = None, + memory_limit: int = None, + cpu_overcommit_ratio: float = 1.0, + memory_overcommit_ratio: float = 0.9, + ) -> None: + super().__init__( + ctx, + input_deps, + partition_infos, + output_name, + output_path, + cpu_limit, + gpu_limit, + memory_limit, + ) + self.cpu_overcommit_ratio = cpu_overcommit_ratio + self.memory_overcommit_ratio = memory_overcommit_ratio + self.sql_queries = sql_queries + self.query_udfs: List[UDFContext] = udfs or [] + self.per_thread_output = per_thread_output + self.materialize_output = materialize_output + self.materialize_in_memory = materialize_in_memory + self.batched_processing = batched_processing and len(self.input_deps) == 1 + self.enable_temp_directory = enable_temp_directory + self.parquet_row_group_size = parquet_row_group_size + self.parquet_row_group_bytes = clamp_row_group_bytes( + parquet_row_group_size * 4 * KB + ) + self.parquet_dictionary_encoding = parquet_dictionary_encoding + self.parquet_compression = parquet_compression + self.parquet_compression_level = parquet_compression_level + + def __str__(self) -> str: + return ( + super().__str__() + + f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}" + ) + + @property + def oneline_query(self) -> str: + return "; ".join( + " ".join(filter(None, map(str.strip, query.splitlines()))) + for query in self.sql_queries + ) + + @property + def max_batch_size(self) -> int: + return self._memory_limit // 2 + + def cleanup(self): + self.udfs.clear() + super().cleanup() + + def run(self) -> bool: + if self.skip_when_any_input_empty: + return True + + if self.batched_processing and isinstance( + self.input_datasets[0], ParquetDataSet + ): + input_batches = [ + [batch] + for batch in self.input_datasets[0].partition_by_size( + self.max_batch_size + ) + ] + else: + input_batches = [self.input_datasets] + + for batch_index, input_batch in enumerate(input_batches): + with duckdb.connect( + database=":memory:", config={"allow_unsigned_extensions": "true"} + ) as conn: + self.prepare_connection(conn) + self.process_batch(batch_index, input_batch, conn) + + return True + + def process_batch( + self, + batch_index: int, + input_datasets: List[DataSet], + conn: duckdb.DuckDBPyConnection, + ): + # define inputs as views + input_views = self.create_input_views(conn, input_datasets) + + if isinstance(self.parquet_dictionary_encoding, bool): + dictionary_encoding_cfg = ( + "DICTIONARY_ENCODING TRUE," if self.parquet_dictionary_encoding else "" + ) + else: + dictionary_encoding_cfg = "DICTIONARY_ENCODING ({}),".format( + ", ".join(self.parquet_dictionary_encoding) + ) + + for query_index, sql_query in enumerate(self.sql_queries): + last_query = query_index + 1 == len(self.sql_queries) + output_filename = f"{self.output_filename}-{batch_index}.{query_index}" + output_path = self.runtime_output_abspath + + if not self.per_thread_output: + output_path = os.path.join(output_path, f"{output_filename}.parquet") + + sql_query = sql_query.format( + *input_views, + batch_index=batch_index, + query_index=query_index, + cpu_limit=self.cpu_limit, + memory_limit=self.memory_limit, + rand_seed=self.rand_seed_uint32, + output_filename=output_filename, + **self.partition_infos_as_dict, + ) + + if last_query and self.materialize_in_memory: + self.merge_metrics( + self.exec_query( + conn, + f"EXPLAIN ANALYZE CREATE OR REPLACE TEMP TABLE temp_query_result AS {sql_query}", + enable_profiling=True, + ) + ) + sql_query = f"select * from temp_query_result" + + if last_query and self.materialize_output: + sql_query = f""" + COPY ( + {sql_query} + ) TO '{output_path}' ( + FORMAT PARQUET, + KV_METADATA {self.parquet_kv_metadata_str()}, + {self.compression_options}, + ROW_GROUP_SIZE {self.parquet_row_group_size}, + ROW_GROUP_SIZE_BYTES {self.parquet_row_group_bytes}, + {dictionary_encoding_cfg} + PER_THREAD_OUTPUT {self.per_thread_output}, + FILENAME_PATTERN '{output_filename}.{{i}}', + OVERWRITE_OR_IGNORE true) + """ + + self.merge_metrics( + self.exec_query( + conn, f"EXPLAIN ANALYZE {sql_query}", enable_profiling=True + ) + ) + + +class HashPartitionTask(PartitionProducerTask): + + __slots__ = ( + "hash_columns", + "data_partition_column", + "random_shuffle", + "shuffle_only", + "drop_partition_column", + "use_parquet_writer", + "hive_partitioning", + "parquet_row_group_size", + "parquet_row_group_bytes", + "parquet_dictionary_encoding", + "parquet_compression", + "parquet_compression_level", + "partitioned_datasets", + "_io_workers", + "_partition_files", + "_partition_writers", + "_pending_write_works", + "_file_writer_closed", + ) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + npartitions: int, + dimension: str, + hash_columns: List[str], + data_partition_column: str, + random_shuffle: bool = False, + shuffle_only: bool = False, + drop_partition_column=False, + use_parquet_writer=False, + hive_partitioning=False, + parquet_row_group_size: int = DEFAULT_ROW_GROUP_SIZE, + parquet_dictionary_encoding=False, + parquet_compression="ZSTD", + parquet_compression_level=3, + output_name: str = None, + output_path: str = None, + cpu_limit: int = None, + memory_limit: int = None, + ) -> None: + self.hash_columns = ["random()"] if random_shuffle else hash_columns + self.data_partition_column = data_partition_column + self.random_shuffle = random_shuffle + self.shuffle_only = shuffle_only + self.drop_partition_column = drop_partition_column + self.use_parquet_writer = use_parquet_writer + self.hive_partitioning = hive_partitioning + self.parquet_row_group_size = parquet_row_group_size + self.parquet_row_group_bytes = clamp_row_group_bytes( + parquet_row_group_size * 4 * KB + ) + self.parquet_dictionary_encoding = parquet_dictionary_encoding + self.parquet_compression = parquet_compression + self.parquet_compression_level = parquet_compression_level + super().__init__( + ctx, + input_deps, + partition_infos, + npartitions, + dimension, + output_name, + output_path, + cpu_limit, + memory_limit, + ) + self.partitioned_datasets = None + self._io_workers: ThreadPoolExecutor = None + self._partition_files: List[BinaryIO] = None + self._partition_writers: List[parquet.ParquetWriter] = None + self._pending_write_works: List[Future] = None + self._file_writer_closed = True + + def __str__(self) -> str: + return ( + super().__str__() + + f", hash_columns={self.hash_columns}, data_partition_column={self.data_partition_column}" + ) + + @staticmethod + def create( + engine_type: Literal["duckdb", "arrow"], *args, **kwargs + ) -> "HashPartitionTask": + if engine_type == "duckdb": + return HashPartitionDuckDbTask(*args, *kwargs) + if engine_type == "arrow": + return HashPartitionArrowTask(*args, *kwargs) + raise ValueError(f"Unknown hash partition engine: '{engine_type}'") + + @property + def max_batch_size(self) -> int: + return self._memory_limit // 6 + + @property + def write_buffer_size(self) -> int: + write_buffer_size = min( + 4 * MB, + round_up(min(16 * GB, self.max_batch_size) // self.npartitions, 16 * KB), + ) + return ( + write_buffer_size if write_buffer_size >= 128 * KB else -1 + ) # disable write buffer if too small + + @property + def num_workers(self) -> int: + return min(self.npartitions, self.cpu_limit) + + @property + def io_workers(self): + if self._io_workers is None: + self._io_workers = ThreadPoolExecutor(self.num_workers) + return self._io_workers + + def _wait_pending_writes(self): + for i in range(len(self._pending_write_works)): + if self._pending_write_works[i] is not None: + self._pending_write_works[i].result() + self._pending_write_works[i] = None + + def _close_file_writers(self): + if self._file_writer_closed: + return + self._file_writer_closed = True + self.add_elapsed_time() + self._wait_pending_writes() + if self._io_workers is not None: + list( + self._io_workers.map( + lambda w: w.close(), filter(None, self._partition_writers) + ) + ) + list( + self._io_workers.map( + lambda f: f.close(), filter(None, self._partition_files) + ) + ) + self._io_workers.shutdown(wait=True) + self.add_elapsed_time("output dump time (secs)") + + def _create_file_writer(self, partition_idx: int, schema: arrow.Schema): + partition_filename = f"{self.output_filename}-{partition_idx}.parquet" + partition_path = os.path.join(self.runtime_output_abspath, partition_filename) + + self._partition_files[partition_idx] = open( + partition_path, "wb", buffering=self.write_buffer_size + ) + output_file = self._partition_files[partition_idx] + + self.partitioned_datasets[partition_idx].paths.append(partition_filename) + self._partition_writers[partition_idx] = parquet.ParquetWriter( + where=output_file, + schema=schema.with_metadata( + self.parquet_kv_metadata_bytes( + [ + PartitionInfo(partition_idx, self.npartitions, self.dimension), + PartitionInfo( + partition_idx, self.npartitions, self.data_partition_column + ), + ] + ) + ), + use_dictionary=self.parquet_dictionary_encoding, + compression=( + self.parquet_compression + if self.parquet_compression is not None + else "NONE" + ), + compression_level=self.parquet_compression_level, + write_batch_size=max(16 * 1024, self.parquet_row_group_size // 8), + data_page_size=max(64 * MB, self.parquet_row_group_bytes // 8), + ) + return self._partition_writers[partition_idx] + + def _write_to_partition( + self, partition_idx, partition, pending_write: Future = None + ): + if pending_write is not None: + pending_write.result() + if partition is not None: + writer = self._partition_writers[partition_idx] or self._create_file_writer( + partition_idx, partition.schema + ) + writer.write_table(partition, self.parquet_row_group_size) + + def _write_partitioned_tables(self, partitioned_tables): + assert len(partitioned_tables) == self.npartitions + assert len(self._pending_write_works) == self.npartitions + + self._pending_write_works = [ + self.io_workers.submit( + self._write_to_partition, partition_idx, partition, pending_write + ) + for partition_idx, (partition, pending_write) in enumerate( + zip(partitioned_tables, self._pending_write_works) + ) + ] + self.perf_metrics["num output rows"] += sum( + partition.num_rows + for partition in partitioned_tables + if partition is not None + ) + self._wait_pending_writes() + + def initialize(self): + super().initialize() + if isinstance(self, HashPartitionDuckDbTask) and self.hive_partitioning: + self.partitioned_datasets = [ + ParquetDataSet( + [ + os.path.join( + f"{self.data_partition_column}={partition_idx}", "*.parquet" + ) + ], + root_dir=self.runtime_output_abspath, + ) + for partition_idx in range(self.npartitions) + ] + else: + self.partitioned_datasets = [ + ParquetDataSet([], root_dir=self.runtime_output_abspath) + for _ in range(self.npartitions) + ] + self._partition_files = [None] * self.npartitions + self._partition_writers = [None] * self.npartitions + self._pending_write_works = [None] * self.npartitions + self._file_writer_closed = False + + def finalize(self): + # first close all writers + self._close_file_writers() + assert ( + self.perf_metrics["num input rows"] == self.perf_metrics["num output rows"] + ), f'num input rows {self.perf_metrics["num input rows"]} != num output rows {self.perf_metrics["num output rows"]}' + super().finalize() + + def cleanup(self): + self._close_file_writers() + self._io_workers = None + self._partition_files = None + self._partition_writers = None + super().cleanup() + + def partition(self, input_dataset: ParquetDataSet): + raise NotImplementedError + + def run(self) -> bool: + self.add_elapsed_time() + if self.skip_when_any_input_empty: + return True + + input_dataset = self.input_datasets[0] + assert isinstance( + input_dataset, ParquetDataSet + ), f"only parquet dataset supported, found {input_dataset}" + input_paths = input_dataset.resolved_paths + input_byte_size = input_dataset.estimated_data_size + input_num_rows = input_dataset.num_rows + + logger.info( + f"partitioning dataset: {len(input_paths)} files, {input_byte_size/GB:.3f}GB, {input_num_rows} rows" + ) + input_batches = input_dataset.partition_by_size(self.max_batch_size) + + for batch_index, input_batch in enumerate(input_batches): + batch_start_time = time.time() + batch_byte_size = input_batch.estimated_data_size + batch_num_rows = input_batch.num_rows + logger.info( + f"start to partition batch #{batch_index+1}/{len(input_batches)}: {len(input_batch.resolved_paths)} files, {batch_byte_size/GB:.3f}GB, {batch_num_rows} rows" + ) + self.partition(batch_index, input_batch) + logger.info( + f"finished to partition batch #{batch_index+1}/{len(input_batches)}: {time.time() - batch_start_time:.3f} secs" + ) + + return True + + +class HashPartitionDuckDbTask(ExecSqlQueryMixin, HashPartitionTask): + + memory_overcommit_ratio = 1.0 + + def __str__(self) -> str: + return super().__str__() + f", hive_partitioning={self.hive_partitioning}" + + @property + def partition_query(self): + if self.shuffle_only: + partition_query = r"SELECT * FROM {0}" + else: + if self.random_shuffle: + hash_values = ( + f"random() * {2147483647 // self.npartitions * self.npartitions}" + ) + else: + hash_values = ( + f"hash( concat_ws( '##', {', '.join(self.hash_columns)} ) )" + ) + partition_keys = f"CAST({hash_values} AS UINT64) % {self.npartitions}::UINT64 AS {self.data_partition_column}" + partition_query = f""" + SELECT *, + {partition_keys} + FROM ( + SELECT COLUMNS(c -> c != '{self.data_partition_column}') FROM {{0}} + )""" + return partition_query + + def partition(self, batch_index: int, input_dataset: ParquetDataSet): + with duckdb.connect( + database=":memory:", config={"allow_unsigned_extensions": "true"} + ) as conn: + self.prepare_connection(conn) + if self.hive_partitioning: + self.load_input_batch( + conn, batch_index, input_dataset, sort_by_partition_key=True + ) + self.write_hive_partitions(conn, batch_index, input_dataset) + else: + self.load_input_batch( + conn, batch_index, input_dataset, sort_by_partition_key=True + ) + self.write_flat_partitions(conn, batch_index, input_dataset) + + def load_input_batch( + self, + conn: duckdb.DuckDBPyConnection, + batch_index: int, + input_dataset: ParquetDataSet, + sort_by_partition_key=False, + ): + input_views = self.create_input_views(conn, [input_dataset]) + partition_query = self.partition_query.format( + *input_views, **self.partition_infos_as_dict + ) + if sort_by_partition_key: + partition_query += f" ORDER BY {self.data_partition_column}" + + perf_metrics = self.exec_query( + conn, + f"EXPLAIN ANALYZE CREATE OR REPLACE TABLE temp_query_result AS {partition_query}", + enable_profiling=True, + ) + self.perf_metrics["num input rows"] += perf_metrics["num input rows"] + elapsed_time = self.add_elapsed_time("input load time (secs)") + + # make sure partition keys are in the range of [0, npartitions) + min_partition_key, max_partition_key = conn.sql( + f"SELECT MIN({self.data_partition_column}), MAX({self.data_partition_column}) FROM temp_query_result" + ).fetchall()[0] + assert ( + min_partition_key >= 0 + ), f"partition key {min_partition_key} is out of range 0-{self.npartitions-1}" + assert ( + max_partition_key < self.npartitions + ), f"partition key {max_partition_key} is out of range 0-{self.npartitions-1}" + + logger.debug(f"load input dataset #{batch_index+1}: {elapsed_time:.3f} secs") + + def write_hive_partitions( + self, + conn: duckdb.DuckDBPyConnection, + batch_index: int, + input_dataset: ParquetDataSet, + ): + batch_num_rows = input_dataset.num_rows + self.exec_query( + conn, + f"SET partitioned_write_flush_threshold={round_up(batch_num_rows / self.cpu_limit / 4, KB)}", + ) + copy_query_result = f""" + COPY ( + SELECT * FROM temp_query_result + ) TO '{self.runtime_output_abspath}' ( + FORMAT PARQUET, + OVERWRITE_OR_IGNORE, + WRITE_PARTITION_COLUMNS, + PARTITION_BY {self.data_partition_column}, + KV_METADATA {self.parquet_kv_metadata_str()}, + {self.compression_options}, + ROW_GROUP_SIZE {self.parquet_row_group_size}, + ROW_GROUP_SIZE_BYTES {self.parquet_row_group_bytes}, + {"DICTIONARY_ENCODING TRUE," if self.parquet_dictionary_encoding else ""} + FILENAME_PATTERN '{self.output_filename}-{batch_index}.{{i}}') + """ + perf_metrics = self.exec_query( + conn, f"EXPLAIN ANALYZE {copy_query_result}", enable_profiling=True + ) + self.perf_metrics["num output rows"] += perf_metrics["num output rows"] + elapsed_time = self.add_elapsed_time("output dump time (secs)") + logger.debug(f"write partition data #{batch_index+1}: {elapsed_time:.3f} secs") + + def write_flat_partitions( + self, + conn: duckdb.DuckDBPyConnection, + batch_index: int, + input_dataset: ParquetDataSet, + ): + def write_partition_data( + conn: duckdb.DuckDBPyConnection, partition_batch: List[Tuple[int, str]] + ) -> int: + total_num_rows = 0 + for partition_idx, partition_filter in partition_batch: + if self.use_parquet_writer: + partition_data = conn.sql(partition_filter).fetch_arrow_table() + self._write_to_partition(partition_idx, partition_data) + total_num_rows += partition_data.num_rows + else: + partition_filename = ( + f"{self.output_filename}-{partition_idx}.{batch_index}.parquet" + ) + partition_path = os.path.join( + self.runtime_output_abspath, partition_filename + ) + self.partitioned_datasets[partition_idx].paths.append( + partition_filename + ) + perf_metrics = self.exec_query( + conn, + f""" + EXPLAIN ANALYZE + COPY ( + {partition_filter} + ) TO '{partition_path}' ( + FORMAT PARQUET, + KV_METADATA {self.parquet_kv_metadata_str( + [PartitionInfo(partition_idx, self.npartitions, self.dimension), PartitionInfo(partition_idx, self.npartitions, self.data_partition_column)])}, + {self.compression_options}, + ROW_GROUP_SIZE {self.parquet_row_group_size}, + ROW_GROUP_SIZE_BYTES {self.parquet_row_group_bytes}, + {"DICTIONARY_ENCODING TRUE," if self.parquet_dictionary_encoding else ""} + OVERWRITE_OR_IGNORE true) + """, + enable_profiling=True, + log_query=partition_idx == 0, + log_output=False, + ) # avoid duplicate logs + total_num_rows += perf_metrics["num output rows"] + return total_num_rows + + column_projection = ( + f"* EXCLUDE ({self.data_partition_column})" + if self.drop_partition_column + else "*" + ) + partition_filters = [ + ( + partition_idx, + f"SELECT {column_projection} FROM temp_query_result WHERE {self.data_partition_column} = {partition_idx}", + ) + for partition_idx in range(self.npartitions) + ] + partition_batches = split_into_rows(partition_filters, self.num_workers) + + with contextlib.ExitStack() as stack: + db_conns = [ + stack.enter_context(conn.cursor()) for _ in range(self.num_workers) + ] + self.perf_metrics["num output rows"] += sum( + self.io_workers.map(write_partition_data, db_conns, partition_batches) + ) + elapsed_time = self.add_elapsed_time("output dump time (secs)") + logger.debug(f"write partition data #{batch_index+1}: {elapsed_time:.3f} secs") + + +class HashPartitionArrowTask(HashPartitionTask): + + # WARNING: totally different hash partitions are generated if the random seeds changed. + fixed_rand_seeds = ( + 14592751030717519312, + 9336845975743342460, + 1211974630270170534, + 6392954943940246686, + ) + + def partition(self, batch_index: int, input_dataset: ParquetDataSet): + import polars + + self.add_elapsed_time() + table = input_dataset.to_arrow_table(max_workers=self.cpu_limit) + self.perf_metrics["num input rows"] += table.num_rows + elapsed_time = self.add_elapsed_time("input load time (secs)") + logger.debug( + f"load input dataset: {table.nbytes/MB:.3f}MB, {table.num_rows} rows, {elapsed_time:.3f} secs" + ) + + if self.shuffle_only: + partition_keys = table.column(self.data_partition_column) + elif self.random_shuffle: + partition_keys = arrow.array( + self.numpy_random_gen.integers(self.npartitions, size=table.num_rows) + ) + else: + hash_columns = polars.from_arrow(table.select(self.hash_columns)) + hash_values = hash_columns.hash_rows(*self.fixed_rand_seeds) + partition_keys = (hash_values % self.npartitions).to_arrow() + + if self.data_partition_column in table.column_names: + table = table.drop_columns(self.data_partition_column) + table = table.append_column(self.data_partition_column, partition_keys) + elapsed_time = self.add_elapsed_time("compute time (secs)") + logger.debug(f"generate partition keys: {elapsed_time:.3f} secs") + + table_slice_size = max( + DEFAULT_BATCH_SIZE, min(table.num_rows // 2, 100 * 1024 * 1024) + ) + num_iterations = math.ceil(table.num_rows / table_slice_size) + + def write_partition_data( + partition_batch: List[Tuple[int, polars.DataFrame]], + ) -> int: + total_num_rows = 0 + for partition_idx, partition_data in partition_batch: + total_num_rows += len(partition_data) + self._write_to_partition(partition_idx, partition_data.to_arrow()) + return total_num_rows + + for table_slice_idx, table_slice_offset in enumerate( + range(0, table.num_rows, table_slice_size) + ): + table_slice = table.slice(table_slice_offset, table_slice_size) + logger.debug( + f"table slice #{table_slice_idx+1}/{num_iterations}: {table_slice.nbytes/MB:.3f}MB, {table_slice.num_rows} rows" + ) + + df = polars.from_arrow(table_slice) + del table_slice + elapsed_time = self.add_elapsed_time("compute time (secs)") + logger.debug( + f"convert from arrow table #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs" + ) + + partitioned_dfs = df.partition_by( + [self.data_partition_column], + maintain_order=False, + include_key=not self.drop_partition_column, + as_dict=True, + ) + partitioned_dfs = [ + (partition_idx, df) for (partition_idx,), df in partitioned_dfs.items() + ] + del df + elapsed_time = self.add_elapsed_time("compute time (secs)") + logger.debug( + f"build partition data #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs" + ) + + partition_batches = split_into_rows(partitioned_dfs, self.num_workers) + self.perf_metrics["num output rows"] += sum( + self.io_workers.map(write_partition_data, partition_batches) + ) + elapsed_time = self.add_elapsed_time("output dump time (secs)") + logger.debug( + f"write partition data #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs" + ) + + +class ProjectionTask(Task): + + __slots__ = ( + "columns", + "generated_columns", + "union_by_name", + ) + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + columns: List[str], + generated_columns: List[str], + union_by_name: Optional[bool], + ) -> None: + super().__init__(ctx, input_deps, partition_infos) + self.columns = list(columns) + self.generated_columns = list(generated_columns) + self.union_by_name = union_by_name + + @property + def exec_on_scheduler(self) -> bool: + return True + + @property + def self_contained_output(self): + return False + + def initialize(self): + pass + + def finalize(self): + pass + + def run(self) -> bool: + self.dataset = copy.copy(self.input_datasets[0]) + assert not self.generated_columns or isinstance( + self.dataset, ParquetDataSet + ), f"generated columns can be only applied to parquet dataset, but found: {self.dataset}" + self.dataset.columns = self.columns + self.dataset.generated_columns = self.generated_columns + if self.union_by_name is not None: + self.dataset._union_by_name = self.union_by_name + return True + + +DataSinkType = Literal["link_manifest", "copy", "link_or_copy", "manifest"] + + +class DataSinkTask(Task): + + __slots__ = ( + "type", + "is_final_node", + ) + + manifest_filename = ".MANIFEST.txt" + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + output_path: str, + type: DataSinkType = "link_manifest", + manifest_only=False, + is_final_node=False, + ) -> None: + assert type in ( + "link_manifest", + "copy", + "link_or_copy", + "manifest", + ), f"invalid sink type: {type}" + assert output_path is not None, f"output path is required for {repr(self)}" + super().__init__(ctx, input_deps, partition_infos, output_path=output_path) + self.type = "manifest" if manifest_only else type + # just a flag to indicate if this is the final results + # there should be only one final results in the execution plan + self.is_final_node = is_final_node + self.temp_output = False + + @property + def allow_speculative_exec(self) -> bool: + return False + + @property + def self_contained_output(self) -> bool: + return self.type != "manifest" + + @property + def final_output_abspath(self) -> str: + if self.type in ("copy", "link_or_copy"): + # in the first phase, we copy or link files to the staging directory + return os.path.join(self.staging_root, self.output_dirname) + else: + # in the second phase, these files will be linked to the output directory + return self.output_root + + @property + def runtime_output_abspath(self) -> str: + return self.final_output_abspath + + def clean_output(self, force=False) -> None: + pass + + def run(self) -> bool: + with ThreadPoolExecutor(min(32, len(self.input_datasets))) as pool: + return self.collect_output_files(pool) + + def collect_output_files(self, pool: ThreadPoolExecutor) -> bool: + final_output_dir = PurePath(self.final_output_abspath) + runtime_output_dir = Path(self.runtime_output_abspath) + dst_mount_point = find_mount_point(self.runtime_output_abspath) + sink_type = self.type + + src_paths = [ + p + for paths in pool.map( + lambda dataset: [Path(path) for path in dataset.resolved_paths], + self.input_datasets, + ) + for p in paths + ] + logger.info( + f"collected {len(src_paths)} files from {len(self.input_datasets)} input datasets" + ) + + if len(set(p.name for p in src_paths)) == len(src_paths): + dst_paths = [runtime_output_dir / p.name for p in src_paths] + else: + logger.warning(f"found duplicate filenames, appending index to filename...") + dst_paths = [ + runtime_output_dir / f"{p.stem}.{idx}{p.suffix}" + for idx, p in enumerate(src_paths) + ] + + output_paths = ( + src_paths + if sink_type == "manifest" + else [final_output_dir / p.name for p in dst_paths] + ) + self.dataset = ParquetDataSet( + [str(p) for p in output_paths] + ) # FIXME: what if the dataset is not parquet? + + def copy_file(src_path: Path, dst_path: Path): + # XXX: DO NOT use shutil.{copy, copy2, copyfileobj} + # they use sendfile on Linux. although they set blocksize=8M, the actual io size is fixed to 64k, resulting in low throughput. + with open(src_path, "rb") as src_file, open(dst_path, "wb") as dst_file: + shutil.copyfileobj(src_file, dst_file, length=16 * MB) + + def create_link_or_copy(src_path: Path, dst_path: Path): + if dst_path.exists(): + logger.warning( + f"destination path already exists, replacing {dst_path} with {src_path}" + ) + dst_path.unlink(missing_ok=True) + same_mount_point = str(src_path).startswith(dst_mount_point) + if sink_type == "copy": + copy_file(src_path, dst_path) + elif sink_type == "link_manifest": + if same_mount_point: + os.link(src_path, dst_path) + else: + dst_path.symlink_to(src_path) + elif sink_type == "link_or_copy": + if same_mount_point: + os.link(src_path, dst_path) + else: + copy_file(src_path, dst_path) + else: + raise RuntimeError(f"invalid sink type: {sink_type}") + return True + + if sink_type in ("copy", "link_or_copy", "link_manifest"): + if src_paths: + assert all(pool.map(create_link_or_copy, src_paths, dst_paths)) + else: + logger.warning(f"input of data sink is empty: {self}") + + if sink_type == "manifest" or sink_type == "link_manifest": + # write to a temporary file and rename it atomically + manifest_path = final_output_dir / self.manifest_filename + manifest_tmp_path = runtime_output_dir / f"{self.manifest_filename}.tmp" + with open(manifest_tmp_path, "w", buffering=2 * MB) as manifest_file: + for path in output_paths: + print(str(path), file=manifest_file) + manifest_tmp_path.rename(manifest_path) + logger.info(f"created a manifest file at {manifest_path}") + + if sink_type == "link_manifest": + # remove the staging directory + remove_path(self.staging_root) + + # check the output parquet files + # if any file is broken, an exception will be raised + if len(dst_paths) > 0 and dst_paths[0].suffix == ".parquet": + logger.info( + f"checked dataset files and found {self.dataset.num_rows} rows" + ) + + return True + + +class RootTask(Task): + @property + def exec_on_scheduler(self) -> bool: + return True + + def __init__( + self, + ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> None: + super().__init__(ctx, input_deps, partition_infos) + + def initialize(self): + pass + + def finalize(self): + pass + + def run(self) -> bool: + return True + + +class ExecutionPlan(object): + """ + A directed acyclic graph (DAG) of tasks. + """ + + def __init__( + self, ctx: RuntimeContext, root_task: RootTask, logical_plan: "LogicalPlan" + ) -> None: + from smallpond.logical.node import LogicalPlan + + self.ctx = ctx + self.root_task = root_task + self.logical_plan: LogicalPlan = logical_plan + + def __str__(self) -> str: + visited = set() + + def to_str(task: Task, depth: int = 0) -> List[str]: + lines = [" " * depth + str(task)] + if task.id in visited: + return lines + [" " * depth + " (omitted ...)"] + visited.add(task.id) + for dep in task.input_deps.values(): + lines.extend(to_str(dep, depth + 1)) + return lines + + return os.linesep.join(to_str(self.root_task)) + + @cached_property + def _final_results(self) -> DataSinkTask: + for task in self.root_task.input_deps.values(): + if isinstance(task, DataSinkTask) and task.is_final_node: + return task + raise RuntimeError("no final results found") + + @property + def final_output(self) -> DataSet: + return self._final_results.output + + @property + def final_output_path(self) -> str: + return self._final_results.final_output_abspath + + @property + def successful(self) -> str: + return self.root_task.status == WorkStatus.SUCCEED + + @property + def leaves(self) -> List[Task]: + return [task for task in self.tasks.values() if not task.input_deps] + + @staticmethod + def iter_tasks(task: Task, visited: Set[str] = None): + visited = visited or set() + assert task.key not in visited + visited.add(task.key) + yield task + for dep in task.input_deps.values(): + if dep.key not in visited: + yield from ExecutionPlan.iter_tasks(dep, visited) + + @property + def tasks(self) -> Dict[str, Task]: + return dict((task.key, task) for task in self.iter_tasks(self.root_task)) + + @cached_property + def named_outputs(self): + assert self.successful + named_outputs: Dict[str, DataSet] = {} + task_outputs: Dict[str, List[DataSet]] = {} + + for task in self.tasks.values(): + if task.output_name: + if task.output_name not in task_outputs: + task_outputs[task.output_name] = [task.output] + else: + task_outputs[task.output_name].append(task.output) + + for name, datasets in task_outputs.items(): + named_outputs[name] = datasets[0].merge(datasets) + return named_outputs + + def get_output(self, output_name: str) -> Optional[DataSet]: + return self.named_outputs.get(output_name, None) + + @property + def analyzed_logical_plan(self): + assert self.successful + for node in self.logical_plan.nodes.values(): + for name in node.perf_metrics: + node.get_perf_stats(name) + return self.logical_plan + + +def main(): + import argparse + + from smallpond.execution.task import Task + from smallpond.io.filesystem import load + + parser = argparse.ArgumentParser(prog="task.py", description="Task Local Runner") + parser.add_argument("pickle_path", help="Path of pickled task(s)") + parser.add_argument("-t", "--task_id", default=None, help="Task id") + parser.add_argument("-r", "--retry_count", default=0, help="Task retry count") + parser.add_argument("-o", "--output_path", default=None, help="Task output path") + parser.add_argument( + "-l", "--log_level", default="DEBUG", help="Logging message level" + ) + args = parser.parse_args() + + logger.remove() + logger.add( + sys.stdout, + format=r"[{time:%Y-%m-%d %H:%M:%S.%f}] {level} {message}", + level=args.log_level, + ) + + def exec_task(task: Task, output_path: Optional[str]): + for retry_count in range(1000): + task.retry_count = retry_count + if output_path is None: + task.output_root = task.ctx.staging_root + else: + task.output_root = os.path.join(output_path, "output") + task.ctx.temp_root = os.path.join(output_path, "temp") + if any( + os.path.exists(path) + for path in ( + task.temp_abspath, + task.final_output_abspath, + task.runtime_output_abspath, + ) + ): + continue + task.status = WorkStatus.INCOMPLETE + task.start_time = time.time() + task.finish_time = None + logger.info(f"start to debug: {task}") + task.exec() + break + + obj = load(args.pickle_path) + logger.info(f"loaded an object of {type(obj)} from pickle file {args.pickle_path}") + + if isinstance(obj, Task): + task: Task = obj + exec_task(task, args.output_path) + elif isinstance(obj, Dict): + assert args.task_id is not None, f"error: no task id specified" + tasks: List[Task] = obj.values() + for task in tasks: + if task.id == TaskId(args.task_id) and task.retry_count == args.retry_count: + exec_task(task, args.output_path) + break + else: + logger.error(f"cannot find task with id {args.task_id}") + else: + logger.error(f"unsupported type of object: {type(obj)}") + + +if __name__ == "__main__": + main() diff --git a/smallpond/execution/workqueue.py b/smallpond/execution/workqueue.py new file mode 100644 index 0000000..5ab2525 --- /dev/null +++ b/smallpond/execution/workqueue.py @@ -0,0 +1,479 @@ +import os +import os.path +import queue +import random +import sys +import time +import uuid +from enum import Enum +from typing import Dict, Iterable, List, Optional + +import numpy as np +from GPUtil import GPU +from loguru import logger + +from smallpond.common import MB, NonzeroExitCode, OutOfMemory +from smallpond.io.filesystem import dump, load + + +class WorkStatus(Enum): + INCOMPLETE = 1 + SUCCEED = 2 + FAILED = 3 + CRASHED = 4 + EXEC_FAILED = 5 + + +class WorkItem(object): + + __slots__ = ( + "_cpu_limit", + "_gpu_limit", + "_memory_limit", + "_cpu_boost", + "_memory_boost", + "_numa_node", + "_local_gpu", + "key", + "status", + "start_time", + "finish_time", + "retry_count", + "fail_count", + "exception", + "exec_id", + "exec_cq", + "location", + ) + + def __init__( + self, + key: str, + cpu_limit: int = None, + gpu_limit: float = None, + memory_limit: int = None, + ) -> None: + self._cpu_limit = cpu_limit + self._gpu_limit = gpu_limit + self._memory_limit = ( + np.int64(memory_limit) if memory_limit is not None else None + ) + self._cpu_boost = 1 + self._memory_boost = 1 + self._numa_node = None + self._local_gpu: Dict[GPU, float] = {} + self.key = key + self.status = WorkStatus.INCOMPLETE + self.start_time = None + self.finish_time = None + self.retry_count = 0 + self.fail_count = 0 + self.exception = None + self.exec_id = "unknown" + self.exec_cq = None + self.location: Optional[str] = None + + def __repr__(self) -> str: + return self.key + + __str__ = __repr__ + + @property + def cpu_limit(self) -> int: + return int(self._cpu_boost * self._cpu_limit) + + @property + def gpu_limit(self) -> int: + return self._gpu_limit + + @property + def memory_limit(self) -> np.int64: + return ( + np.int64(self._memory_boost * self._memory_limit) + if self._memory_limit + else 0 + ) + + @property + def elapsed_time(self) -> float: + if self.start_time is None: + return 0 + if self.finish_time is None: + return time.time() - self.start_time + return self.finish_time - self.start_time + + @property + def exec_on_scheduler(self) -> bool: + return False + + @property + def local_gpu(self) -> Optional[GPU]: + """ + Return the first GPU granted to this task. + If gpu_limit is 0, return None. + If gpu_limit is greater than 1, only the first GPU is returned. + """ + return next(iter(self._local_gpu.keys()), None) + + @property + # @deprecated("use `local_gpu_ranks` instead") + def local_rank(self) -> Optional[int]: + """ + Return the first GPU rank granted to this task. + If gpu_limit is 0, return None. + If gpu_limit is greater than 1, only the first GPU rank is returned. Caller should use `local_gpu_ranks` instead. + """ + if self.gpu_limit > 1: + logger.warning( + f"task {self.key} requires more than 1 GPU, but only the first GPU rank is returned. please use `local_gpu_ranks` instead." + ) + return next(iter(self._local_gpu.keys())).id if self._local_gpu else None + + @property + def local_gpu_ranks(self) -> List[int]: + """Return all GPU ranks granted to this task.""" + return [gpu.id for gpu in self._local_gpu.keys()] + + @property + def numa_node(self) -> int: + return self._numa_node + + def oom(self, nonzero_exitcode_as_oom=False): + return ( + self._memory_limit is not None + and self.status == WorkStatus.CRASHED + and ( + isinstance(self.exception, (OutOfMemory, MemoryError)) + or ( + isinstance(self.exception, NonzeroExitCode) + and nonzero_exitcode_as_oom + ) + ) + ) + + def run(self) -> bool: + return True + + def initialize(self): + """Called before run() to prepare for running the task.""" + + def finalize(self): + """Called after run() to finalize the processing.""" + + def cleanup(self): + """Called after run() even if there is an exception.""" + + def exec(self, cq: Optional["WorkQueue"] = None) -> WorkStatus: + if self.status == WorkStatus.INCOMPLETE: + try: + self.start_time = time.time() + self.exec_cq = cq + self.initialize() + if self.run(): + self.status = WorkStatus.SUCCEED + self.finalize() + else: + self.status = WorkStatus.FAILED + except Exception as ex: + logger.opt(exception=ex).error( + f"{repr(self)} crashed with error. node location at {self.location}" + ) + self.status = WorkStatus.CRASHED + self.exception = ex + finally: + self.cleanup() + self.exec_cq = None + self.finish_time = time.time() + return self.status + + +class StopExecutor(WorkItem): + def __init__(self, key: str, ack=True) -> None: + super().__init__(key, cpu_limit=0, gpu_limit=0, memory_limit=0) + self.ack = ack + + +class StopWorkItem(WorkItem): + def __init__(self, key: str, work_to_stop: str) -> None: + super().__init__(key, cpu_limit=0, gpu_limit=0, memory_limit=0) + self.work_to_stop = work_to_stop + + +class WorkBatch(WorkItem): + def __init__(self, key: str, works: List[WorkItem]) -> None: + cpu_limit = max(w.cpu_limit for w in works) + gpu_limit = max(w.gpu_limit for w in works) + memory_limit = max(w.memory_limit for w in works) + super().__init__( + f"{self.__class__.__name__}-{key}", cpu_limit, gpu_limit, memory_limit + ) + self.works = works + + def __str__(self) -> str: + return ( + super().__str__() + + f", works[{len(self.works)}]={self.works[:1]}...{self.works[-1:]}" + ) + + def run(self) -> bool: + logger.info(f"processing {len(self.works)} works in the batch") + for index, work in enumerate(self.works): + work.exec_id = self.exec_id + if work.exec(self.exec_cq) != WorkStatus.SUCCEED: + logger.error( + f"work item #{index+1}/{len(self.works)} in {self.key} failed: {work}" + ) + return False + logger.info(f"done {len(self.works)} works in the batch") + return True + + +class WorkQueue(object): + def __init__(self) -> None: + self.outbound_works: List[WorkItem] = [] + + def _pop_unbuffered(self, count: int) -> List[WorkItem]: + raise NotImplementedError + + def _push_unbuffered(self, item: WorkItem) -> bool: + raise NotImplementedError + + def pop(self, count=1) -> List[WorkItem]: + inbound_works: List[WorkItem] = [] + for item in self._pop_unbuffered(count): + if isinstance(item, WorkBatch): + inbound_works.extend(item.works) + else: + inbound_works.append(item) + return inbound_works + + def push(self, item: WorkItem, buffering=False) -> bool: + if buffering: + self.outbound_works.append(item) + return True + elif len(self.outbound_works) > 0: + self.outbound_works.append(item) + return self.flush() + else: + return self._push_unbuffered(item) + + def flush(self) -> bool: + if len(self.outbound_works) == 0: + return True + batch = WorkBatch(self.outbound_works[0].key, self.outbound_works) + self.outbound_works = [] + return self._push_unbuffered(batch) + + +class WorkQueueInMemory(WorkQueue): + def __init__(self, queue_type=queue.Queue) -> None: + super().__init__() + self.queue = queue_type() + + def _pop_unbuffered(self, count: int) -> List[WorkItem]: + try: + return [self.queue.get(block=False)] + except queue.Empty: + return [] + + def _push_unbuffered(self, item: WorkItem) -> bool: + self.queue.put(item) + return True + + +class WorkQueueOnFilesystem(WorkQueue): + + buffer_size = 16 * MB + + def __init__(self, workq_root: str, sort=True, random=False) -> None: + super().__init__() + self.workq_root = workq_root + self.sort = sort + self.random = random + self.buffered_works: List[WorkItem] = [] + self.temp_dir = os.path.join(self.workq_root, "temp") + self.enqueue_dir = os.path.join(self.workq_root, "enqueued") + self.dequeue_dir = os.path.join(self.workq_root, "dequeued") + os.makedirs(self.temp_dir, exist_ok=True) + os.makedirs(self.enqueue_dir, exist_ok=True) + os.makedirs(self.dequeue_dir, exist_ok=True) + + def __str__(self) -> str: + return f"{self.__class__.__name__}@{self.workq_root}" + + def _get_entries(self, path=None) -> List[os.DirEntry]: + dentries: List[os.DirEntry] = [] + with os.scandir(path or self.enqueue_dir) as dir_iter: + for entry in dir_iter: + if entry.is_file(): + dentries.append(entry) + return dentries + + def size(self) -> int: + return len(self._get_entries()) + + def _list_works(self, path: str, expand_batch=True) -> Iterable[WorkItem]: + dentries = self._get_entries(path) + logger.info("listing {} entries in {}", len(dentries), path) + for entry in dentries: + item: WorkItem = load(entry.path, self.buffer_size) + if expand_batch and isinstance(item, WorkBatch): + for work in item.works: + yield work + else: + yield item + + def list_enqueued(self, expand_batch=True) -> Iterable[WorkItem]: + yield from self._list_works(self.enqueue_dir, expand_batch) + + def list_dequeued(self, expand_batch=True) -> Iterable[WorkItem]: + yield from self._list_works(self.dequeue_dir, expand_batch) + + def list_works(self, expand_batch=True) -> Iterable[WorkItem]: + yield from self.list_enqueued(expand_batch) + yield from self.list_dequeued(expand_batch) + + def _pop_unbuffered(self, count: int) -> List[WorkItem]: + items = [] + dentries = self._get_entries() + + if self.sort: + dentries = sorted(dentries, key=lambda entry: entry.name) + elif self.random: + random.shuffle(dentries) + + for entry in dentries: + uuid_hex = uuid.uuid4().hex + filename = f"{entry.name}-DEQ{uuid_hex}" + dequeued_path = os.path.join(self.dequeue_dir, filename) + + try: + os.rename(entry.path, dequeued_path) + except OSError as err: + logger.debug(f"cannot rename {entry.path} to {dequeued_path}: {err}") + if items: + break + else: + continue + + items.append(load(dequeued_path, self.buffer_size)) + if len(items) >= count: + break + + return items + + def _push_unbuffered(self, item: WorkItem) -> bool: + timestamp = time.time_ns() + uuid_hex = uuid.uuid4().hex + filename = f"{item.key}-{timestamp:x}-ENQ{uuid_hex}" + tempfile_path = os.path.join(self.temp_dir, filename) + enqueued_path = os.path.join(self.enqueue_dir, filename) + + dump(item, tempfile_path, self.buffer_size) + + try: + os.rename(tempfile_path, enqueued_path) + return True + except OSError as err: + logger.critical( + f"failed to rename {tempfile_path} to {enqueued_path}: {err}" + ) + return False + + +def count_objects(obj, object_cnt=None, visited_objs=None, depth=0): + object_cnt = {} if object_cnt is None else object_cnt + visited_objs = set() if visited_objs is None else visited_objs + + if id(obj) in visited_objs: + return + else: + visited_objs.add(id(obj)) + + if isinstance(obj, dict): + for key, value in obj.items(): + count_objects(key, object_cnt, visited_objs, depth + 1) + count_objects(value, object_cnt, visited_objs, depth + 1) + elif isinstance(obj, list) or isinstance(obj, tuple): + for item in obj: + count_objects(item, object_cnt, visited_objs, depth + 1) + else: + class_name = obj.__class__.__name__ + if class_name not in object_cnt: + object_cnt[class_name] = (0, 0) + cnt, size = object_cnt[class_name] + object_cnt[class_name] = (cnt + 1, size + sys.getsizeof(obj)) + + key_attributes = ("__self__", "__dict__", "__slots__") + if not isinstance(obj, (bool, str, int, float, type(None))) and any( + attr_name in key_attributes for attr_name in dir(obj) + ): + logger.debug(f"{' ' * depth}{class_name}@{id(obj):x}") + for attr_name in dir(obj): + try: + if ( + not attr_name.startswith("__") or attr_name in key_attributes + ) and not isinstance( + getattr(obj.__class__, attr_name, None), property + ): + logger.debug( + f"{' ' * depth}{class_name}.{attr_name}@{id(obj):x}" + ) + count_objects( + getattr(obj, attr_name), object_cnt, visited_objs, depth + 1 + ) + except Exception as ex: + logger.warning( + f"failed to get '{attr_name}' from {repr(obj)}: {ex}" + ) + + +def main(): + import argparse + + from smallpond.execution.task import Probe + + parser = argparse.ArgumentParser( + prog="workqueue.py", description="Work Queue Reader" + ) + parser.add_argument("wq_root", help="Work queue root path") + parser.add_argument("-f", "--work_filter", default="", help="Work item filter") + parser.add_argument( + "-x", "--expand_batch", action="store_true", help="Expand batched works" + ) + parser.add_argument( + "-c", "--count_object", action="store_true", help="Count number of objects" + ) + parser.add_argument( + "-n", "--top_n_class", default=20, type=int, help="Show the top n classes" + ) + parser.add_argument( + "-l", "--log_level", default="INFO", help="Logging message level" + ) + args = parser.parse_args() + + logger.remove() + logger.add( + sys.stdout, + format=r"[{time:%Y-%m-%d %H:%M:%S.%f}] {level} {message}", + level=args.log_level, + ) + + wq = WorkQueueOnFilesystem(args.wq_root) + for work in wq.list_works(args.expand_batch): + if isinstance(work, Probe): + continue + if args.work_filter in work.key: + logger.info(work) + if args.count_object: + object_cnt = {} + count_objects(work, object_cnt) + sorted_counts = sorted( + [(v, k) for k, v in object_cnt.items()], reverse=True + ) + for count, class_name in sorted_counts[: args.top_n_class]: + logger.info(f" {class_name}: {count}") + + +if __name__ == "__main__": + main() diff --git a/smallpond/io/__init__.py b/smallpond/io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/smallpond/io/arrow.py b/smallpond/io/arrow.py new file mode 100644 index 0000000..9c47cae --- /dev/null +++ b/smallpond/io/arrow.py @@ -0,0 +1,419 @@ +import copy +import math +import os.path +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Iterable, List, Optional, Union + +import fsspec +import pyarrow as arrow +import pyarrow.parquet as parquet +from loguru import logger + +from smallpond.common import ( + DEFAULT_BATCH_SIZE, + DEFAULT_ROW_GROUP_BYTES, + DEFAULT_ROW_GROUP_SIZE, + MAX_PARQUET_FILE_BYTES, + MB, + split_into_rows, +) + + +@dataclass +class RowRange: + """A range of rows in a file.""" + + path: str + """Path to the file.""" + data_size: int + """The uncompressed data size in bytes.""" + file_num_rows: int + """The number of rows in the file.""" + begin: int + """Index of first row in the file.""" + end: int + """Index of last row + 1 in the file.""" + + @property + def num_rows(self) -> int: + """The number of rows in the range.""" + return self.end - self.begin + + @property + def estimated_data_size(self) -> int: + """The estimated uncompressed data size in bytes.""" + return ( + self.data_size * self.num_rows // self.file_num_rows + if self.file_num_rows > 0 + else 0 + ) + + def take(self, num_rows: int) -> "RowRange": + """ + Take `num_rows` rows from the range. + NOTE: this function modifies the current row range. + """ + num_rows = min(num_rows, self.num_rows) + head = copy.copy(self) + head.end = head.begin + num_rows + self.begin += num_rows + return head + + @staticmethod + def partition_by_rows( + row_ranges: List["RowRange"], npartition: int + ) -> List[List["RowRange"]]: + """Evenly split a list of row ranges into `npartition` partitions.""" + # NOTE: `row_ranges` should not be modified by this function + row_ranges = copy.deepcopy(row_ranges) + num_rows: int = sum(row_range.num_rows for row_range in row_ranges) + num_partitions: int = npartition + row_range_partitions: List[List[RowRange]] = [] + while num_partitions: + rows_in_partition = (num_rows + num_partitions - 1) // num_partitions + num_rows -= rows_in_partition + num_partitions -= 1 + row_ranges_in_partition = [] + while rows_in_partition: + current_range = row_ranges[0] + if current_range.num_rows == 0: + row_ranges.pop(0) + continue + taken_range = current_range.take(rows_in_partition) + row_ranges_in_partition.append(taken_range) + rows_in_partition -= taken_range.num_rows + row_range_partitions.append(row_ranges_in_partition) + assert num_rows == 0 and num_partitions == 0 + return row_range_partitions + + +def convert_type_to_large(type_: arrow.DataType) -> arrow.DataType: + """ + Convert all string and binary types to large types recursively. + """ + # Since arrow uses 32-bit signed offsets for string and binary types, convert all string and binary columns + # to large_string and large_binary to avoid offset overflow, see https://issues.apache.org/jira/browse/ARROW-17828. + if arrow.types.is_string(type_): + return arrow.large_string() + elif arrow.types.is_binary(type_): + return arrow.large_binary() + elif isinstance(type_, arrow.ListType): + return arrow.list_(convert_type_to_large(type_.value_type)) + elif isinstance(type_, arrow.StructType): + return arrow.struct( + [ + arrow.field( + field.name, + convert_type_to_large(field.type), + nullable=field.nullable, + ) + for field in type_ + ] + ) + elif isinstance(type_, arrow.MapType): + return arrow.map_( + convert_type_to_large(type_.key_type), + convert_type_to_large(type_.item_type), + ) + else: + return type_ + + +def convert_types_to_large_string(schema: arrow.Schema) -> arrow.Schema: + """ + Convert all string and binary types to large types in the schema. + """ + new_fields = [] + for field in schema: + new_type = convert_type_to_large(field.type) + new_field = arrow.field( + field.name, new_type, nullable=field.nullable, metadata=field.metadata + ) + new_fields.append(new_field) + return arrow.schema(new_fields, metadata=schema.metadata) + + +def cast_columns_to_large_string(table: arrow.Table) -> arrow.Table: + schema = convert_types_to_large_string(table.schema) + return table.cast(schema) + + +def filter_schema( + schema: arrow.Schema, + included_cols: Optional[List[str]] = None, + excluded_cols: Optional[List[str]] = None, +): + assert included_cols is None or excluded_cols is None + if included_cols is None and excluded_cols is None: + return schema + if included_cols is not None: + fields = [schema.field(col_name) for col_name in included_cols] + if excluded_cols is not None: + fields = [ + schema.field(col_name) + for col_name in schema.names + if col_name not in excluded_cols + ] + return arrow.schema(fields, metadata=schema.metadata) + + +def _iter_record_batches( + file: parquet.ParquetFile, + columns: List[str], + offset: int, + length: int, + batch_size: int, +) -> Iterable[arrow.RecordBatch]: + """ + Read record batches from a range of a parquet file. + """ + current_offset = 0 + required_l, required_r = offset, offset + length + + for batch in file.iter_batches( + batch_size=batch_size, columns=columns, use_threads=False + ): + current_l, current_r = current_offset, current_offset + batch.num_rows + # check if intersection is null + if current_r <= required_l: + pass + elif current_l >= required_r: + break + else: + intersection_l = max(required_l, current_l) + intersection_r = min(required_r, current_r) + trimmed = batch.slice( + intersection_l - current_offset, intersection_r - intersection_l + ) + assert ( + trimmed.num_rows == intersection_r - intersection_l + ), f"trimmed.num_rows {trimmed.num_rows} != batch_length {intersection_r - intersection_l}" + yield cast_columns_to_large_string(trimmed) + current_offset += batch.num_rows + + +def build_batch_reader_from_files( + paths_or_ranges: Union[List[str], List[RowRange]], + *, + columns: Optional[List[str]] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + max_batch_byte_size: Optional[int] = None, + filesystem: fsspec.AbstractFileSystem = None, +) -> arrow.RecordBatchReader: + assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list" + schema = _read_schema_from_file(paths_or_ranges[0], columns, filesystem) + iterator = _iter_record_batches_from_files( + paths_or_ranges, columns, batch_size, max_batch_byte_size, filesystem + ) + return arrow.RecordBatchReader.from_batches(schema, iterator) + + +def _read_schema_from_file( + path_or_range: Union[str, RowRange], + columns: Optional[List[str]] = None, + filesystem: fsspec.AbstractFileSystem = None, +) -> arrow.Schema: + path = path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range + schema = parquet.read_schema( + filesystem.unstrip_protocol(path) if filesystem else path, filesystem=filesystem + ) + if columns is not None: + assert all( + c in schema.names for c in columns + ), f"""some of {columns} cannot be found in schema of {path}: + {schema} + + The following query can help to find files with missing columns: + duckdb-dev -c "select * from ( select file_name, list(name) as column_names, list_filter({columns}, c -> not list_contains(column_names, c)) as missing_columns FROM parquet_schema(['{os.path.join(os.path.dirname(path), '*.parquet')}']) group by file_name ) where len(missing_columns) > 0" + """ + schema = filter_schema(schema, columns) + return convert_types_to_large_string(schema) + + +def _iter_record_batches_from_files( + paths_or_ranges: Union[List[str], List[RowRange]], + columns: Optional[List[str]] = None, + batch_size: int = DEFAULT_BATCH_SIZE, + max_batch_byte_size: Optional[int] = None, + filesystem: fsspec.AbstractFileSystem = None, +) -> Iterable[arrow.RecordBatch]: + """ + Build a batch reader from a list of row ranges. + """ + buffered_batches = [] + buffered_rows = 0 + buffered_bytes = 0 + + def combine_buffered_batches( + batches: List[arrow.RecordBatch], + ) -> Iterable[arrow.RecordBatch]: + table = arrow.Table.from_batches(batches) + yield from table.combine_chunks().to_batches(batch_size) + + for path_or_range in paths_or_ranges: + path = ( + path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range + ) + with parquet.ParquetFile( + filesystem.unstrip_protocol(path) if filesystem else path, + buffer_size=16 * MB, + filesystem=filesystem, + ) as file: + if isinstance(path_or_range, RowRange): + offset, length = path_or_range.begin, path_or_range.num_rows + else: + offset, length = 0, file.metadata.num_rows + for batch in _iter_record_batches( + file, columns, offset, length, batch_size + ): + batch_size_exceeded = batch.num_rows + buffered_rows >= batch_size + batch_byte_size_exceeded = ( + max_batch_byte_size is not None + and batch.nbytes + buffered_bytes >= max_batch_byte_size + ) + if not batch_size_exceeded and not batch_byte_size_exceeded: + buffered_batches.append(batch) + buffered_rows += batch.num_rows + buffered_bytes += batch.nbytes + else: + if batch_size_exceeded: + buffered_batches.append( + batch.slice(0, batch_size - buffered_rows) + ) + batch = batch.slice(batch_size - buffered_rows) + if buffered_batches: + yield from combine_buffered_batches(buffered_batches) + buffered_batches = [batch] + buffered_rows = batch.num_rows + buffered_bytes = batch.nbytes + + if buffered_batches: + yield from combine_buffered_batches(buffered_batches) + + +def read_parquet_files_into_table( + paths_or_ranges: Union[List[str], List[RowRange]], + columns: List[str] = None, + filesystem: fsspec.AbstractFileSystem = None, +) -> arrow.Table: + batch_reader = build_batch_reader_from_files( + paths_or_ranges, columns=columns, filesystem=filesystem + ) + return batch_reader.read_all() + + +def load_from_parquet_files( + paths_or_ranges: Union[List[str], List[RowRange]], + columns: List[str] = None, + max_workers: int = 16, + filesystem: fsspec.AbstractFileSystem = None, +) -> arrow.Table: + start_time = time.time() + assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list" + paths = [ + path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range + for path_or_range in paths_or_ranges + ] + total_compressed_size = sum( + ( + path_or_range.data_size + if isinstance(path_or_range, RowRange) + else os.path.getsize(path_or_range) + ) + for path_or_range in paths_or_ranges + ) + logger.debug( + f"loading {len(paths)} parquet files (compressed size: {total_compressed_size/MB:.3f}MB): {paths[:3]}..." + ) + num_workers = min(len(paths), max_workers) + with ThreadPoolExecutor(num_workers) as pool: + running_works = [ + pool.submit(read_parquet_files_into_table, batch, columns, filesystem) + for batch in split_into_rows(paths_or_ranges, num_workers) + ] + tables = [work.result() for work in running_works] + logger.debug( + f"collected {len(tables)} tables from: {paths[:3]}... (elapsed: {time.time() - start_time:.3f} secs)" + ) + return arrow.concat_tables(tables) + + +def parquet_write_table( + table, where, filesystem: fsspec.AbstractFileSystem = None, **write_table_args +) -> int: + if filesystem is not None: + return parquet.write_table( + table, + where=(filesystem.unstrip_protocol(where) if filesystem else where), + filesystem=filesystem, + **write_table_args, + ) + with open(where, "wb", buffering=32 * MB) as file: + return parquet.write_table(table, where=file, **write_table_args) + + +def dump_to_parquet_files( + table: arrow.Table, + output_dir: str, + filename: str = "data", + compression="ZSTD", + compression_level=3, + row_group_size=DEFAULT_ROW_GROUP_SIZE, + row_group_bytes=DEFAULT_ROW_GROUP_BYTES, + use_dictionary=False, + max_workers=16, + filesystem: fsspec.AbstractFileSystem = None, +) -> bool: + table = cast_columns_to_large_string(table) + if table.num_rows == 0: + logger.warning(f"creating empty parquet file in {output_dir}") + parquet_write_table( + table, + os.path.join(output_dir, f"{filename}-0.parquet"), + compression=compression, + row_group_size=row_group_size, + ) + return True + + start_time = time.time() + avg_row_size = max(1, table.nbytes // table.num_rows) + row_group_size = min(row_group_bytes // avg_row_size, row_group_size) + logger.debug( + f"dumping arrow table ({table.nbytes/MB:.3f}MB, {table.num_rows} rows) to {output_dir}, avg row size: {avg_row_size}, row group size: {row_group_size}" + ) + + batches = table.to_batches(max_chunksize=row_group_size) + num_workers = min(len(batches), max_workers) + num_tables = max(math.ceil(table.nbytes / MAX_PARQUET_FILE_BYTES), num_workers) + logger.debug(f"evenly distributed {len(batches)} batches into {num_tables} files") + tables = [ + arrow.Table.from_batches(batch, table.schema) + for batch in split_into_rows(batches, num_tables) + ] + assert sum(t.num_rows for t in tables) == table.num_rows + + logger.debug(f"writing {len(tables)} files to {output_dir}") + with ThreadPoolExecutor(num_workers) as pool: + running_works = [ + pool.submit( + parquet_write_table, + table=table, + where=os.path.join(output_dir, f"{filename}-{i}.parquet"), + use_dictionary=use_dictionary, + compression=compression, + compression_level=compression_level, + row_group_size=row_group_size, + write_batch_size=max(16 * 1024, row_group_size // 8), + data_page_size=max(64 * MB, row_group_bytes // 8), + filesystem=filesystem, + ) + for i, table in enumerate(tables) + ] + assert all(work.result() or True for work in running_works) + + logger.debug( + f"finished writing {len(tables)} files to {output_dir} (elapsed: {time.time() - start_time:.3f} secs)" + ) + return True diff --git a/smallpond/io/filesystem.py b/smallpond/io/filesystem.py new file mode 100644 index 0000000..c76d406 --- /dev/null +++ b/smallpond/io/filesystem.py @@ -0,0 +1,136 @@ +import io +import os +import shutil +import tempfile +import time +from typing import Any + +import cloudpickle +import zstandard as zstd +from loguru import logger + +from smallpond.common import MB + +HF3FS_MOUNT_POINT_PREFIX = "/hf3fs" +HF3FS_FSSPEC_PROTOCOL = "hf3fs://" + + +def on_hf3fs(path: str): + return path.startswith(HF3FS_MOUNT_POINT_PREFIX) + + +def extract_hf3fs_mount_point(path: str): + return os.path.join("/", *path.split("/")[1:3]) if on_hf3fs(path) else None + + +def remove_path(path: str): + realpath = os.path.realpath(path) + + if os.path.islink(path): + logger.debug(f"removing link: {path}") + os.unlink(path) + + if not os.path.exists(realpath): + return + + logger.debug(f"removing path: {realpath}") + if on_hf3fs(realpath): + try: + link = os.path.join( + extract_hf3fs_mount_point(realpath), + "3fs-virt/rm-rf", + f"{os.path.basename(realpath)}-{time.time_ns()}", + ) + os.symlink(realpath, link) + return + except Exception as ex: + logger.opt(exception=ex).debug( + f"fast recursive remove failed, fall back to shutil.rmtree('{realpath}')" + ) + shutil.rmtree(realpath, ignore_errors=True) + + +def find_mount_point(path: str): + path = os.path.abspath(path) + while not os.path.ismount(path): + path = os.path.dirname(path) + return path + + +def dump(obj: Any, path: str, buffering=2 * MB, atomic_write=False) -> int: + """ + Dump an object to a file. + + Args: + obj: The object to dump. + path: The path to the file to dump the object to. + buffering: The buffering size. + atomic_write: Whether to atomically write the file. + + Returns: + The size of the file. + """ + + def get_pickle_trace(obj): + try: + import dill + import dill.detect + except ImportError: + return None, None + pickle_trace = io.StringIO() + pickle_error = None + with dill.detect.trace(pickle_trace): + try: + dill.dumps(obj, recurse=True) + except Exception as ex: + pickle_error = ex + return pickle_trace.getvalue(), pickle_error + + def write_to_file(fout): + with zstd.ZstdCompressor().stream_writer(fout, closefd=False) as zstd_writer: + try: + cloudpickle.dump(obj, zstd_writer) + except zstd.ZstdError as ex: + raise + except Exception as ex: + trace_str, trace_err = get_pickle_trace(obj) + logger.opt(exception=ex).error( + f"pickle trace of {repr(obj)}:{os.linesep}{trace_str}" + ) + if trace_err is None: + raise + else: + raise trace_err from ex + logger.trace("{} saved to {}", repr(obj), path) + + size = 0 + + if atomic_write: + directory, filename = os.path.split(path) + with tempfile.NamedTemporaryFile( + "wb", buffering=buffering, dir=directory, prefix=filename, delete=False + ) as fout: + write_to_file(fout) + fout.seek(0, os.SEEK_END) + size = fout.tell() + os.rename(fout.name, path) + else: + with open(path, "wb", buffering=buffering) as fout: + write_to_file(fout) + fout.seek(0, os.SEEK_END) + size = fout.tell() + + if size >= buffering: + logger.debug(f"created a large pickle file ({size/MB:.3f}MB): {path}") + return size + + +def load(path: str, buffering=2 * MB) -> Any: + """ + Load an object from a file. + """ + with open(path, "rb", buffering=buffering) as fin: + with zstd.ZstdDecompressor().stream_reader(fin) as zstd_reader: + obj = cloudpickle.load(zstd_reader) + logger.trace("{} loaded from {}", repr(obj), path) + return obj diff --git a/smallpond/logical/__init__.py b/smallpond/logical/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/smallpond/logical/dataset.py b/smallpond/logical/dataset.py new file mode 100644 index 0000000..a507f94 --- /dev/null +++ b/smallpond/logical/dataset.py @@ -0,0 +1,1098 @@ +import copy +import functools +import glob +import os.path +import random +import re +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Dict, List, Optional, Union + +import duckdb +import fsspec +import pandas as pd +import pyarrow as arrow +import pyarrow.parquet as parquet +from loguru import logger + +from smallpond.common import ( + DEFAULT_BATCH_SIZE, + GB, + PARQUET_METADATA_KEY_PREFIX, + split_into_rows, +) +from smallpond.io.arrow import ( + RowRange, + build_batch_reader_from_files, + dump_to_parquet_files, + load_from_parquet_files, +) +from smallpond.logical.udf import UDFContext + +magic_check = re.compile(r"([*?]|\[.*\])") +magic_check_bytes = re.compile(rb"([*?]|\[.*\])") + + +def has_magic(s): + if isinstance(s, bytes): + match = magic_check_bytes.search(s) + else: + match = magic_check.search(s) + return match is not None + + +class DataSet(object): + """ + The base class for all datasets. + """ + + __slots__ = ( + "paths", + "root_dir", + "recursive", + "columns", + "__dict__", + "_union_by_name", + "_resolved_paths", + "_absolute_paths", + "_resolved_num_rows", + ) + + def __init__( + self, + paths: Union[str, List[str]], + root_dir: Optional[str] = "", + recursive=False, + columns: Optional[List[str]] = None, + union_by_name=False, + ) -> None: + """ + Construct a dataset from a list of paths. + + Parameters + ---------- + paths + A path or a list of paths or path patterns. + e.g. `['data/100.parquet', '/datasetA/*.parquet']`. + root_dir, optional + Relative paths in `paths` would be resolved under `root_dir` if specified. + recursive, optional + Resolve path patterns recursively if true. + columns, optional + Only load the specified columns if not None. + union_by_name, optional + Unify the columns of different files by name (see https://duckdb.org/docs/data/multiple_files/combining_schemas#union-by-name). + """ + self.paths = [paths] if isinstance(paths, str) else paths + "The paths to the dataset files." + self.root_dir = os.path.abspath(root_dir) if root_dir is not None else None + "The root directory of paths." + self.recursive = recursive + "Whether to resolve path patterns recursively." + self.columns = columns + "The columns to load from the dataset files." + self._union_by_name = union_by_name + self._resolved_paths: List[str] = None + self._absolute_paths: List[str] = None + self._resolved_num_rows: int = None + + def __str__(self) -> str: + return f"{self.__class__.__name__}: paths[{len(self.paths)}]={self.paths[:3]}...', root_dir={self.root_dir}, columns={self.columns}" + + __repr__ = __str__ + + @property + def _resolved_path_str(self) -> str: + return ", ".join(map(lambda x: f"'{x}'", self.resolved_paths)) + + @property + def _column_str(self) -> str: + """ + A column string used in SQL select clause. + """ + return ", ".join(self.columns) if self.columns else "*" + + @property + def union_by_name(self) -> bool: + """ + Whether to unify the columns of different files by name. + """ + return self._union_by_name or self.columns is not None + + @property + def udfs(self) -> List[UDFContext]: + return [] + + @staticmethod + def merge(datasets: "List[DataSet]") -> "DataSet": + """ + Merge multiple datasets into a single dataset. + """ + raise NotImplementedError + + def reset( + self, + paths: Optional[List[str]] = None, + root_dir: Optional[str] = "", + recursive=None, + ) -> None: + """ + Reset the dataset with new paths, root_dir, and recursive flag. + """ + self.partition_by_files.cache_clear() + self.paths = paths or [] + self.root_dir = os.path.abspath(root_dir) if root_dir is not None else None + self.recursive = recursive if recursive is not None else self.recursive + self._resolved_paths = None + self._absolute_paths = None + self._resolved_num_rows = None + + @property + def num_files(self) -> int: + """ + The number of files in the dataset. + """ + return len(self.resolved_paths) + + @property + def num_rows(self) -> int: + """ + The number of rows in the dataset. + """ + if self._resolved_num_rows is None: + sql_query = f"select count(*) from {self.sql_query_fragment()}" + row = duckdb.sql(sql_query).fetchone() + assert row is not None, "no rows returned" + self._resolved_num_rows = row[0] + return self._resolved_num_rows + + @property + def empty(self) -> bool: + """ + Whether the dataset is empty. + """ + if self._resolved_paths is not None: + return len(self._resolved_paths) == 0 + for path in self.paths: + if has_magic(path): + if any( + glob.iglob( + os.path.join(self.root_dir or "", path), + recursive=self.recursive, + ) + ): + return False + else: + return False + return True + + @property + def resolved_paths(self) -> List[str]: + """ + An ordered list of absolute paths of files. + File patterns are expanded to absolute paths. + + Example:: + >>> DataSet(['data/100.parquet', '/datasetA/*.parquet']).resolved_paths + ['/datasetA/1.parquet', '/datasetA/2.parquet', '/home/user/data/100.parquet'] + """ + if self._resolved_paths is None: + resolved_paths = [] + wildcard_paths = [] + for path in self.absolute_paths: + if has_magic(path): + wildcard_paths.append(path) + else: + resolved_paths.append(path) + if wildcard_paths: + if len(wildcard_paths) == 1: + expanded_paths = glob.glob( + wildcard_paths[0], recursive=self.recursive + ) + else: + logger.debug( + "resolving {} paths with wildcards in {}", + len(wildcard_paths), + self, + ) + with ThreadPoolExecutor(min(32, len(wildcard_paths))) as pool: + expanded_paths = [ + p + for paths in pool.map( + lambda p: glob.glob(p, recursive=self.recursive), + wildcard_paths, + ) + for p in paths + ] + resolved_paths.extend(expanded_paths) + logger.debug( + "resolved {} files from {} wildcard path(s) in {}", + len(expanded_paths), + len(wildcard_paths), + self, + ) + self._resolved_paths = sorted(resolved_paths) + return self._resolved_paths + + @property + def absolute_paths(self) -> List[str]: + """ + An ordered list of absolute paths of the given file patterns. + + Example:: + >>> DataSet(['data/100.parquet', '/datasetA/*.parquet']).absolute_paths + ['/datasetA/*.parquet', '/home/user/data/100.parquet'] + """ + if self._absolute_paths is None: + if self.root_dir is None: + self._absolute_paths = sorted(self.paths) + else: + self._absolute_paths = [ + os.path.join(self.root_dir, p) for p in sorted(self.paths) + ] + return self._absolute_paths + + def sql_query_fragment( + self, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> str: + """ + Return a sql fragment that represents the dataset. + """ + raise NotImplementedError + + def log(self, num_rows=200): + """ + Log the dataset to the logger. + """ + import pandas as pd + + pd.set_option("display.max_columns", None) # Show all columns + pd.set_option("display.max_rows", None) # Optionally show all rows + pd.set_option("display.max_colwidth", None) # No truncation of column contents + pd.set_option("display.expand_frame_repr", False) # Do not wrap rows + logger.info("{} ->\n{}", self, self.to_pandas().head(num_rows)) + + def to_pandas(self) -> pd.DataFrame: + """ + Convert the dataset to a pandas dataframe. + """ + return self.to_arrow_table().to_pandas() + + def to_arrow_table( + self, + max_workers: int = 16, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.Table: + """ + Load the dataset to an arrow table. + + Parameters + ---------- + max_workers, optional + The maximum number of worker threads to use. Default to 16. + filesystem, optional + If provided, use the filesystem to load the dataset. + conn, optional + A duckdb connection. If provided, use duckdb to load the dataset. + """ + sql_query = f"select {self._column_str} from {self.sql_query_fragment(filesystem, conn)}" + if conn is not None: + return conn.sql(sql_query).fetch_arrow_table() + else: + return duckdb.sql(sql_query).fetch_arrow_table() + + def to_batch_reader( + self, + batch_size: int = DEFAULT_BATCH_SIZE, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.RecordBatchReader: + """ + Return an arrow record batch reader to read the dataset. + + Parameters + ---------- + batch_size, optional + The record batch size. Default to 122880. + filesystem, optional + If provided, use the filesystem to load the dataset. + conn, optional + A duckdb connection. If provided, use duckdb to load the dataset. + """ + sql_query = f"select {self._column_str} from {self.sql_query_fragment(filesystem, conn)}" + if conn is not None: + return conn.sql(sql_query).fetch_arrow_reader(batch_size) + else: + return duckdb.sql(sql_query).fetch_arrow_reader(batch_size) + + def _init_file_partitions(self, npartition: int) -> "List[DataSet]": + """ + Return `npartition` empty datasets. + """ + file_partitions = [] + for _ in range(npartition): + empty_dataset = copy.copy(self) + empty_dataset.reset() + file_partitions.append(empty_dataset) + return file_partitions + + @functools.lru_cache + def partition_by_files( + self, npartition: int, random_shuffle: bool = False + ) -> "List[DataSet]": + """ + Evenly split into `npartition` datasets by files. + """ + assert npartition > 0, f"npartition has negative value: {npartition}" + if npartition > self.num_files: + logger.debug( + f"number of partitions {npartition} is greater than the number of files {self.num_files}" + ) + + resolved_paths = ( + random.sample(self.resolved_paths, len(self.resolved_paths)) + if random_shuffle + else self.resolved_paths + ) + evenly_split_groups = split_into_rows(resolved_paths, npartition) + num_paths_in_groups = list(map(len, evenly_split_groups)) + + file_partitions = self._init_file_partitions(npartition) + for i, paths in enumerate(evenly_split_groups): + file_partitions[i].reset(paths, None) + + logger.debug( + f"created {npartition} file partitions (min #files: {min(num_paths_in_groups)}, max #files: {max(num_paths_in_groups)}, avg #files: {sum(num_paths_in_groups)/len(num_paths_in_groups):.3f}) from {self}" + ) + return ( + random.sample(file_partitions, len(file_partitions)) + if random_shuffle + else file_partitions + ) + + +class PartitionedDataSet(DataSet): + """ + A dataset that is partitioned into multiple datasets. + """ + + __slots__ = ("datasets",) + + def __init__(self, datasets: List[DataSet]) -> None: + assert len(datasets) > 0, "no dataset given" + self.datasets = datasets + absolute_paths = [p for dataset in datasets for p in dataset.absolute_paths] + super().__init__( + absolute_paths, + datasets[0].root_dir, + datasets[0].recursive, + datasets[0].columns, + datasets[0].union_by_name, + ) + + def __getitem__(self, key: int) -> DataSet: + return self.datasets[key] + + @property + def udfs(self) -> List[UDFContext]: + return [udf for dataset in self.datasets for udf in dataset.udfs] + + @staticmethod + def merge(datasets: "List[PartitionedDataSet]") -> DataSet: + # merge partitioned datasets results in an unpartitioned dataset + assert all(isinstance(dataset, PartitionedDataSet) for dataset in datasets) + datasets = [d for dataset in datasets for d in dataset] + return datasets[0].merge(datasets) + + +class FileSet(DataSet): + """ + A set of files. + """ + + @staticmethod + def merge(datasets: "List[FileSet]") -> "FileSet": + assert all(isinstance(dataset, FileSet) for dataset in datasets) + absolute_paths = [p for dataset in datasets for p in dataset.absolute_paths] + return FileSet(absolute_paths) + + def to_arrow_table( + self, + max_workers=16, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.Table: + return arrow.table([self.resolved_paths], names=["resolved_paths"]) + + def to_batch_reader( + self, + batch_size=DEFAULT_BATCH_SIZE, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.RecordBatchReader: + return self.to_arrow_table().to_reader(batch_size) + + +class CsvDataSet(DataSet): + """ + A set of csv files. + """ + + __slots__ = ( + "schema", + "delim", + "max_line_size", + "parallel", + "header", + ) + + def __init__( + self, + paths: List[str], + schema: Dict[str, str], + delim=",", + max_line_size: Optional[int] = None, + parallel=True, + header=False, + root_dir: Optional[str] = "", + recursive=False, + columns: Optional[List[str]] = None, + union_by_name=False, + ) -> None: + super().__init__(paths, root_dir, recursive, columns, union_by_name) + assert isinstance( + schema, OrderedDict + ), f"type of csv schema is not OrderedDict: {type(schema)}" + self.schema = schema + self.delim = delim + self.max_line_size = max_line_size + self.parallel = parallel + self.header = header + + @staticmethod + def merge(datasets: "List[CsvDataSet]") -> "CsvDataSet": + assert all(isinstance(dataset, CsvDataSet) for dataset in datasets) + absolute_paths = [p for dataset in datasets for p in dataset.absolute_paths] + return CsvDataSet( + absolute_paths, + datasets[0].schema, + datasets[0].delim, + datasets[0].max_line_size, + datasets[0].parallel, + recursive=any(dataset.recursive for dataset in datasets), + columns=datasets[0].columns, + union_by_name=any(dataset.union_by_name for dataset in datasets), + ) + + def sql_query_fragment( + self, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> str: + schema_str = ", ".join( + map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items()) + ) + max_line_size_str = ( + f"max_line_size={self.max_line_size}, " + if self.max_line_size is not None + else "" + ) + return ( + f"( select {self._column_str} from read_csv([ {self._resolved_path_str} ], delim='{self.delim}', columns={{ {schema_str} }}, header={self.header}, " + f"{max_line_size_str} parallel={self.parallel}, union_by_name={self.union_by_name}) )" + ) + + +class JsonDataSet(DataSet): + """ + A set of json files. + """ + + __slots__ = ( + "schema", + "format", + "max_object_size", + ) + + def __init__( + self, + paths: List[str], + schema: Dict[str, str], + format="newline_delimited", + max_object_size=1 * GB, + root_dir: Optional[str] = "", + recursive=False, + columns: Optional[List[str]] = None, + union_by_name=False, + ) -> None: + super().__init__(paths, root_dir, recursive, columns, union_by_name) + self.schema = schema + self.format = format + self.max_object_size = max_object_size + + @staticmethod + def merge(datasets: "List[JsonDataSet]") -> "JsonDataSet": + assert all(isinstance(dataset, JsonDataSet) for dataset in datasets) + absolute_paths = [p for dataset in datasets for p in dataset.absolute_paths] + return JsonDataSet( + absolute_paths, + datasets[0].schema, + datasets[0].format, + datasets[0].max_object_size, + recursive=any(dataset.recursive for dataset in datasets), + columns=datasets[0].columns, + union_by_name=any(dataset.union_by_name for dataset in datasets), + ) + + def sql_query_fragment( + self, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> str: + schema_str = ", ".join( + map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items()) + ) + return ( + f"( select {self._column_str} from read_json([ {self._resolved_path_str} ], format='{self.format}', columns={{ {schema_str} }}, " + f"maximum_object_size={self.max_object_size}, union_by_name={self.union_by_name}) )" + ) + + +class ParquetDataSet(DataSet): + """ + A set of parquet files. + """ + + __slots__ = ( + "generated_columns", + "_resolved_row_ranges", + ) + + def __init__( + self, + paths: List[str], + root_dir: Optional[str] = "", + recursive=False, + columns: Optional[List[str]] = None, + generated_columns: Optional[List[str]] = None, + union_by_name=False, + ) -> None: + super().__init__(paths, root_dir, recursive, columns, union_by_name) + self.generated_columns = generated_columns or [] + "Generated columns of DuckDB `read_parquet` function. e.g. `file_name`, `file_row_number`." + self._resolved_row_ranges: List[RowRange] = None + + def __str__(self) -> str: + s = super().__str__() + f", generated_columns={self.generated_columns}" + if self._resolved_row_ranges: + s += f", resolved_row_ranges[{len(self._resolved_row_ranges)}]={self._resolved_row_ranges[:3]}..." + return s + + __repr__ = __str__ + + @property + def _column_str(self) -> str: + if not self.columns: + return "*" + if "*" in self.columns: + return ", ".join(self.columns) + return ", ".join(self.columns + self.generated_columns) + + @staticmethod + def merge(datasets: "List[ParquetDataSet]") -> "ParquetDataSet": + assert all(isinstance(dataset, ParquetDataSet) for dataset in datasets) + dataset = ParquetDataSet( + paths=[p for dataset in datasets for p in dataset.absolute_paths], + root_dir=None, + recursive=any(dataset.recursive for dataset in datasets), + columns=datasets[0].columns, + generated_columns=datasets[0].generated_columns, + union_by_name=any(dataset.union_by_name for dataset in datasets), + ) + # merge row ranges if any dataset has resolved row ranges + if any(dataset._resolved_row_ranges is not None for dataset in datasets): + dataset._resolved_row_ranges = [ + row_range + for dataset in datasets + for row_range in dataset.resolved_row_ranges + ] + return dataset + + @staticmethod + def create_from(table: arrow.Table, output_dir: str, filename: str = "data"): + dump_to_parquet_files(table, output_dir, filename) + return ParquetDataSet([os.path.join(output_dir, f"{filename}*.parquet")]) + + def reset( + self, + paths: Optional[List[str]] = None, + root_dir: Optional[str] = "", + recursive=None, + ) -> None: + """ + NOTE: all row ranges will be cleared. DO NOT call this if you want to keep partial files. + """ + super().reset(paths, root_dir, recursive) + self._resolved_row_ranges = None + self.partition_by_files.cache_clear() + self.partition_by_rows.cache_clear() + self.partition_by_size.cache_clear() + + @property + def resolved_row_ranges(self) -> List[RowRange]: + """ + Return row ranges for each parquet file. + """ + if self._resolved_row_ranges is None: + if len(self.resolved_paths) == 0: + self._resolved_row_ranges = [] + else: + + def resolve_row_range(path: str) -> RowRange: + # read parquet metadata to get number of rows + metadata = parquet.read_metadata(path) + num_rows = metadata.num_rows + uncompressed_data_size = sum( + metadata.row_group(i).total_byte_size + for i in range(metadata.num_row_groups) + ) + return RowRange( + path, + data_size=uncompressed_data_size, + file_num_rows=num_rows, + begin=0, + end=num_rows, + ) + + with ThreadPoolExecutor( + max_workers=min(32, len(self.resolved_paths)) + ) as pool: + self._resolved_row_ranges = list( + pool.map(resolve_row_range, self.resolved_paths) + ) + return self._resolved_row_ranges + + @property + def num_rows(self) -> int: + if self._resolved_num_rows is None: + self._resolved_num_rows = sum( + row_range.num_rows for row_range in self.resolved_row_ranges + ) + return self._resolved_num_rows + + @property + def empty(self) -> bool: + # this method should be quick. do not resolve row ranges. + if self._resolved_num_rows is not None or self._resolved_row_ranges is not None: + return self.num_rows == 0 + return super().empty + + @property + def estimated_data_size(self) -> int: + """ + Return the estimated data size in bytes. + """ + return sum( + row_range.estimated_data_size for row_range in self.resolved_row_ranges + ) + + def sql_query_fragment( + self, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> str: + extra_parameters = ( + "".join(f", {col}=true" for col in self.generated_columns) + if self.generated_columns + else "" + ) + parquet_file_queries = [] + full_row_ranges = [] + + for row_range in self.resolved_row_ranges: + path = ( + filesystem.unstrip_protocol(row_range.path) + if filesystem + else row_range.path + ) + if row_range.num_rows == row_range.file_num_rows: + full_row_ranges.append(row_range) + else: + sql_query = f""" + select {self._column_str} + from read_parquet('{path}' {extra_parameters}, file_row_number=true) + where file_row_number between {row_range.begin} and {row_range.end - 1} + """ + if "file_row_number" not in self.generated_columns: + sql_query = f"select columns(c -> c != 'file_row_number') from ( {sql_query} )" + parquet_file_queries.append(sql_query) + + # NOTE: prefer: read_parquet([path1, path2, ...]) + # instead of: read_parquet(path1) union all read_parquet(path2) union all ... + # for performance + if full_row_ranges: + # XXX: duckdb uses the first file as the estimated cardinality of `read_parquet` + # to prevent incorrect estimation, we move the largest file to the head + largest_index = max( + range(len(full_row_ranges)), + key=lambda i: full_row_ranges[i].file_num_rows, + ) + full_row_ranges[0], full_row_ranges[largest_index] = ( + full_row_ranges[largest_index], + full_row_ranges[0], + ) + parquet_file_str = ",\n ".join( + map(lambda x: f"'{x.path}'", full_row_ranges) + ) + parquet_file_queries.insert( + 0, + f""" + select {self._column_str} + from read_parquet([ + {parquet_file_str} + ], union_by_name={self.union_by_name} {extra_parameters}) + """, + ) + + union_op = " union all by name " if self.union_by_name else " union all " + return f"( {union_op.join(parquet_file_queries)} )" + + def to_arrow_table( + self, + max_workers=16, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.Table: + if conn is not None: + return super().to_arrow_table(max_workers, filesystem, conn) + + tables = [] + if self.resolved_row_ranges: + tables.append( + load_from_parquet_files( + self.resolved_row_ranges, self.columns, max_workers, filesystem + ) + ) + return arrow.concat_tables(tables) + + def to_batch_reader( + self, + batch_size=DEFAULT_BATCH_SIZE, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.RecordBatchReader: + if conn is not None: + return super().to_batch_reader(batch_size, filesystem, conn) + + return build_batch_reader_from_files( + self.resolved_row_ranges, + columns=self.columns, + batch_size=batch_size, + filesystem=filesystem, + ) + + @functools.lru_cache + def partition_by_files( + self, npartition: int, random_shuffle: bool = False + ) -> "List[ParquetDataSet]": + if self._resolved_row_ranges is not None: + return self.partition_by_rows(npartition, random_shuffle) + else: + return super().partition_by_files(npartition, random_shuffle) + + @functools.lru_cache + def partition_by_rows( + self, npartition: int, random_shuffle: bool = False + ) -> "List[ParquetDataSet]": + """ + Evenly split the dataset into `npartition` partitions by rows. + If `random_shuffle` is True, shuffle the files before partitioning. + """ + assert npartition > 0, f"npartition has negative value: {npartition}" + + resolved_row_ranges = self.resolved_row_ranges + resolved_row_ranges = ( + random.sample(resolved_row_ranges, len(resolved_row_ranges)) + if random_shuffle + else resolved_row_ranges + ) + + def create_dataset(row_ranges: List[RowRange]) -> ParquetDataSet: + row_ranges = sorted(row_ranges, key=lambda x: x.path) + resolved_paths = [x.path for x in row_ranges] + dataset = ParquetDataSet( + resolved_paths, + columns=self.columns, + generated_columns=self.generated_columns, + union_by_name=self.union_by_name, + ) + dataset._resolved_paths = resolved_paths + dataset._resolved_row_ranges = row_ranges + return dataset + + return [ + create_dataset(row_ranges) + for row_ranges in RowRange.partition_by_rows( + resolved_row_ranges, npartition + ) + ] + + @functools.lru_cache + def partition_by_size(self, max_partition_size: int) -> "List[ParquetDataSet]": + """ + Split the dataset into multiple partitions so that each partition has at most `max_partition_size` bytes. + """ + if self.empty: + return [] + estimated_data_size = sum( + row_range.estimated_data_size for row_range in self.resolved_row_ranges + ) + npartition = estimated_data_size // max_partition_size + 1 + return self.partition_by_rows(npartition) + + @staticmethod + def _read_partition_key( + path: str, data_partition_column: str, hive_partitioning: bool + ) -> int: + """ + Get the partition key of the parquet file. + + Examples + -------- + ``` + >>> ParquetDataSet._read_partition_key("output/000.parquet", "key", hive_partitioning=False) + 1 + >>> ParquetDataSet._read_partition_key("output/key=1/000.parquet", "key", hive_partitioning=True) + 1 + ``` + """ + + def parse_partition_key(key: str): + try: + return int(key) + except ValueError: + logger.error( + f"cannot parse partition key '{data_partition_column}' of {path} from: {key}" + ) + raise + + if hive_partitioning: + path_part_prefix = data_partition_column + "=" + for part in path.split(os.path.sep): + if part.startswith(path_part_prefix): + return parse_partition_key(part[len(path_part_prefix) :]) + raise RuntimeError( + f"cannot extract hive partition key '{data_partition_column}' from path: {path}" + ) + + with parquet.ParquetFile(path) as file: + kv_metadata = file.schema_arrow.metadata or file.metadata.metadata + if kv_metadata is not None: + for key, val in kv_metadata.items(): + key, val = key.decode("utf-8"), val.decode("utf-8") + if key == PARQUET_METADATA_KEY_PREFIX + data_partition_column: + return parse_partition_key(val) + if file.metadata.num_rows == 0: + logger.warning( + f"cannot read partition keys from empty parquet file: {path}" + ) + return None + for batch in file.iter_batches( + batch_size=128, columns=[data_partition_column], use_threads=False + ): + assert ( + data_partition_column in batch.column_names + ), f"cannot find column '{data_partition_column}' in {batch.column_names}" + assert ( + batch.num_columns == 1 + ), f"unexpected num of columns: {batch.column_names}" + uniq_partition_keys = set(batch.columns[0].to_pylist()) + assert ( + uniq_partition_keys and len(uniq_partition_keys) == 1 + ), f"partition keys found in {path} not unique: {uniq_partition_keys}" + return uniq_partition_keys.pop() + + def load_partitioned_datasets( + self, npartition: int, data_partition_column: str, hive_partitioning=False + ) -> "List[ParquetDataSet]": + """ + Split the dataset into a list of partitioned datasets. + """ + assert npartition > 0, f"npartition has negative value: {npartition}" + if npartition > self.num_files: + logger.debug( + f"number of partitions {npartition} is greater than the number of files {self.num_files}" + ) + + file_partitions: List[ParquetDataSet] = self._init_file_partitions(npartition) + for dataset in file_partitions: + # elements will be appended later + dataset._absolute_paths = [] + dataset._resolved_paths = [] + dataset._resolved_row_ranges = [] + + if not self.resolved_paths: + logger.debug(f"create {npartition} empty data partitions from {self}") + return file_partitions + + with ThreadPoolExecutor(min(32, len(self.resolved_paths))) as pool: + partition_keys = pool.map( + lambda path: ParquetDataSet._read_partition_key( + path, data_partition_column, hive_partitioning + ), + self.resolved_paths, + ) + + for row_range, partition_key in zip(self.resolved_row_ranges, partition_keys): + if partition_key is not None: + assert ( + 0 <= partition_key <= npartition + ), f"invalid partition key {partition_key} found in {row_range.path}" + dataset = file_partitions[partition_key] + dataset.paths.append(row_range.path) + dataset._absolute_paths.append(row_range.path) + dataset._resolved_paths.append(row_range.path) + dataset._resolved_row_ranges.append(row_range) + + logger.debug(f"loaded {npartition} data partitions from {self}") + return file_partitions + + def remove_empty_files(self) -> None: + """ + Remove empty parquet files from the dataset. + """ + new_row_ranges = [ + row_range + for row_range in self.resolved_row_ranges + if row_range.num_rows > 0 + ] + if len(new_row_ranges) == 0: + # keep at least one file to avoid empty dataset + new_row_ranges = self.resolved_row_ranges[:1] + if len(new_row_ranges) == len(self.resolved_row_ranges): + # no empty files found + return + logger.info( + f"removed {len(self.resolved_row_ranges) - len(new_row_ranges)}/{len(self.resolved_row_ranges)} empty parquet files from {self}" + ) + self._resolved_row_ranges = new_row_ranges + self._resolved_paths = [row_range.path for row_range in new_row_ranges] + self._absolute_paths = self._resolved_paths + self.paths = self._resolved_paths + + +class SqlQueryDataSet(DataSet): + """ + The result of a sql query. + """ + + __slots__ = ( + "sql_query", + "query_builder", + ) + + def __init__( + self, + sql_query: str, + query_builder: Callable[ + [duckdb.DuckDBPyConnection, fsspec.AbstractFileSystem], str + ] = None, + ) -> None: + super().__init__([]) + self.sql_query = sql_query + self.query_builder = query_builder + + @property + def num_rows(self) -> int: + num_rows = duckdb.sql( + f"select count(*) as num_rows from {self.sql_query_fragment()}" + ).fetchall() + return num_rows[0][0] + + def sql_query_fragment( + self, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> str: + sql_query = ( + self.sql_query + if self.query_builder is None + else self.query_builder(conn, filesystem) + ) + return f"( {sql_query} )" + + +class ArrowTableDataSet(DataSet): + """ + An arrow table. + """ + + def __init__(self, table: arrow.Table) -> None: + super().__init__([]) + self.table = copy.deepcopy(table) + + def to_arrow_table( + self, + max_workers=16, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.Table: + return self.table + + def to_batch_reader( + self, + batch_size=DEFAULT_BATCH_SIZE, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.RecordBatchReader: + return self.table.to_reader(batch_size) + + def sql_query_fragment( + self, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> str: + name = f"arrow_table_{id(self.table)}" + self.table.to_pandas().to_sql(name, conn, index=False) + return f"( select * from {name} )" + + +class PandasDataSet(DataSet): + """ + A pandas dataframe. + """ + + def __init__(self, df: pd.DataFrame) -> None: + super().__init__([]) + self.df = df + + def to_pandas(self) -> pd.DataFrame: + return self.df + + def to_arrow_table( + self, + max_workers=16, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.Table: + return arrow.Table.from_pandas(self.df) + + def to_batch_reader( + self, + batch_size=DEFAULT_BATCH_SIZE, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> arrow.RecordBatchReader: + return self.to_arrow_table().to_reader(batch_size) + + def sql_query_fragment( + self, + filesystem: fsspec.AbstractFileSystem = None, + conn: duckdb.DuckDBPyConnection = None, + ) -> str: + name = f"pandas_table_{id(self.df)}" + self.df.to_sql(name, conn, index=False) + return f"( select * from {name} )" diff --git a/smallpond/logical/node.py b/smallpond/logical/node.py new file mode 100644 index 0000000..0ebab8f --- /dev/null +++ b/smallpond/logical/node.py @@ -0,0 +1,2136 @@ +import copy +import math +import os +import os.path +import re +import traceback +import warnings +from collections import defaultdict +from typing import ( + Callable, + Dict, + Generic, + Iterable, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, +) + +import numpy as np +import pandas as pd +import pyarrow as arrow +from graphviz import Digraph + +from smallpond.common import ( + DATA_PARTITION_COLUMN_NAME, + DEFAULT_BATCH_SIZE, + DEFAULT_ROW_GROUP_SIZE, + GB, + GENERATED_COLUMNS, +) +from smallpond.execution.task import ( + ArrowBatchTask, + ArrowComputeTask, + ArrowStreamTask, + DataSinkTask, + DataSourceTask, + EvenlyDistributedPartitionProducerTask, + HashPartitionTask, + LoadPartitionedDataSetProducerTask, + MergeDataSetsTask, + PandasBatchTask, + PandasComputeTask, + PartitionConsumerTask, + PartitionInfo, + PartitionProducerTask, + PerfStats, + ProjectionTask, + PythonScriptTask, + RepeatPartitionProducerTask, + RuntimeContext, + SplitDataSetTask, + SqlEngineTask, + Task, + UserDefinedPartitionProducerTask, +) +from smallpond.logical.dataset import DataSet, ParquetDataSet +from smallpond.logical.udf import ( + DuckDbExtensionContext, + ExternalModuleContext, + PythonUDFContext, + UDFContext, + UDFType, + UserDefinedFunction, +) + + +class NodeId(int): + """ + A unique identifier for each node. + """ + + def __str__(self) -> str: + return f"{self:06d}" + + +class Context(object): + """ + Global context for each logical plan. + Right now it's mainly used to keep a list of Python UDFs. + """ + + def __init__(self) -> None: + self.next_node_id = 0 + self.udfs: Dict[str, UDFContext] = {} + + def _new_node_id(self) -> NodeId: + """ + Generate a new node id. + """ + self.next_node_id += 1 + return NodeId(self.next_node_id) + + def create_function( + self, + name: str, + func: Callable, + params: Optional[List[UDFType]], + return_type: Optional[UDFType], + use_arrow_type=False, + ) -> str: + """ + Define a Python UDF to be referenced in the logical plan. + Currently only scalar functions (return one element per row) are supported. + See https://duckdb.org/docs/archive/0.9.2/api/python/function. + + Parameters + ---------- + name + A unique function name, which can be referenced in SQL query. + func + The Python function you wish to register as a UDF. + params + A list of column types for function parameters, including basic types: + `UDFType.INTEGER`, `UDFType.FLOAT`, `UDFType.VARCHAR`, `UDFType.BLOB` etc, + and container types: + ``` + UDFListType(UDFType.INTEGER), + UDFMapType(UDFType.VARCHAR, UDFType.INTEGER), + UDFListType(UDFStructType({'int': 'INTEGER', 'str': 'VARCHAR'})). + ``` + These types are simple wrappers of duckdb types defined in https://duckdb.org/docs/api/python/types.html. + Set params to `UDFAnyParameters()` allows the udf to accept parameters of any type. + use_arrow_type, optional + Specify true to use PyArrow Tables, by default use built-in Python types. + return_type + The return type of the function, see the above note for `params`. + + Returns + ------- + The unique function name. + """ + self.udfs[name] = PythonUDFContext( + name, func, params, return_type, use_arrow_type + ) + return name + + def create_external_module(self, module_path: str, name: str = None) -> str: + """ + Load an external DuckDB module. + """ + name = name or os.path.basename(module_path) + self.udfs[name] = ExternalModuleContext(name, module_path) + return name + + def create_duckdb_extension(self, extension_path: str, name: str = None) -> str: + """ + Load a DuckDB extension. + """ + name = name or os.path.basename(extension_path) + self.udfs[name] = DuckDbExtensionContext(name, extension_path) + return name + + +class Node(object): + """ + The base class for all nodes. + """ + + enable_resource_boost = False + + def __init__( + self, + ctx: Context, + input_deps: "Tuple[Node, ...]", + output_name: Optional[str] = None, + output_path: Optional[str] = None, + cpu_limit: int = 1, + gpu_limit: float = 0, + memory_limit: Optional[int] = None, + ) -> None: + """ + The base class for all nodes in logical plan. + + Parameters + ---------- + ctx + The context of logical plan. + input_deps + Define the inputs of this node. + output_name, optional + The prefix of output directories and filenames for tasks generated from this node. + The default `output_name` is the class name of the task created for this node, e.g. + `HashPartitionTask, SqlEngineTask, PythonScriptTask`, etc. + + The `output_name` string should only include alphanumeric characters and underscore. + In other words, it matches regular expression `[a-zA-Z0-9_]+`. + + If `output_name` is set and `output_path` is None, the path format of output files is: + `{job_root_path}/output/{output_name}/{task_runtime_id}/{output_name}-{task_runtime_id}-{seqnum}.parquet` + where `{task_runtime_id}` is defined as `{job_id}.{task_id}.{sched_epoch}.{task_retry_count}`. + output_path, optional + The absolute path of a customized output folder for tasks generated from this node. + Any shared folder that can be accessed by executor and scheduler is allowed + although IO performance varies across filesystems. + + If both `output_name` and `output_path` are specified, the path format of output files is: + `{output_path}/{output_name}/{task_runtime_id}/{output_name}-{task_runtime_id}-{seqnum}.parquet` + where `{task_runtime_id}` is defined as `{job_id}.{task_id}.{sched_epoch}.{task_retry_count}`. + cpu_limit, optional + The max number of CPUs would be used by tasks generated from this node. + This is a resource requirement specified by the user and used to guide + task scheduling. smallpond does NOT enforce this limit. + gpu_limit, optional + The max number of GPUs would be used by tasks generated from this node. + This is a resource requirement specified by the user and used to guide + task scheduling. smallpond does NOT enforce this limit. + memory_limit, optional + The max memory would be used by tasks generated from this node. + The memory limit is automatically calculated based memory-to-cpu ratio of executor machine if not specified. + This is a resource requirement specified by the user and used to guide + task scheduling. smallpond does NOT enforce this limit. + """ + assert isinstance( + input_deps, Iterable + ), f"input_deps is not iterable: {input_deps}" + assert all( + isinstance(node, Node) for node in input_deps + ), f"some of input_deps are not instances of Node: {input_deps}" + assert output_name is None or re.match( + "[a-zA-Z0-9_]+", output_name + ), f"output_name has invalid format: {output_name}" + assert output_path is None or os.path.isabs( + output_path + ), f"output_path is not an absolute path: {output_path}" + self.ctx = ctx + self.id = self.ctx._new_node_id() + self.input_deps = input_deps + self.output_name = output_name + self.output_path = output_path + self.cpu_limit = max(cpu_limit, gpu_limit * 8) + self.gpu_limit = gpu_limit + self.memory_limit = memory_limit + self.generated_tasks: List[str] = [] + self.perf_stats: Dict[str, PerfStats] = {} + self.perf_metrics: Dict[str, List[float]] = defaultdict(list) + # record the location where the node is constructed in user code + frame = next( + frame + for frame in reversed(traceback.extract_stack()) + if frame.filename != __file__ + and not frame.filename.endswith("/dataframe.py") + ) + self.location = f"{frame.filename}:{frame.lineno}" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}-{self.id}" + + def __str__(self) -> str: + return ( + f"{repr(self)}: input_deps={self.input_deps}, output_name={self.output_name}, " + f"tasks[{len(self.generated_tasks)}]={self.generated_tasks[:1]}...{self.generated_tasks[-1:]}, " + f"resource_limit={self.cpu_limit}CPUs/{self.gpu_limit}GPUs/{(self.memory_limit or 0)//GB}GB" + ) + + @staticmethod + def task_factory(task_builder): + def wrapper(self: Node, *args, **kwargs): + task: Task = task_builder(self, *args, **kwargs) + task.node_id = self.id + task.location = self.location + self.generated_tasks.append(task.key) + return task + + return wrapper + + def slim_copy(self): + node = copy.copy(self) + del node.input_deps, node.generated_tasks + return node + + def create_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> Task: + raise NotImplementedError + + def add_perf_metrics(self, name, value: Union[List[float], float]): + self.perf_metrics[name].append(value) + self.perf_stats.pop(name, None) + + def get_perf_stats(self, name): + if name not in self.perf_stats: + if name not in self.perf_metrics: + return None + values = self.perf_metrics[name] + min, max, avg = np.min(values), np.max(values), np.average(values) + p50, p75, p95, p99 = np.percentile(values, (50, 75, 95, 99)) + self.perf_stats[name] = PerfStats( + len(values), sum(values), min, max, avg, p50, p75, p95, p99 + ) + return self.perf_stats[name] + + @property + def num_partitions(self) -> int: + raise NotImplementedError("num_partitions") + + +class DataSourceNode(Node): + """ + All inputs of a logical plan are represented as `DataSourceNode`. It does not depend on any other node. + """ + + def __init__(self, ctx: Context, dataset: DataSet) -> None: + """ + Construct a DataSourceNode. See :func:`Node.__init__` to find comments on other parameters. + + Parameters + ---------- + dataset + A DataSet instance serving as a input of the plan. Set to `None` to create a dummy data source. + """ + assert dataset is None or isinstance(dataset, DataSet) + super().__init__(ctx, []) + self.dataset = dataset if dataset is not None else ParquetDataSet([]) + + def __str__(self) -> str: + return super().__str__() + f", dataset=<{self.dataset}>" + + @Node.task_factory + def create_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> DataSourceTask: + return DataSourceTask(runtime_ctx, self.dataset, partition_infos) + + @property + def num_partitions(self) -> int: + return 1 + + +DataSinkType = Literal["link", "copy", "link_or_copy", "manifest"] + + +class DataSinkNode(Node): + """ + Collect the output files of `input_deps` to `output_path`. + Depending on the options, it may create hard links, symbolic links, manifest files, or copy files. + """ + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + output_path: str, + type: DataSinkType = "link", + manifest_only=False, + is_final_node=False, + ) -> None: + """ + Construct a DataSinkNode. See :func:`Node.__init__` to find comments on other parameters. + + Parameters + ---------- + output_path + The absolute path of a customized output folder. If set to None, an output + folder would be created under the default output root. + Any shared folder that can be accessed by executor and scheduler is allowed + although IO performance varies across filesystems. + type, optional + The operation type of the sink node. + "link" (default): If an output file is under the same mount point as `output_path`, a hard link is created; otherwise a symlink. + "copy": Copies files to the output path. + "link_or_copy": If an output file is under the same mount point as `output_path`, creates a hard link; otherwise copies the file. + "manifest": Creates a manifest file under `output_path`. Every line of the manifest file is a path string. + manifest_only, optional, deprecated + Set type to "manifest". + """ + assert type in ( + "link", + "copy", + "link_or_copy", + "manifest", + ), f"invalid sink type: {type}" + super().__init__( + ctx, input_deps, None, output_path, cpu_limit=1, gpu_limit=0, memory_limit=0 + ) + self.type: DataSinkType = "manifest" if manifest_only else type + self.is_final_node = is_final_node + + def __str__(self) -> str: + return super().__str__() + f", output_path={self.output_path}, type={self.type}" + + @Node.task_factory + def create_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> DataSinkTask: + # design considerations: + # 1. data copy should start as soon as possible. + # 2. file names may conflict across partitions of different tasks. + # we should rename files **if and only if** there are conflicts. + # 3. resolving conflicts requires a single task. + if self.type == "copy" or self.type == "link_or_copy": + # so we create two phase tasks: + # phase1: copy data to a temp directory, for each input partition in parallel + input_deps = [ + self._create_phase1_task( + runtime_ctx, task, [PartitionInfo(i, len(input_deps))] + ) + for i, task in enumerate(input_deps) + ] + # phase2: resolve file name conflicts, hard link files, create manifest file, and clean up temp directory + return DataSinkTask( + runtime_ctx, + input_deps, + [PartitionInfo()], + self.output_path, + type="link_manifest", + is_final_node=self.is_final_node, + ) + elif self.type == "link": + return DataSinkTask( + runtime_ctx, + input_deps, + [PartitionInfo()], + self.output_path, + type="link_manifest", + is_final_node=self.is_final_node, + ) + elif self.type == "manifest": + return DataSinkTask( + runtime_ctx, + input_deps, + [PartitionInfo()], + self.output_path, + type="manifest", + is_final_node=self.is_final_node, + ) + else: + raise ValueError(f"invalid sink type: {self.type}") + + @Node.task_factory + def _create_phase1_task( + self, + runtime_ctx: RuntimeContext, + input_dep: Task, + partition_infos: List[PartitionInfo], + ) -> DataSinkTask: + return DataSinkTask( + runtime_ctx, [input_dep], partition_infos, self.output_path, type=self.type + ) + + +class PythonScriptNode(Node): + """ + Run Python code to process the input datasets with `PythonScriptNode.process(...)`. + + If the code needs to access attributes of runtime task, e.g. `local_rank`, `random_seed_long`, `numpy_random_gen`, + + - create a subclass of `PythonScriptTask`, which implements `PythonScriptTask.process(...)`, + - override `PythonScriptNode.spawn(...)` and return an instance of the subclass. + + Or use `runtime_ctx.task` in `process(runtime_ctx: RuntimeContext, ...)` function. + """ + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + *, + process_func: Optional[ + Callable[[RuntimeContext, List[DataSet], str], bool] + ] = None, + output_name: Optional[str] = None, + output_path: Optional[str] = None, + cpu_limit: int = 1, + gpu_limit: float = 0, + memory_limit: Optional[int] = None, + ): + """ + Construct a PythonScriptNode. See :func:`Node.__init__` to find comments on other parameters. + + Parameters + ---------- + process_func, optional + User-defined process function, which should have the same signature as `self.process(...)`. + If user-defined function has extra parameters, use `functools.partial(...)` to bind arguments. + See `test_partial_process_func` in `test/test_execution.py` for examples of usage. + """ + super().__init__( + ctx, + input_deps, + output_name, + output_path, + cpu_limit, + gpu_limit, + memory_limit, + ) + self.process_func = process_func + + def __str__(self) -> str: + return super().__str__() + f", process_func={self.process_func}" + + @Node.task_factory + def create_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> PythonScriptTask: + return self.spawn( + runtime_ctx, + input_deps, + partition_infos, + self.process_func + or self.slim_copy().process, # warn: do not call self.slim_copy() in __init__ as attributes may not be fully initialized + self.output_name, + self.output_path, + self.cpu_limit, + self.gpu_limit, + self.memory_limit, + ) + + def spawn(self, *args, **kwargs) -> PythonScriptTask: + """ + Return an instance of subclass of `PythonScriptTask`. The subclass should override `PythonScriptTask.process(...)`. + + Examples + -------- + ``` + class OutputMsgPythonTask(PythonScriptTask): + + def __init__(self, msg: str, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.msg = msg + + def process(self, runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool: + fout = (Path(output_path) / f"{self.output_filename}-{self.local_rank}.SUCCESS") + fout.write_text(f"msg: {self.msg}, seed: {self.random_seed_uint32}, rank: {self.local_rank}") + return True + + + class OutputMsgPythonNode(PythonScriptNode): + + def spawn(self, *args, **kwargs) -> OutputMsgPythonTask: + return OutputMsgPythonTask("python script", *args, **kwargs) + ``` + """ + return PythonScriptTask(*args, **kwargs) + + def process( + self, + runtime_ctx: RuntimeContext, + input_datasets: List[DataSet], + output_path: str, + ) -> bool: + """ + Put user-defined code here. + + Parameters + ---------- + runtime_ctx + The runtime context, which defines a few global configuration info. + input_datasets + A list of input datasets. The number of datasets equal to the number of input_deps. + output_path + The absolute path of output directory created for each task generated from this node. + The outputs generated by this node would be consumed by tasks of downstream nodes. + + Returns + ------- + Return true if success. Return false or throw an exception if there is any error. + """ + raise NotImplementedError + + +class ArrowComputeNode(Node): + """ + Run Python code to process the input datasets, which have been loaded as Apache Arrow tables. + See https://arrow.apache.org/docs/python/generated/pyarrow.Table.html. + + If the code needs to access attributes of runtime task, e.g. `local_rank`, `random_seed_long`, `numpy_random_gen`, + + - create a subclass of `ArrowComputeTask`, which implements `ArrowComputeTask.process(...)`, + - override `ArrowComputeNode.spawn(...)` and return an instance of the subclass. + + Or use `runtime_ctx.task` in `process(runtime_ctx: RuntimeContext, ...)` function. + """ + + default_row_group_size = DEFAULT_ROW_GROUP_SIZE + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + *, + process_func: Callable[[RuntimeContext, List[arrow.Table]], arrow.Table] = None, + parquet_row_group_size: int = None, + parquet_dictionary_encoding=False, + parquet_compression="ZSTD", + parquet_compression_level=3, + use_duckdb_reader=False, + output_name: str = None, + output_path: str = None, + cpu_limit: int = 1, + gpu_limit: float = 0, + memory_limit: Optional[int] = None, + ) -> None: + """ + Construct a ArrowComputeNode. See :func:`Node.__init__` to find comments on other parameters. + + Parameters + ---------- + process_func, optional + User-defined process function, which should have the same signature as `self.process(...)`. + If user-defined function has extra parameters, use `functools.partial(...)` to bind arguments. + See `test_partial_process_func` in `test/test_execution.py` for examples of usage. + parquet_row_group_size, optional + The number of rows stored in each row group of parquet file. + Large row group size provides more opportunities to compress the data. + Small row groups size could make filtering rows faster and achieve high concurrency. + See https://duckdb.org/docs/data/parquet/tips.html#selecting-a-row_group_size. + parquet_dictionary_encoding, optional + Specify if we should use dictionary encoding in general or only for some columns. + See `use_dictionary` in https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html. + use_duckdb_reader, optional + Use duckdb (instead of pyarrow parquet module) to load parquet files as arrow table. + cpu_limit, optional + The max number of CPUs would be used by tasks generated from this node. + This is a resource requirement specified by the user and used to guide + task scheduling. smallpond does NOT enforce this limit. + gpu_limit, optional + The max number of GPUs would be used by tasks generated from this node. + This is a resource requirement specified by the user and used to guide + task scheduling. smallpond does NOT enforce this limit. + memory_limit, optional + The max memory would be used by tasks generated from this node. + This is a resource requirement specified by the user and used to guide + task scheduling. smallpond does NOT enforce this limit. + """ + super().__init__( + ctx, + input_deps, + output_name, + output_path, + cpu_limit, + gpu_limit, + memory_limit, + ) + self.parquet_row_group_size = ( + parquet_row_group_size or self.default_row_group_size + ) + self.parquet_dictionary_encoding = parquet_dictionary_encoding + self.parquet_compression = parquet_compression + self.parquet_compression_level = parquet_compression_level + self.use_duckdb_reader = use_duckdb_reader + self.process_func = process_func + + def __str__(self) -> str: + return super().__str__() + f", process_func={self.process_func}" + + @Node.task_factory + def create_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> ArrowComputeTask: + return self.spawn( + runtime_ctx, + input_deps, + partition_infos, + self.process_func + or self.slim_copy().process, # warn: do not call self.slim_copy() in __init__ as attributes may not be fully initialized + self.parquet_row_group_size, + self.parquet_dictionary_encoding, + self.parquet_compression, + self.parquet_compression_level, + self.use_duckdb_reader, + self.output_name, + self.output_path, + self.cpu_limit, + self.gpu_limit, + self.memory_limit, + ) + + def spawn(self, *args, **kwargs) -> ArrowComputeTask: + """ + Return an instance of subclass of `ArrowComputeTask`. The subclass should override `ArrowComputeTask.process(...)`. + + Examples + -------- + ``` + class CopyInputArrowTask(ArrowComputeTask): + + def __init__(self, msg: str, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.msg = msg + + def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: + return input_tables[0] + + + class CopyInputArrowNode(ArrowComputeNode): + + def spawn(self, *args, **kwargs) -> CopyInputArrowTask: + return CopyInputArrowTask("arrow compute", *args, **kwargs) + ``` + """ + return ArrowComputeTask(*args, **kwargs) + + def process( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + """ + Put user-defined code here. + + Parameters + ---------- + runtime_ctx + The runtime context, which defines a few global configuration info. + input_datasets + A list of arrow tables. The number of arrow tables equal to the number of input_deps. + + Returns + ------- + Return the output as a arrow table. Throw an exception if there is any error. + """ + raise NotImplementedError + + +class ArrowStreamNode(Node): + """ + Run Python code to process the input datasets, which have been loaded as RecordBatchReader. + See https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html. + + If the code needs to access attributes of runtime task, e.g. `local_rank`, `random_seed_long`, `numpy_random_gen`, + - create a subclass of `ArrowStreamTask`, which implements `ArrowStreamTask.process(...)`, + - override `ArrowStreamNode.spawn(...)` and return an instance of the subclass. + + Or use `runtime_ctx.task` in `process(runtime_ctx: RuntimeContext, ...)` function. + """ + + default_batch_size = DEFAULT_BATCH_SIZE + default_row_group_size = DEFAULT_ROW_GROUP_SIZE + default_secs_checkpoint_interval = 180 + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + *, + process_func: Callable[ + [RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table] + ] = None, + background_io_thread=True, + streaming_batch_size: int = None, + secs_checkpoint_interval: int = None, + parquet_row_group_size: int = None, + parquet_dictionary_encoding=False, + parquet_compression="ZSTD", + parquet_compression_level=3, + use_duckdb_reader=False, + output_name: str = None, + output_path: str = None, + cpu_limit: int = 1, + gpu_limit: float = 0, + memory_limit: Optional[int] = None, + ) -> None: + """ + Construct a ArrowStreamNode. See :func:`Node.__init__` to find comments on other parameters. + + Parameters + ---------- + process_func, optional + User-defined process function, which should have the same signature as `self.process(...)`. + If user-defined function has extra parameters, use `functools.partial(...)` to bind arguments. + See `test_partial_process_func` in `test/test_execution.py` for examples of usage. + background_io_thread, optional + Create a background IO thread for read/write. + streaming_batch_size, optional + Split the input datasets into batches, each of which has length less or equal to `streaming_batch_size`. + secs_checkpoint_interval, optional + Create a checkpoint of the stream task every `secs_checkpoint_interval` seconds. + parquet_row_group_size, optional + The number of rows stored in each row group of parquet file. + Large row group size provides more opportunities to compress the data. + Small row groups size could make filtering rows faster and achieve high concurrency. + See https://duckdb.org/docs/data/parquet/tips.html#selecting-a-row_group_size. + parquet_dictionary_encoding, optional + Specify if we should use dictionary encoding in general or only for some columns. + See `use_dictionary` in https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html. + use_duckdb_reader, optional + Use duckdb (instead of pyarrow parquet module) to load parquet files as arrow table. + cpu_limit, optional + The max number of CPUs would be used by tasks generated from this node. + This is a resource requirement specified by the user and used to guide + task scheduling. smallpond does NOT enforce this limit. + gpu_limit, optional + The max number of GPUs would be used by tasks generated from this node. + This is a resource requirement specified by the user and used to guide + task scheduling. smallpond does NOT enforce this limit. + memory_limit, optional + The max memory would be used by tasks generated from this node. + This is a resource requirement specified by the user and used to guide + task scheduling. smallpond does NOT enforce this limit. + """ + super().__init__( + ctx, + input_deps, + output_name, + output_path, + cpu_limit, + gpu_limit, + memory_limit, + ) + self.background_io_thread = background_io_thread and self.cpu_limit > 1 + self.streaming_batch_size = streaming_batch_size or self.default_batch_size + self.secs_checkpoint_interval = secs_checkpoint_interval or math.ceil( + self.default_secs_checkpoint_interval + / min(6, self.gpu_limit + 2, self.cpu_limit) + ) + self.parquet_row_group_size = ( + parquet_row_group_size or self.default_row_group_size + ) + self.parquet_dictionary_encoding = parquet_dictionary_encoding + self.parquet_compression = parquet_compression + self.parquet_compression_level = parquet_compression_level + self.use_duckdb_reader = use_duckdb_reader + self.process_func = process_func + + def __str__(self) -> str: + return ( + super().__str__() + + f", process_func={self.process_func}, background_io_thread={self.background_io_thread}, streaming_batch_size={self.streaming_batch_size}, checkpoint_interval={self.secs_checkpoint_interval}s" + ) + + @Node.task_factory + def create_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> ArrowStreamTask: + return self.spawn( + runtime_ctx, + input_deps, + partition_infos, + self.process_func + or self.slim_copy().process, # warn: do not call self.slim_copy() in __init__ as attributes may not be fully initialized + self.background_io_thread, + self.streaming_batch_size, + self.secs_checkpoint_interval, + self.parquet_row_group_size, + self.parquet_dictionary_encoding, + self.parquet_compression, + self.parquet_compression_level, + self.use_duckdb_reader, + self.output_name, + self.output_path, + self.cpu_limit, + self.gpu_limit, + self.memory_limit, + ) + + def spawn(self, *args, **kwargs) -> ArrowStreamTask: + """ + Return an instance of subclass of `ArrowStreamTask`. The subclass should override `ArrowStreamTask.process(...)`. + + Examples + -------- + ``` + class CopyInputStreamTask(ArrowStreamTask): + + def __init__(self, msg: str, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.msg = msg + + def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]: + for batch in input_readers[0]: + yield arrow.Table.from_batches([batch]) + + + class CopyInputStreamNode(ArrowStreamNode): + + default_batch_size = 10 + + def spawn(self, *args, **kwargs) -> CopyInputStreamTask: + return CopyInputStreamTask("arrow stream", *args, **kwargs) + ``` + """ + return ArrowStreamTask(*args, **kwargs) + + def process( + self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] + ) -> Iterable[arrow.Table]: + """ + Put user-defined code here. + + Parameters + ---------- + runtime_ctx + The runtime context, which defines a few global configuration info. + input_readers + A list of RecordBatchReader. The number of readers equal to the number of input_deps. + + Returns + ------- + Return the output as a arrow table. Throw an exception if there is any error. + """ + raise NotImplementedError + + +class ArrowBatchNode(ArrowStreamNode): + """ + Run user-defined code to process the input datasets as a series of arrow tables. + """ + + def spawn(self, *args, **kwargs) -> ArrowBatchTask: + return ArrowBatchTask(*args, **kwargs) + + def process( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + raise NotImplementedError + + +class PandasComputeNode(ArrowComputeNode): + """ + Run Python code to process the input datasets as a single pandas DataFrame. + """ + + def spawn(self, *args, **kwargs) -> PandasComputeTask: + return PandasComputeTask(*args, **kwargs) + + def process( + self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame] + ) -> pd.DataFrame: + raise NotImplementedError + + +class PandasBatchNode(ArrowStreamNode): + """ + Run Python code to process the input datasets as a series of pandas DataFrames. + """ + + def spawn(self, *args, **kwargs) -> PandasBatchTask: + return PandasBatchTask(*args, **kwargs) + + def process( + self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame] + ) -> pd.DataFrame: + raise NotImplementedError + + +class SqlEngineNode(Node): + """ + Run SQL query against the outputs of input_deps. + """ + + max_udf_cpu_limit = 3 + default_cpu_limit = 1 + default_memory_limit = None + default_row_group_size = DEFAULT_ROW_GROUP_SIZE + enable_resource_boost = True + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + sql_query: Union[str, Iterable[str]], + *, + udfs: List[Union[str, UserDefinedFunction]] = None, + per_thread_output=True, + materialize_output=True, + materialize_in_memory=False, + relax_memory_if_oom=None, + batched_processing=False, + extension_paths: List[str] = None, + udf_module_paths: List[str] = None, + enable_temp_directory=False, + parquet_row_group_size: int = None, + parquet_dictionary_encoding: bool = False, + parquet_compression="ZSTD", + parquet_compression_level=3, + output_name: Optional[str] = None, + output_path: Optional[str] = None, + cpu_limit: Optional[int] = None, + memory_limit: Optional[int] = None, + cpu_overcommit_ratio: float = 1.0, + memory_overcommit_ratio: float = 0.9, + ) -> None: + """ + Construct a SqlEngineNode. See :func:`Node.__init__` to find comments on other parameters. + + Parameters + ---------- + sql_query + SQL query string or a list of query strings, currently DuckDB query syntax is supported, + see https://duckdb.org/docs/sql/query_syntax/select. + All queries are executed. But only the results of the last query is persisted as the output. + + The output dataset of each `input_deps` can be referenced as `{0}`, `{1}`, `{2}`, etc. + For example, the following query counts the total number of product items + from `{0}` that have `category_id` included in `{1}`. + + .. code-block:: + + select count(product_item.id) from {0} + where product_item.id > 0 and + product_item.category_id in ( select category_id from {1} ) + + The following placeholders are supported in the query: + + - `{batch_index}`: the index of the current batch. + - `{query_index}`: the index of the current query. + - `{rand_seed}`: the random seed of the current query. + - `{__data_partition__}`: the index of the current data partition. + udfs, optional + A list of user-defined functions to be referenced in `sql_query`. + Each element can be one of the following: + + - A `@udf` decorated function. + - A path to a duckdb extension file, e.g. `path/to/udf.duckdb_extension`. + - A string returned by `ctx.create_function()` or `ctx.create_duckdb_extension()`. + + If `udfs` is not empty, the resource requirement is downgraded to `min(cpu_limit, 3)` and `min(memory_limit, 50*GB)` + since UDF execution in duckdb is not highly paralleled. + per_thread_output, optional + If the final number of Parquet files is not important, writing one file per thread can significantly improve performance. + Also see https://duckdb.org/docs/data/parquet/tips.html#enabling-per_thread_output. + materialize_output, optional + Query result is materialized to the underlying filesystem as parquet files if enabled. + materialize_in_memory, optional + Materialize query result in memory before writing to the underlying filesystem, by default False. + relax_memory_if_oom, optional + Double the memory limit and retry if sql engine OOM, by default False. + batched_processing, optional + Split input dataset into multiple batches, each of which fits into memory limit, and then run sql query against each batch. + Enabled only if `len(input_deps) == 1`. + extension_paths, optional + A list of duckdb extension paths to be loaded at runtime. + enable_temp_directory, optional + Write temp files when memory is low, by default False. + parquet_row_group_size, optional + The number of rows stored in each row group of parquet file. + Large row group size provides more opportunities to compress the data. + Small row groups size could make filtering rows faster and achieve high concurrency. + See https://duckdb.org/docs/data/parquet/tips.html#selecting-a-row_group_size. + parquet_dictionary_encoding, optional + Specify if we should use dictionary encoding in general or only for some columns. + When encoding the column, if the dictionary size is too large, the column will fallback to PLAIN encoding. + By default, dictionary encoding is enabled for all columns. Set it to False to disable dictionary encoding, + or pass in column names to enable it only for specific columns. eg: parquet_dictionary_encoding=['column_1'] + cpu_limit, optional + The max number of CPUs used by the SQL engine. + memory_limit, optional + The max memory used by the SQL engine. + cpu_overcommit_ratio, optional + The effective number of threads used by the SQL engine is: `cpu_limit * cpu_overcommit_ratio`. + memory_overcommit_ratio, optional + The effective size of memory used by the SQL engine is: `memory_limit * memory_overcommit_ratio`. + """ + + cpu_limit = cpu_limit or self.default_cpu_limit + memory_limit = memory_limit or self.default_memory_limit + if udfs is not None: + if ( + self.max_udf_cpu_limit is not None + and cpu_limit > self.max_udf_cpu_limit + ): + warnings.warn( + f"UDF execution is not highly paralleled, downgrade cpu_limit from {cpu_limit} to {self.max_udf_cpu_limit}" + ) + cpu_limit = self.max_udf_cpu_limit + memory_limit = None + if relax_memory_if_oom is not None: + warnings.warn( + "Argument 'relax_memory_if_oom' has been deprecated", + DeprecationWarning, + stacklevel=3, + ) + + assert isinstance(sql_query, str) or ( + isinstance(sql_query, Iterable) + and all(isinstance(q, str) for q in sql_query) + ) + super().__init__( + ctx, + input_deps, + output_name, + output_path, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + self.sql_queries = ( + [sql_query] if isinstance(sql_query, str) else list(sql_query) + ) + self.udfs = [ + ctx.create_duckdb_extension(path) for path in extension_paths or [] + ] + [ctx.create_external_module(path) for path in udf_module_paths or []] + for udf in udfs or []: + if isinstance(udf, UserDefinedFunction): + name = ctx.create_function( + udf.name, udf.func, udf.params, udf.return_type, udf.use_arrow_type + ) + else: + assert isinstance(udf, str), f"udf must be a string: {udf}" + if udf in ctx.udfs: + name = udf + elif udf.endswith(".duckdb_extension"): + name = ctx.create_duckdb_extension(udf) + elif udf.endswith(".so"): + name = ctx.create_external_module(udf) + else: + raise ValueError(f"invalid udf: {udf}") + self.udfs.append(name) + + self.per_thread_output = per_thread_output + self.materialize_output = materialize_output + self.materialize_in_memory = materialize_in_memory + self.batched_processing = batched_processing and len(input_deps) == 1 + self.enable_temp_directory = enable_temp_directory + self.parquet_row_group_size = ( + parquet_row_group_size or self.default_row_group_size + ) + self.parquet_dictionary_encoding = parquet_dictionary_encoding + self.parquet_compression = parquet_compression + self.parquet_compression_level = parquet_compression_level + self.cpu_overcommit_ratio = cpu_overcommit_ratio + self.memory_overcommit_ratio = memory_overcommit_ratio + + def __str__(self) -> str: + return ( + super().__str__() + + f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}" + ) + + @property + def oneline_query(self) -> str: + return "; ".join( + " ".join(filter(None, map(str.strip, query.splitlines()))) + for query in self.sql_queries + ) + + @Node.task_factory + def create_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> SqlEngineTask: + return self.spawn( + runtime_ctx, + input_deps, + partition_infos, + self.sql_queries, + udfs=[self.ctx.udfs[name] for name in self.udfs], + per_thread_output=self.per_thread_output, + materialize_output=self.materialize_output, + materialize_in_memory=self.materialize_in_memory, + batched_processing=self.batched_processing, + enable_temp_directory=self.enable_temp_directory, + parquet_row_group_size=self.parquet_row_group_size, + parquet_dictionary_encoding=self.parquet_dictionary_encoding, + parquet_compression=self.parquet_compression, + parquet_compression_level=self.parquet_compression_level, + output_name=self.output_name, + output_path=self.output_path, + cpu_limit=self.cpu_limit, + gpu_limit=self.gpu_limit, + memory_limit=self.memory_limit, + cpu_overcommit_ratio=self.cpu_overcommit_ratio, + memory_overcommit_ratio=self.memory_overcommit_ratio, + ) + + def spawn(self, *args, **kwargs) -> SqlEngineTask: + return SqlEngineTask(*args, **kwargs) + + @property + def num_partitions(self) -> int: + return self.input_deps[0].num_partitions + + +class UnionNode(Node): + """ + Union two or more nodes into one flow of data. + """ + + def __init__(self, ctx: Context, input_deps: Tuple[Node, ...]): + """ + Union two or more `input_deps` into one flow of data. + + Parameters + ---------- + input_deps + All input deps should have the same set of partition dimensions. + """ + super().__init__(ctx, input_deps) + + +class RootNode(Node): + """ + A virtual node that assembles multiple nodes and outputs nothing. + """ + + def __init__(self, ctx: Context, input_deps: Tuple[Node, ...]): + """ + Assemble multiple nodes into a root node. + """ + super().__init__(ctx, input_deps) + + +class ConsolidateNode(Node): + """ + Consolidate partitions into larger ones. + """ + + def __init__(self, ctx: Context, input_dep: Node, dimensions: List[str]): + """ + Effectively reduces the number of partitions without shuffling the data across the network. + + Parameters + ---------- + dimensions + Partitions would be grouped by these `dimensions` and consolidated into larger partitions. + """ + assert isinstance( + dimensions, Iterable + ), f"dimensions is not iterable: {dimensions}" + assert all( + isinstance(dim, str) for dim in dimensions + ), f"some dimensions are not strings: {dimensions}" + super().__init__(ctx, [input_dep]) + self.dimensions = set(list(dimensions) + [PartitionInfo.toplevel_dimension]) + + def __str__(self) -> str: + return super().__str__() + f", dimensions={self.dimensions}" + + @Node.task_factory + def create_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> MergeDataSetsTask: + return MergeDataSetsTask(runtime_ctx, input_deps, partition_infos) + + +class PartitionNode(Node): + """ + The base class for all partition nodes. + """ + + max_num_producer_tasks = 100 + max_card_of_producers_x_consumers = 4_096_000 + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + npartitions: int, + dimension: str = None, + nested: bool = False, + output_name: str = None, + output_path: str = None, + cpu_limit: int = 1, + memory_limit: Optional[int] = None, + ) -> None: + """ + Partition the outputs of `input_deps` into n partitions. + + Parameters + ---------- + npartitions + The dataset would be split and distributed to `npartitions` partitions. + dimension + The unique partition dimension. Required if this is a nested partition. + nested, optional + `npartitions` subpartitions are created in each existing partition of `input_deps` if true. + + Examples + -------- + See unit tests in `test/test_partition.py`. For nested partition see `test_nested_partition`. + Why nested partition? See **5.1 Partial Partitioning** of [Advanced partitioning techniques for massively distributed computation](https://dl.acm.org/doi/10.1145/2213836.2213839). + """ + assert isinstance( + npartitions, int + ), f"npartitions is not an integer: {npartitions}" + assert dimension is None or re.match( + "[a-zA-Z0-9_]+", dimension + ), f"dimension has invalid format: {dimension}" + assert not ( + nested and dimension is None + ), f"nested partition should have dimension" + super().__init__( + ctx, input_deps, output_name, output_path, cpu_limit, 0, memory_limit + ) + self.npartitions = npartitions + self.dimension = ( + dimension if dimension is not None else PartitionInfo.default_dimension + ) + self.nested = nested + + def __str__(self) -> str: + return ( + super().__str__() + + f", npartitions={self.npartitions}, dimension={self.dimension}, nested={self.nested}" + ) + + @Node.task_factory + def create_producer_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> PartitionProducerTask: + raise NotImplementedError + + @Node.task_factory + def create_consumer_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> PartitionConsumerTask: + return PartitionConsumerTask(runtime_ctx, input_deps, partition_infos) + + @Node.task_factory + def create_merge_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> MergeDataSetsTask: + return MergeDataSetsTask(runtime_ctx, input_deps, partition_infos) + + @Node.task_factory + def create_split_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> SplitDataSetTask: + return SplitDataSetTask(runtime_ctx, input_deps, partition_infos) + + @property + def num_partitions(self) -> int: + return self.npartitions + + +class RepeatPartitionNode(PartitionNode): + """ + Create a new partition dimension by repeating the `input_deps`. This is always a nested partition. + """ + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + npartitions: int, + dimension: str, + cpu_limit: int = 1, + memory_limit: Optional[int] = None, + ) -> None: + super().__init__( + ctx, + input_deps, + npartitions, + dimension, + nested=True, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + + @Node.task_factory + def create_producer_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> RepeatPartitionProducerTask: + return RepeatPartitionProducerTask( + runtime_ctx, + input_deps, + partition_infos, + self.npartitions, + self.dimension, + self.cpu_limit, + self.memory_limit, + ) + + +class UserDefinedPartitionNode(PartitionNode): + """ + Distribute the output files or rows of `input_deps` into n partitions based on user code. + See unit test `test_user_defined_partition` in `test/test_partition.py`. + """ + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + npartitions: int, + dimension: str = None, + nested: bool = False, + cpu_limit: int = 1, + memory_limit: Optional[int] = None, + ) -> None: + super().__init__( + ctx, + input_deps, + npartitions, + dimension, + nested, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + + @Node.task_factory + def create_producer_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> UserDefinedPartitionProducerTask: + return UserDefinedPartitionProducerTask( + runtime_ctx, + input_deps, + partition_infos, + self.npartitions, + self.dimension, + self.partition, + self.cpu_limit, + self.memory_limit, + ) + + def partition(self, runtime_ctx: RuntimeContext, dataset: DataSet) -> List[DataSet]: + raise NotImplementedError + + +class UserPartitionedDataSourceNode(UserDefinedPartitionNode): + max_num_producer_tasks = 1 + + def __init__( + self, ctx: Context, partitioned_datasets: List[DataSet], dimension: str = None + ) -> None: + assert isinstance(partitioned_datasets, Iterable) and all( + isinstance(dataset, DataSet) for dataset in partitioned_datasets + ) + super().__init__( + ctx, + [DataSourceNode(ctx, dataset=None)], + len(partitioned_datasets), + dimension, + nested=False, + ) + self.partitioned_datasets = partitioned_datasets + + def partition(self, runtime_ctx: RuntimeContext, dataset: DataSet) -> List[DataSet]: + return self.partitioned_datasets + + +class EvenlyDistributedPartitionNode(PartitionNode): + """ + Evenly distribute the output files or rows of `input_deps` into n partitions. + """ + + max_num_producer_tasks = 1 + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + npartitions: int, + dimension: str = None, + nested: bool = False, + *, + partition_by_rows=False, + random_shuffle=False, + output_name: str = None, + output_path: str = None, + cpu_limit: int = 1, + memory_limit: Optional[int] = None, + ) -> None: + """ + Evenly distribute the output files or rows of `input_deps` into n partitions. + + Parameters + ---------- + partition_by_rows, optional + Evenly distribute rows instead of input files into `npartitions` partitions, by default distribute by files. + random_shuffle, optional + Random shuffle the list of paths or parquet row groups (if `partition_by_rows=True`) in input datasets. + """ + super().__init__( + ctx, + input_deps, + npartitions, + dimension, + nested, + output_name=output_name, + output_path=output_path, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + self.partition_by_rows = partition_by_rows and npartitions > 1 + self.random_shuffle = random_shuffle + + def __str__(self) -> str: + return ( + super().__str__() + + f", partition_by_rows={self.partition_by_rows}, random_shuffle={self.random_shuffle}" + ) + + @Node.task_factory + def create_producer_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ): + return EvenlyDistributedPartitionProducerTask( + runtime_ctx, + input_deps, + partition_infos, + self.npartitions, + self.dimension, + self.partition_by_rows, + self.random_shuffle, + self.cpu_limit, + self.memory_limit, + ) + + +class LoadPartitionedDataSetNode(PartitionNode): + """ + Load existing partitioned dataset (only parquet files are supported). + """ + + max_num_producer_tasks = 10 + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + npartitions: int, + dimension: str = None, + nested: bool = False, + data_partition_column: str = None, + hive_partitioning: bool = False, + cpu_limit: int = 1, + memory_limit: Optional[int] = None, + ) -> None: + assert ( + dimension or data_partition_column + ), f"Both 'dimension' and 'data_partition_column' are none or empty" + super().__init__( + ctx, + input_deps, + npartitions, + dimension or data_partition_column, + nested, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + self.data_partition_column = data_partition_column + self.hive_partitioning = hive_partitioning + + def __str__(self) -> str: + return ( + super().__str__() + + f", data_partition_column={self.data_partition_column}, hive_partitioning={self.hive_partitioning}" + ) + + @Node.task_factory + def create_producer_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ): + return LoadPartitionedDataSetProducerTask( + runtime_ctx, + input_deps, + partition_infos, + self.npartitions, + self.dimension, + self.data_partition_column, + self.hive_partitioning, + self.cpu_limit, + self.memory_limit, + ) + + +def DataSetPartitionNode( + ctx: Context, + input_deps: Tuple[Node, ...], + npartitions: int, + *, + partition_by_rows=False, + random_shuffle=False, + data_partition_column=None, +): + """ + Partition the outputs of `input_deps` into n partitions. + + Parameters + ---------- + npartitions + The number of partitions. The input files or rows would be evenly distributed to `npartitions` partitions. + partition_by_rows, optional + Evenly distribute rows instead of input files into `npartitions` partitions, by default distribute by files. + random_shuffle, optional + Random shuffle the list of paths or parquet row groups (if `partition_by_rows=True`) in input datasets. + data_partition_column, optional + Partition by files based on the partition keys stored in `data_partition_column` if specified. + Default column name used by `HashPartitionNode` is `DATA_PARTITION_COLUMN_NAME`. + + Examples + -------- + See unit test `test_load_partitioned_datasets` in `test/test_partition.py`. + """ + assert not ( + partition_by_rows and data_partition_column + ), "partition_by_rows and data_partition_column cannot be set at the same time" + if data_partition_column is None: + partition_node = EvenlyDistributedPartitionNode( + ctx, + input_deps, + npartitions, + dimension=None, + nested=False, + partition_by_rows=partition_by_rows, + random_shuffle=random_shuffle, + ) + if npartitions == 1: + return ConsolidateNode(ctx, partition_node, dimensions=[]) + else: + return partition_node + else: + return LoadPartitionedDataSetNode( + ctx, + input_deps, + npartitions, + dimension=data_partition_column, + nested=False, + data_partition_column=data_partition_column, + hive_partitioning=False, + ) + + +class HashPartitionNode(PartitionNode): + """ + Partition the outputs of `input_deps` into n partitions based on the hash values of `hash_columns`. + """ + + default_cpu_limit = 1 + default_memory_limit = None + default_data_partition_column = DATA_PARTITION_COLUMN_NAME + default_engine_type = "duckdb" + default_row_group_size = DEFAULT_ROW_GROUP_SIZE + max_num_producer_tasks = 1000 + enable_resource_boost = True + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + npartitions: int, + hash_columns: List[str] = None, + data_partition_column: str = None, + *, + dimension: str = None, + nested: bool = False, + engine_type: Literal["duckdb", "arrow"] = None, + random_shuffle: bool = False, + shuffle_only: bool = False, + drop_partition_column: bool = False, + use_parquet_writer: bool = False, + hive_partitioning: bool = False, + parquet_row_group_size: int = None, + parquet_dictionary_encoding=False, + parquet_compression="ZSTD", + parquet_compression_level=3, + output_name: str = None, + output_path: str = None, + cpu_limit: Optional[int] = None, + memory_limit: Optional[int] = None, + ) -> None: + """ + Construct a HashPartitionNode. See :func:`Node.__init__` to find comments on other parameters. + + Parameters + ---------- + npartitions + The number of hash partitions. The number of generated parquet files would be proportional to `npartitions`. + hash_columns + The hash values are computed from `hash_columns`. + data_partition_column, optional + The name of column used to store partition keys. + engine_type, optional + The underlying query engine for hash partition. + random_shuffle, optional + Ignore `hash_columns` and shuffle each row to a random partition if true. + shuffle_only, optional + Ignore `hash_columns` and shuffle each row to the partition specified in `data_partition_column` if true. + drop_partition_column, optional + Exclude `data_partition_column` in output if true. + use_parquet_writer, optional + Convert partition data to arrow tables and append with parquet writer if true. This creates less number of + intermediate files but makes partitioning slower. + hive_partitioning, optional + Use Hive partitioned write of duckdb if true. + parquet_row_group_size, optional + The number of rows stored in each row group of parquet file. + Large row group size provides more opportunities to compress the data. + Small row groups size could make filtering rows faster and achieve high concurrency. + See https://duckdb.org/docs/data/parquet/tips.html#selecting-a-row_group_size. + parquet_dictionary_encoding, optional + Specify if we should use dictionary encoding in general or only for some columns. + See `use_dictionary` in https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html. + """ + assert ( + not random_shuffle or not shuffle_only + ), f"random_shuffle and shuffle_only cannot be enabled at the same time" + assert ( + not shuffle_only or data_partition_column is not None + ), f"data_partition_column not specified for shuffle-only partitioning" + assert data_partition_column is None or re.match( + "[a-zA-Z0-9_]+", data_partition_column + ), f"data_partition_column has invalid format: {data_partition_column}" + assert engine_type in ( + None, + "duckdb", + "arrow", + ), f"unknown query engine type: {engine_type}" + data_partition_column = ( + data_partition_column or self.default_data_partition_column + ) + super().__init__( + ctx, + input_deps, + npartitions, + dimension or data_partition_column, + nested, + output_name, + output_path, + cpu_limit or self.default_cpu_limit, + memory_limit or self.default_memory_limit, + ) + self.hash_columns = ["random()"] if random_shuffle else hash_columns + self.data_partition_column = data_partition_column + self.engine_type = engine_type or self.default_engine_type + self.random_shuffle = random_shuffle + self.shuffle_only = shuffle_only + self.drop_partition_column = drop_partition_column + self.use_parquet_writer = use_parquet_writer + self.hive_partitioning = hive_partitioning and self.engine_type == "duckdb" + self.parquet_row_group_size = ( + parquet_row_group_size or self.default_row_group_size + ) + self.parquet_dictionary_encoding = parquet_dictionary_encoding + self.parquet_compression = parquet_compression + self.parquet_compression_level = parquet_compression_level + + def __str__(self) -> str: + return ( + super().__str__() + + f", hash_columns={self.hash_columns}, data_partition_column={self.data_partition_column}, engine_type={self.engine_type}, hive_partitioning={self.hive_partitioning}" + ) + + @Node.task_factory + def create_producer_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> HashPartitionTask: + return HashPartitionTask.create( + self.engine_type, + runtime_ctx, + input_deps, + partition_infos, + self.npartitions, + self.dimension, + self.hash_columns, + self.data_partition_column, + self.random_shuffle, + self.shuffle_only, + self.drop_partition_column, + self.use_parquet_writer, + self.hive_partitioning, + self.parquet_row_group_size, + self.parquet_dictionary_encoding, + self.parquet_compression, + self.parquet_compression_level, + self.output_name, + self.output_path, + self.cpu_limit, + self.memory_limit, + ) + + +class ShuffleNode(HashPartitionNode): + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + npartitions: int, + data_partition_column: str = None, + *, + dimension: str = None, + nested: bool = False, + engine_type: Literal["duckdb", "arrow"] = None, + use_parquet_writer: bool = False, + hive_partitioning: bool = False, + parquet_row_group_size: int = None, + parquet_dictionary_encoding=False, + parquet_compression="ZSTD", + parquet_compression_level=3, + output_name: str = None, + output_path: str = None, + cpu_limit: Optional[int] = None, + memory_limit: Optional[int] = None, + ) -> None: + super().__init__( + ctx, + input_deps, + npartitions, + hash_columns=None, + data_partition_column=data_partition_column, + dimension=dimension, + nested=nested, + engine_type=engine_type, + random_shuffle=False, + shuffle_only=True, + drop_partition_column=False, + use_parquet_writer=use_parquet_writer, + hive_partitioning=hive_partitioning, + parquet_row_group_size=parquet_row_group_size, + parquet_dictionary_encoding=parquet_dictionary_encoding, + parquet_compression=parquet_compression, + parquet_compression_level=parquet_compression_level, + output_name=output_name, + output_path=output_path, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + + +class RangePartitionNode(PartitionNode): + """ + Partition the outputs of `input_deps` into partitions defined by `split_points`. This node is not implemented yet. + """ + + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + split_points: List, + dimension: str = None, + nested: bool = False, + output_name: str = None, + output_path: str = None, + cpu_limit: int = 16, + memory_limit: int = 128 * GB, + ) -> None: + super().__init__( + ctx, + input_deps, + len(split_points) + 1, + dimension, + nested, + output_name, + output_path, + cpu_limit, + memory_limit, + ) + self.split_points = split_points + + +class ProjectionNode(Node): + """ + Select columns from output of an input node. + """ + + def __init__( + self, + ctx: Context, + input_dep: Node, + columns: List[str] = None, + generated_columns: List[Literal["filename", "file_row_number"]] = None, + union_by_name=None, + ) -> None: + """ + Construct a ProjectNode to select only the `columns` from output of `input_dep`. + + Parameters + ---------- + input_dep + The input node whose output would be selected. + columns, optional + The columns to be selected or created. Select all columns if set to `None`. + generated_columns + Auto generated columns, supported values: `filename`, `file_row_number`. + union_by_name, optional + Unify the columns of different files by name (see https://duckdb.org/docs/data/multiple_files/combining_schemas#union-by-name). + + Examples + -------- + First create an ArrowComputeNode to extract hosts from urls. + + .. code-block:: python + + class ParseUrl(ArrowComputeNode): + def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table: + assert input_tables[0].column_names == ["url"] # check url is the only column in table + urls, = input_tables[0].columns + hosts = [url.as_py().split("/", maxsplit=2)[0] for url in urls] + return arrow.Table.from_arrays([hosts, urls], names=["host", "url"]) + + Suppose there are several columns in output of `data_partitions`, + `ProjectionNode(..., ["url"])` selects the `url` column. + Then only this column would be loaded into arrow table when feeding data to `ParseUrl`. + + .. code-block:: python + + urls_with_host = ParseUrl(ctx, (ProjectionNode(ctx, data_partitions, ["url"]),)) + """ + columns = columns or ["*"] + generated_columns = generated_columns or [] + assert all( + col in GENERATED_COLUMNS for col in generated_columns + ), f"invalid values found in generated columns: {generated_columns}" + assert not ( + set(columns) & set(generated_columns) + ), f"columns {columns} and generated columns {generated_columns} share common columns" + super().__init__(ctx, [input_dep]) + self.columns = columns + self.generated_columns = generated_columns + self.union_by_name = union_by_name + + def __str__(self) -> str: + return ( + super().__str__() + + f", columns={self.columns}, generated_columns={self.generated_columns}, union_by_name={self.union_by_name}" + ) + + @Node.task_factory + def create_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> ProjectionTask: + return ProjectionTask( + runtime_ctx, + input_deps, + partition_infos, + self.columns, + self.generated_columns, + self.union_by_name, + ) + + +class LimitNode(SqlEngineNode): + """ + Limit the number of rows of the output of an input node. + """ + + def __init__(self, ctx: Context, input_dep: Node, limit: int) -> None: + """ + Construct a LimitNode to limit the number of rows of the output of `input_dep`. + + Parameters + ---------- + input_dep + The input node whose output would be limited. + limit + The number of rows to be limited. + """ + super().__init__(ctx, (input_dep,), f"select * from {{0}} limit {limit}") + self.limit = limit + + def __str__(self) -> str: + return super().__str__() + f", limit={self.limit}" + + @Node.task_factory + def create_merge_task( + self, + runtime_ctx: RuntimeContext, + input_deps: List[Task], + partition_infos: List[PartitionInfo], + ) -> MergeDataSetsTask: + return MergeDataSetsTask(runtime_ctx, input_deps, partition_infos) + + +T = TypeVar("T") + + +class LogicalPlanVisitor(Generic[T]): + """ + Visit the nodes of a logcial plan in depth-first order. + """ + + def visit(self, node: Node, depth: int = 0) -> T: + """ + Visit a node depending on its type. + If the method for the node type is not implemented, call `generic_visit`. + """ + if isinstance(node, DataSourceNode): + return self.visit_data_source_node(node, depth) + elif isinstance(node, DataSinkNode): + return self.visit_data_sink_node(node, depth) + elif isinstance(node, RootNode): + return self.visit_root_node(node, depth) + elif isinstance(node, UnionNode): + return self.visit_union_node(node, depth) + elif isinstance(node, ConsolidateNode): + return self.visit_consolidate_node(node, depth) + elif isinstance(node, PartitionNode): + return self.visit_partition_node(node, depth) + elif isinstance(node, PythonScriptNode): + return self.visit_python_script_node(node, depth) + elif isinstance(node, ArrowComputeNode): + return self.visit_arrow_compute_node(node, depth) + elif isinstance(node, ArrowStreamNode): + return self.visit_arrow_stream_node(node, depth) + elif isinstance(node, LimitNode): + return self.visit_limit_node(node, depth) + elif isinstance(node, SqlEngineNode): + return self.visit_query_engine_node(node, depth) + elif isinstance(node, ProjectionNode): + return self.visit_projection_node(node, depth) + else: + raise Exception(f"Unknown node type: {node}") + + def generic_visit(self, node: Node, depth: int) -> T: + """This visitor calls visit() on all children of the node.""" + for dep in node.input_deps: + self.visit(dep, depth + 1) + + def visit_data_source_node(self, node: DataSourceNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_data_sink_node(self, node: DataSinkNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_root_node(self, node: RootNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_union_node(self, node: UnionNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_consolidate_node(self, node: ConsolidateNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_partition_node(self, node: PartitionNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_python_script_node(self, node: PythonScriptNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_arrow_compute_node(self, node: ArrowComputeNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_arrow_stream_node(self, node: ArrowStreamNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_projection_node(self, node: ProjectionNode, depth: int) -> T: + return self.generic_visit(node, depth) + + def visit_limit_node(self, node: LimitNode, depth: int) -> T: + return self.generic_visit(node, depth) + + +class LogicalPlan(object): + """ + The logical plan that defines a directed acyclic computation graph. + """ + + def __init__(self, ctx: Context, root_node: Node) -> None: + self.ctx = ctx + self.root_node = root_node + + def __str__(self) -> str: + return self.explain_str() + + def explain_str(self) -> str: + """ + Return a string that shows the structure of the logical plan. + """ + visited = set() + + def to_str(node: Node, depth: int = 0) -> List[str]: + lines = [" " * depth + str(node) + ", file= " + node.location] + if node.id in visited: + return lines + [" " * depth + " (omitted ...)"] + visited.add(node.id) + lines += [ + " " * depth + f" | {name}: {stats}" + for name, stats in node.perf_stats.items() + ] + for dep in node.input_deps: + lines.extend(to_str(dep, depth + 1)) + return lines + + return os.linesep.join(to_str(self.root_node)) + + def graph(self) -> Digraph: + """ + Return a graphviz graph that shows the structure of the logical plan. + """ + dot = Digraph(comment="smallpond") + for node in self.nodes.values(): + dot.node(str(node.id), repr(node)) + for dep in node.input_deps: + dot.edge(str(dep.id), str(node.id)) + return dot + + @property + def nodes(self) -> Dict[NodeId, Node]: + nodes = {} + + def collect_nodes(node: Node): + if node.id in nodes: + return + nodes[node.id] = node + for dep in node.input_deps: + collect_nodes(dep) + + collect_nodes(self.root_node) + return nodes diff --git a/smallpond/logical/optimizer.py b/smallpond/logical/optimizer.py new file mode 100644 index 0000000..c84a0b7 --- /dev/null +++ b/smallpond/logical/optimizer.py @@ -0,0 +1,59 @@ +from smallpond.execution.task import * +from smallpond.logical.node import * + + +class Optimizer(LogicalPlanVisitor[Node]): + """ + Optimize the logical plan. + """ + + def __init__(self, exclude_nodes: Set[Node]): + self.exclude_nodes = exclude_nodes + """A set of nodes that will not be optimized.""" + self.optimized_node_map: Dict[Node, Node] = {} + """A map from original node to optimized node.""" + + def visit(self, node: Node, depth: int = 0) -> Node: + # stop recursion if the node is excluded + if node in self.exclude_nodes: + return node + # memoize the optimized node + if node in self.optimized_node_map: + return self.optimized_node_map[node] + optimized_node = super().visit(node, depth) + self.optimized_node_map[node] = optimized_node + return optimized_node + + def generic_visit(self, node: Node, depth: int) -> Node: + # by default, recursively optimize the input deps + new_node = copy.copy(node) + new_node.input_deps = [self.visit(dep, depth + 1) for dep in node.input_deps] + return new_node + + def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> Node: + # fuse consecutive SqlEngineNodes + if len(node.input_deps) == 1 and isinstance( + child := self.visit(node.input_deps[0], depth + 1), SqlEngineNode + ): + fused = copy.copy(node) + fused.input_deps = child.input_deps + fused.udfs = node.udfs + child.udfs + fused.cpu_limit = max(node.cpu_limit, child.cpu_limit) + fused.gpu_limit = max(node.gpu_limit, child.gpu_limit) + fused.memory_limit = ( + max(node.memory_limit, child.memory_limit) + if node.memory_limit is not None and child.memory_limit is not None + else node.memory_limit or child.memory_limit + ) + # merge the sql queries + # example: + # ``` + # child.sql_queries = ["select * from {0}"] + # node.sql_queries = ["select a, b from {0}"] + # fused.sql_queries = ["select a, b from (select * from {0})"] + # ``` + fused.sql_queries = child.sql_queries[:-1] + [ + query.format(f"({child.sql_queries[-1]})") for query in node.sql_queries + ] + return fused + return self.generic_visit(node, depth) diff --git a/smallpond/logical/planner.py b/smallpond/logical/planner.py new file mode 100644 index 0000000..e641ed4 --- /dev/null +++ b/smallpond/logical/planner.py @@ -0,0 +1,348 @@ +from smallpond.execution.task import * +from smallpond.logical.node import * + +TaskGroup = List[Task] + + +class Planner(LogicalPlanVisitor[TaskGroup]): + """ + Create an execution plan (tasks) from a logical plan (nodes). + """ + + def __init__(self, runtime_ctx: RuntimeContext) -> None: + self.runtime_ctx = runtime_ctx + self.node_to_tasks: Dict[Node, TaskGroup] = {} + + @logger.catch(reraise=True, message="failed to build computation graph") + def create_exec_plan( + self, logical_plan: LogicalPlan, manifest_only_final_results=True + ) -> ExecutionPlan: + logical_plan = copy.deepcopy(logical_plan) + + # if --output_path is specified, copy files to the output path + # otherwise, create manifest files only + sink_type = ( + "copy" if self.runtime_ctx.final_output_path is not None else "manifest" + ) + final_sink_type = ( + "copy" + if self.runtime_ctx.final_output_path is not None + else "manifest" if manifest_only_final_results else "link" + ) + + # create DataSinkNode for each named output node (same name share the same sink node) + nodes_groupby_output_name: Dict[str, List[Node]] = defaultdict(list) + for node in logical_plan.nodes.values(): + if node.output_name is not None: + if node.output_name in nodes_groupby_output_name: + warnings.warn( + f"{node} has duplicate output name: {node.output_name}" + ) + nodes_groupby_output_name[node.output_name].append(node) + sink_nodes = {} # { output_name: DataSinkNode } + for output_name, nodes in nodes_groupby_output_name.items(): + output_path = os.path.join( + self.runtime_ctx.final_output_path or self.runtime_ctx.output_root, + output_name, + ) + sink_nodes[output_name] = DataSinkNode( + logical_plan.ctx, tuple(nodes), output_path, type=sink_type + ) + + # create DataSinkNode for root node + # XXX: special case optimization to avoid copying files twice + # if root node is DataSetPartitionNode(npartitions=1), and all its input nodes are named, create manifest files instead of copying files. + if ( + isinstance(logical_plan.root_node, ConsolidateNode) + and len(logical_plan.root_node.input_deps) == 1 + and isinstance( + partition_node := logical_plan.root_node.input_deps[0], + EvenlyDistributedPartitionNode, + ) + and all(node.output_name is not None for node in partition_node.input_deps) + ): + sink_nodes["FinalResults"] = DataSinkNode( + logical_plan.ctx, + tuple( + sink_nodes[node.output_name] for node in partition_node.input_deps + ), + output_path=os.path.join( + self.runtime_ctx.final_output_path or self.runtime_ctx.output_root, + "FinalResults", + ), + type="manifest", + is_final_node=True, + ) + # if root node also has output_name, create manifest files instead of copying files. + elif (output_name := logical_plan.root_node.output_name) is not None: + sink_nodes["FinalResults"] = DataSinkNode( + logical_plan.ctx, + (sink_nodes[output_name],), + output_path=os.path.join( + self.runtime_ctx.final_output_path or self.runtime_ctx.output_root, + "FinalResults", + ), + type="manifest", + is_final_node=True, + ) + else: + # normal case + sink_nodes["FinalResults"] = DataSinkNode( + logical_plan.ctx, + (logical_plan.root_node,), + output_path=os.path.join( + self.runtime_ctx.final_output_path or self.runtime_ctx.output_root, + "FinalResults", + ), + type=final_sink_type, + is_final_node=True, + ) + + # assemble sink nodes as new root node + root_node = RootNode(logical_plan.ctx, tuple(sink_nodes.values())) + logical_plan = LogicalPlan(logical_plan.ctx, root_node) + + # generate tasks + [root_task] = self.visit(root_node) + # print logical plan with the generated runtime tasks + logger.info(f"logical plan:{os.linesep}{str(logical_plan)}") + + exec_plan = ExecutionPlan(self.runtime_ctx, root_task, logical_plan) + + return exec_plan + + def visit(self, node: Node, depth: int = 0) -> TaskGroup: + # memoize the tasks + if node in self.node_to_tasks: + return self.node_to_tasks[node] + retvals = super().visit(node, depth) + self.node_to_tasks[node] = retvals + return retvals + + def visit_data_source_node(self, node: DataSourceNode, depth: int) -> TaskGroup: + assert not node.input_deps, f"data source should be leaf node: {node}" + return [node.create_task(self.runtime_ctx, [], [PartitionInfo()])] + + def visit_data_sink_node(self, node: DataSinkNode, depth: int) -> TaskGroup: + all_input_deps = [ + task for dep in node.input_deps for task in self.visit(dep, depth + 1) + ] + return [node.create_task(self.runtime_ctx, all_input_deps, [PartitionInfo()])] + + def visit_root_node(self, node: RootNode, depth: int) -> TaskGroup: + all_input_deps = [ + task for dep in node.input_deps for task in self.visit(dep, depth + 1) + ] + return [RootTask(self.runtime_ctx, all_input_deps, [PartitionInfo()])] + + def visit_union_node(self, node: UnionNode, depth: int) -> TaskGroup: + all_input_deps = [ + task for dep in node.input_deps for task in self.visit(dep, depth + 1) + ] + unique_partition_dims = set(task.partition_dims for task in all_input_deps) + assert ( + len(unique_partition_dims) == 1 + ), f"cannot union partitions with different dimensions: {unique_partition_dims}" + return all_input_deps + + def visit_consolidate_node(self, node: ConsolidateNode, depth: int) -> TaskGroup: + input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps] + assert ( + len(input_deps_taskgroups) == 1 + ), f"consolidate node only accepts one input node, but found: {input_deps_taskgroups}" + unique_partition_dims = set( + task.partition_dims for task in input_deps_taskgroups[0] + ) + assert ( + len(unique_partition_dims) == 1 + ), f"cannot consolidate partitions with different dimensions: {unique_partition_dims}" + existing_dimensions = set(unique_partition_dims.pop()) + assert ( + node.dimensions.intersection(existing_dimensions) == node.dimensions + ), f"cannot found some of {node.dimensions} in {existing_dimensions}" + # group tasks by partitions + input_deps_groupby_partitions: Dict[Tuple, List[Task]] = defaultdict(list) + for task in input_deps_taskgroups[0]: + partition_infos = tuple( + info + for info in task.partition_infos + if info.dimension in node.dimensions + ) + input_deps_groupby_partitions[partition_infos].append(task) + return [ + node.create_task(self.runtime_ctx, input_deps, partition_infos) + for partition_infos, input_deps in input_deps_groupby_partitions.items() + ] + + def visit_partition_node(self, node: PartitionNode, depth: int) -> TaskGroup: + all_input_deps = [ + task for dep in node.input_deps for task in self.visit(dep, depth + 1) + ] + unique_partition_dims = set(task.partition_dims for task in all_input_deps) + assert ( + len(unique_partition_dims) == 1 + ), f"cannot partition input_deps with different dimensions: {unique_partition_dims}" + + if node.nested: + assert ( + node.dimension not in unique_partition_dims + ), f"found duplicate partition dimension '{node.dimension}', existing dimensions: {unique_partition_dims}" + assert ( + len(all_input_deps) * node.npartitions + <= node.max_card_of_producers_x_consumers + ), f"{len(all_input_deps)=} * {node.npartitions=} > {node.max_card_of_producers_x_consumers=}" + producer_tasks = [ + node.create_producer_task( + self.runtime_ctx, [task], task.partition_infos + ) + for task in all_input_deps + ] + return [ + node.create_consumer_task( + self.runtime_ctx, + [producer], + list(producer.partition_infos) + + [PartitionInfo(partition_idx, node.npartitions, node.dimension)], + ) + for producer in producer_tasks + for partition_idx in range(node.npartitions) + ] + else: + max_num_producer_tasks = min( + node.max_num_producer_tasks, + math.ceil(node.max_card_of_producers_x_consumers / node.npartitions), + ) + num_parallel_tasks = ( + 2 + * self.runtime_ctx.num_executors + * math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit) + ) + num_producer_tasks = max(1, min(max_num_producer_tasks, num_parallel_tasks)) + if len(all_input_deps) < num_producer_tasks: + merge_datasets_task = node.create_merge_task( + self.runtime_ctx, all_input_deps, [PartitionInfo()] + ) + split_dataset_tasks = [ + node.create_split_task( + self.runtime_ctx, + [merge_datasets_task], + [PartitionInfo(partition_idx, num_producer_tasks)], + ) + for partition_idx in range(num_producer_tasks) + ] + else: + split_dataset_tasks = [ + node.create_merge_task( + self.runtime_ctx, + tasks, + [PartitionInfo(partition_idx, num_producer_tasks)], + ) + for partition_idx, tasks in enumerate( + split_into_rows(all_input_deps, num_producer_tasks) + ) + ] + producer_tasks = [ + node.create_producer_task( + self.runtime_ctx, [split_dataset], split_dataset.partition_infos + ) + for split_dataset in split_dataset_tasks + ] + return [ + node.create_consumer_task( + self.runtime_ctx, + producer_tasks, + [ + PartitionInfo(), + PartitionInfo(partition_idx, node.npartitions, node.dimension), + ], + ) + for partition_idx in range(node.npartitions) + ] + + def broadcast_input_deps(self, node: Node, depth: int): + # if no input deps, return a single partition + if not node.input_deps: + yield [], (PartitionInfo(),) + return + + input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps] + input_deps_most_ndims = max( + input_deps_taskgroups, + key=lambda taskgroup: ( + len(taskgroup[0].partition_dims), + max(t.partition_infos for t in taskgroup), + ), + ) + input_deps_maps = [ + ( + taskgroup[0].partition_dims, + dict((t.partition_infos, t) for t in taskgroup), + ) + for taskgroup in input_deps_taskgroups + ] + + for main_input in input_deps_most_ndims: + input_deps = [] + for input_deps_dims, input_deps_map in input_deps_maps: + partition_infos = tuple( + info + for info in main_input.partition_infos + if info.dimension in input_deps_dims + ) + input_dep = input_deps_map.get(partition_infos, None) + assert ( + input_dep is not None + ), f"""the partition dimensions or npartitions of inputs {node.input_deps} of {repr(node)} are not compatible + cannot match {main_input.partition_infos} against any of {input_deps_map.keys()}""" + input_deps.append(input_dep) + yield input_deps, main_input.partition_infos + + def visit_python_script_node(self, node: PythonScriptNode, depth: int) -> TaskGroup: + return [ + node.create_task(self.runtime_ctx, input_deps, partition_infos) + for input_deps, partition_infos in self.broadcast_input_deps(node, depth) + ] + + def visit_arrow_compute_node(self, node: ArrowComputeNode, depth: int) -> TaskGroup: + return [ + node.create_task(self.runtime_ctx, input_deps, partition_infos) + for input_deps, partition_infos in self.broadcast_input_deps(node, depth) + ] + + def visit_arrow_stream_node(self, node: ArrowStreamNode, depth: int) -> TaskGroup: + return [ + node.create_task(self.runtime_ctx, input_deps, partition_infos) + for input_deps, partition_infos in self.broadcast_input_deps(node, depth) + ] + + def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> TaskGroup: + return [ + node.create_task(self.runtime_ctx, input_deps, partition_infos) + for input_deps, partition_infos in self.broadcast_input_deps(node, depth) + ] + + def visit_projection_node(self, node: ProjectionNode, depth: int) -> TaskGroup: + assert ( + len(node.input_deps) == 1 + ), f"projection node only accepts one input node, but found: {node.input_deps}" + return [ + node.create_task(self.runtime_ctx, [task], task.partition_infos) + for task in self.visit(node.input_deps[0], depth + 1) + ] + + def visit_limit_node(self, node: LimitNode, depth: int) -> TaskGroup: + assert ( + len(node.input_deps) == 1 + ), f"limit node only accepts one input node, but found: {node.input_deps}" + all_input_deps = self.visit(node.input_deps[0], depth + 1) + partial_limit_tasks = [ + node.create_task(self.runtime_ctx, [task], task.partition_infos) + for task in all_input_deps + ] + merge_task = node.create_merge_task( + self.runtime_ctx, partial_limit_tasks, [PartitionInfo()] + ) + global_limit_task = node.create_task( + self.runtime_ctx, [merge_task], merge_task.partition_infos + ) + return [global_limit_task] diff --git a/smallpond/logical/udf.py b/smallpond/logical/udf.py new file mode 100644 index 0000000..f8a8b7a --- /dev/null +++ b/smallpond/logical/udf.py @@ -0,0 +1,269 @@ +import importlib +import os.path +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, List, Optional, Union + +import duckdb +import duckdb.typing + + +class UDFType(Enum): + """ + A wrapper of duckdb.typing.DuckDBPyType + + See https://duckdb.org/docs/api/python/types.html + """ + + SQLNULL = 1 + BOOLEAN = 2 + TINYINT = 3 + UTINYINT = 4 + SMALLINT = 5 + USMALLINT = 6 + INTEGER = 7 + UINTEGER = 8 + BIGINT = 9 + UBIGINT = 10 + HUGEINT = 11 + UUID = 12 + FLOAT = 13 + DOUBLE = 14 + DATE = 15 + TIMESTAMP = 16 + TIMESTAMP_MS = 17 + TIMESTAMP_NS = 18 + TIMESTAMP_S = 19 + TIME = 20 + TIME_TZ = 21 + TIMESTAMP_TZ = 22 + VARCHAR = 23 + BLOB = 24 + BIT = 25 + INTERVAL = 26 + + def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType: + if self == UDFType.SQLNULL: + return duckdb.typing.SQLNULL + elif self == UDFType.BOOLEAN: + return duckdb.typing.BOOLEAN + elif self == UDFType.TINYINT: + return duckdb.typing.TINYINT + elif self == UDFType.UTINYINT: + return duckdb.typing.UTINYINT + elif self == UDFType.SMALLINT: + return duckdb.typing.SMALLINT + elif self == UDFType.USMALLINT: + return duckdb.typing.USMALLINT + elif self == UDFType.INTEGER: + return duckdb.typing.INTEGER + elif self == UDFType.UINTEGER: + return duckdb.typing.UINTEGER + elif self == UDFType.BIGINT: + return duckdb.typing.BIGINT + elif self == UDFType.UBIGINT: + return duckdb.typing.UBIGINT + elif self == UDFType.HUGEINT: + return duckdb.typing.HUGEINT + elif self == UDFType.UUID: + return duckdb.typing.UUID + elif self == UDFType.FLOAT: + return duckdb.typing.FLOAT + elif self == UDFType.DOUBLE: + return duckdb.typing.DOUBLE + elif self == UDFType.DATE: + return duckdb.typing.DATE + elif self == UDFType.TIMESTAMP: + return duckdb.typing.TIMESTAMP + elif self == UDFType.TIMESTAMP_MS: + return duckdb.typing.TIMESTAMP_MS + elif self == UDFType.TIMESTAMP_NS: + return duckdb.typing.TIMESTAMP_NS + elif self == UDFType.TIMESTAMP_S: + return duckdb.typing.TIMESTAMP_S + elif self == UDFType.TIME: + return duckdb.typing.TIME + elif self == UDFType.TIME_TZ: + return duckdb.typing.TIME_TZ + elif self == UDFType.TIMESTAMP_TZ: + return duckdb.typing.TIMESTAMP_TZ + elif self == UDFType.VARCHAR: + return duckdb.typing.VARCHAR + elif self == UDFType.BLOB: + return duckdb.typing.BLOB + elif self == UDFType.BIT: + return duckdb.typing.BIT + elif self == UDFType.INTERVAL: + return duckdb.typing.INTERVAL + return None + + +class UDFStructType: + """ + A wrapper of duckdb.struct_type, eg: UDFStructType({'host': 'VARCHAR', 'path:' 'VARCHAR', 'query': 'VARCHAR'}) + + See https://duckdb.org/docs/api/python/types.html#a-field_one-b-field_two--n-field_n + """ + + def __init__(self, fields: Union[Dict[str, str], List[str]]) -> None: + self.fields = fields + + def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType: + return duckdb.struct_type(self.fields) + + +class UDFListType: + """ + A wrapper of duckdb.list_type, eg: UDFListType(UDFType.INTEGER) + + See https://duckdb.org/docs/api/python/types.html#listchild_type + """ + + def __init__(self, child) -> None: + self.child = child + + def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType: + return duckdb.list_type(self.child.to_duckdb_type()) + + +class UDFMapType: + """ + A wrapper of duckdb.map_type, eg: UDFMapType(UDFType.VARCHAR, UDFType.INTEGER) + + See https://duckdb.org/docs/api/python/types.html#dictkey_type-value_type + """ + + def __init__(self, key, value) -> None: + self.key = key + self.value = value + + def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType: + return duckdb.map_type(self.key.to_duckdb_type(), self.value.to_duckdb_type()) + + +class UDFAnyParameters: + """ + Accept parameters of any types in UDF. + """ + + def __init__(self) -> None: + pass + + def to_duckdb_type(self) -> duckdb.typing.DuckDBPyType: + return None + + +class UDFContext(object): + def bind(self, conn: duckdb.DuckDBPyConnection): + raise NotImplementedError + + +class PythonUDFContext(UDFContext): + def __init__( + self, + name: str, + func: Callable, + params: Optional[List[UDFType]], + return_type: Optional[UDFType], + use_arrow_type=False, + ): + self.name = name + self.func = func + self.params = params + self.return_type = return_type + self.use_arrow_type = use_arrow_type + + def __str__(self) -> str: + return f"{self.name}@{self.func}" + + __repr__ = __str__ + + def bind(self, conn: duckdb.DuckDBPyConnection): + if isinstance(self.params, UDFAnyParameters): + duckdb_args = self.params.to_duckdb_type() + else: + duckdb_args = [arg.to_duckdb_type() for arg in self.params] + conn.create_function( + self.name, + self.func, + duckdb_args, + self.return_type.to_duckdb_type(), + type=("arrow" if self.use_arrow_type else "native"), + ) + # logger.debug(f"created python udf: {self.name}({self.params}) -> {self.return_type}") + + +class ExternalModuleContext(UDFContext): + def __init__(self, name: str, module_path: str) -> None: + self.name = name + self.module_path = module_path + + def __str__(self) -> str: + return f"{self.name}@{self.module_path}" + + __repr__ = __str__ + + def bind(self, conn: duckdb.DuckDBPyConnection): + module_name, _ = os.path.splitext(os.path.basename(self.module_path)) + spec = importlib.util.spec_from_file_location(module_name, self.module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module.create_duckdb_udfs(conn) + # logger.debug(f"loaded external module at {self.module_path}, udf functions: {module.udfs}") + + +class DuckDbExtensionContext(UDFContext): + def __init__(self, name: str, extension_path: str) -> None: + self.name = name + self.extension_path = extension_path + + def __str__(self) -> str: + return f"{self.name}@{self.extension_path}" + + __repr__ = __str__ + + def bind(self, conn: duckdb.DuckDBPyConnection): + conn.load_extension(self.extension_path) + # logger.debug(f"loaded duckdb extension at {self.extension_path}") + + +@dataclass +class UserDefinedFunction: + """ + A python user-defined function. + """ + + name: str + func: Callable + params: List[UDFType] + return_type: UDFType + use_arrow_type: bool + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +def udf( + params: List[UDFType], + return_type: UDFType, + use_arrow_type: bool = False, + name: Optional[str] = None, +) -> Callable[[Callable], UserDefinedFunction]: + """ + A decorator to define a Python UDF. + + Examples + -------- + ``` + @udf(params=[UDFType.INTEGER, UDFType.INTEGER], return_type=UDFType.INTEGER) + def gcd(a: int, b: int) -> int: + while b: + a, b = b, a % b + return a + ``` + + See `Context.create_function` for more details. + """ + return lambda func: UserDefinedFunction( + name or func.__name__, func, params, return_type, use_arrow_type + ) diff --git a/smallpond/platform/__init__.py b/smallpond/platform/__init__.py new file mode 100644 index 0000000..57dfbbe --- /dev/null +++ b/smallpond/platform/__init__.py @@ -0,0 +1,36 @@ +from typing import Optional + +from smallpond.platform.base import Platform +from smallpond.platform.mpi import MPI + +_platforms = { + "mpi": MPI, +} + + +def get_platform(name: Optional[str] = None) -> Platform: + """ + Get a platform by name. + If name is not specified, try to get an available platform. + """ + if name is None: + for platform in _platforms.values(): + if platform.is_available(): + return platform() + return Platform() + + if name in _platforms: + return _platforms[name]() + + # load platform from a custom python module + from importlib import import_module + + module = import_module(name) + + # find the exact class that inherits from Platform + for name in dir(module): + cls = getattr(module, name) + if isinstance(cls, type) and issubclass(cls, Platform): + return cls() + + raise RuntimeError(f"no Platform class found in module: {name}") diff --git a/smallpond/platform/base.py b/smallpond/platform/base.py new file mode 100644 index 0000000..93caf69 --- /dev/null +++ b/smallpond/platform/base.py @@ -0,0 +1,109 @@ +import os +import signal +import subprocess +import uuid +from datetime import datetime +from typing import List, Optional + + +class Platform: + """ + Base class for all platforms. + """ + + @staticmethod + def is_available() -> bool: + """ + Whether the platform is available in the current environment. + """ + return False + + @classmethod + def __str__(cls) -> str: + return cls.__name__ + + def start_job( + self, + num_nodes: int, + entrypoint: str, + args: List[str], + envs: dict = {}, + extra_opts: dict = {}, + ) -> List[str]: + """ + Start a job on the platform. + Return the job ids. + """ + pids = [] + for _ in range(num_nodes): + popen = subprocess.Popen( + ["python", entrypoint, *args], + env={**os.environ, **envs}, + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + ) + pids.append(str(popen.pid)) + return pids + + def stop_job(self, pid: str) -> None: + """ + Stop the job. + """ + os.kill(int(pid), signal.SIGKILL) + + @staticmethod + def default_job_id() -> str: + """ + Return the default job id. + """ + return str(uuid.uuid4()) + + @staticmethod + def default_job_time() -> datetime: + """ + Return the default job time. + """ + return datetime.now() + + @staticmethod + def default_data_root() -> Optional[str]: + """ + Get the default data root for the platform. + If the platform does not have a default data root, return None. + """ + from loguru import logger + + default = os.path.expanduser("~/.smallpond/data") + logger.warning(f"data root is not set, using default: {default}") + return default + + @staticmethod + def default_share_log_analytics() -> bool: + """ + Whether to share log analytics by default. + """ + return False + + @staticmethod + def shared_log_root() -> Optional[str]: + """ + Return the shared log root. + """ + return None + + @staticmethod + def grafana_homepath() -> Optional[str]: + """ + Return the homepath of grafana. + """ + homebrew_installed_homepath = "/opt/homebrew/opt/grafana/share/grafana" + if os.path.exists(homebrew_installed_homepath): + return homebrew_installed_homepath + return None + + @staticmethod + def default_memory_allocator() -> str: + """ + Get the default memory allocator for the platform. + """ + return "system" diff --git a/smallpond/platform/mpi.py b/smallpond/platform/mpi.py new file mode 100644 index 0000000..aafd49d --- /dev/null +++ b/smallpond/platform/mpi.py @@ -0,0 +1,39 @@ +import shutil +import subprocess +from typing import List +from loguru import logger + +from smallpond.platform.base import Platform + + +class MPI(Platform): + """ + MPI platform. + """ + + @staticmethod + def is_available() -> bool: + return shutil.which("mpirun") is not None + + def start_job( + self, + num_nodes: int, + entrypoint: str, + args: List[str], + envs: dict = {}, + extra_opts: dict = {}, + ) -> List[str]: + mpirun_cmd = ["mpirun", "-n", str(num_nodes)] + for key, value in envs.items(): + mpirun_cmd += ["-x", f"{key}={value}"] + mpirun_cmd += ["python", entrypoint] + args + + logger.debug(f"start job with command: {' '.join(mpirun_cmd)}") + subprocess.Popen( + mpirun_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + text=True, + ) + + return [] diff --git a/smallpond/session.py b/smallpond/session.py new file mode 100644 index 0000000..171be42 --- /dev/null +++ b/smallpond/session.py @@ -0,0 +1,388 @@ +""" +This module defines the `Session` class, which is the entry point for smallpond interactive mode. +""" + +from __future__ import annotations + +import json +import os +import shutil +import socket +import subprocess +import sys +import threading +from dataclasses import dataclass +from datetime import datetime +from typing import Optional, Tuple + +import ray +from graphviz import Digraph +import graphviz.backend.execute +from loguru import logger + +import smallpond +from smallpond.execution.task import JobId, RuntimeContext +from smallpond.logical.node import Context +from smallpond.platform import Platform, get_platform + + +class SessionBase: + def __init__(self, **kwargs): + """ + Create a smallpond environment. + """ + super().__init__() + self._ctx = Context() + self.config, self._platform = Config.from_args_and_env(**kwargs) + + # construct runtime context for Tasks + runtime_ctx = RuntimeContext( + job_id=JobId(hex=self.config.job_id), + job_time=self.config.job_time, + data_root=self.config.data_root, + num_executors=self.config.num_executors, + bind_numa_node=self.config.bind_numa_node, + shared_log_root=self._platform.shared_log_root(), + ) + self._runtime_ctx = runtime_ctx + + # if `spawn` is specified, spawn a job and exit + if os.environ.get("SP_SPAWN") == "1": + self._spawn_self() + exit(0) + + self._runtime_ctx.initialize(exec_id=socket.gethostname()) + logger.info(f"using platform: {self._platform}") + logger.info(f"command-line arguments: {' '.join(sys.argv)}") + logger.info(f"session config: {self.config}") + + def setup_worker(): + runtime_ctx._init_logs( + exec_id=socket.gethostname(), capture_stdout_stderr=True + ) + + if self.config.ray_address is None: + # find the memory allocator + if self.config.memory_allocator == "system": + malloc_path = "" + elif self.config.memory_allocator == "jemalloc": + malloc_path = shutil.which("libjemalloc.so.2") + assert malloc_path is not None, "jemalloc is not installed" + elif self.config.memory_allocator == "mimalloc": + malloc_path = shutil.which("libmimalloc.so.2.1") + assert malloc_path is not None, "mimalloc is not installed" + else: + raise ValueError( + f"unsupported memory allocator: {self.config.memory_allocator}" + ) + memory_purge_delay = 10000 + + # start ray head node + # for ray head node to access grafana + os.environ["RAY_GRAFANA_HOST"] = "http://localhost:8122" + self._ray_address = ray.init( + # start a new local cluster + address="local", + # disable local CPU resource if not running on localhost + num_cpus=( + 0 + if self.config.num_executors > 0 + else self._runtime_ctx.usable_cpu_count + ), + # set the memory limit to the available memory size + _memory=self._runtime_ctx.usable_memory_size, + # setup logging for workers + log_to_driver=False, + runtime_env={ + "worker_process_setup_hook": setup_worker, + "env_vars": { + "LD_PRELOAD": malloc_path, + "MALLOC_CONF": f"percpu_arena:percpu,background_thread:true,metadata_thp:auto,dirty_decay_ms:{memory_purge_delay},muzzy_decay_ms:{memory_purge_delay},oversize_threshold:0,lg_tcache_max:16", + "MIMALLOC_PURGE_DELAY": f"{memory_purge_delay}", + "ARROW_DEFAULT_MEMORY_POOL": self.config.memory_allocator, + "ARROW_IO_THREADS": "2", + "OMP_NUM_THREADS": "2", + "POLARS_MAX_THREADS": "2", + "NUMEXPR_MAX_THREADS": "2", + "RAY_PROFILING": "1", + }, + }, + dashboard_host="0.0.0.0", + dashboard_port=8008, + # for prometheus to scrape metrics + _metrics_export_port=8080, + ).address_info["gcs_address"] + logger.info(f"started ray cluster at {self._ray_address}") + + self._prometheus_process = self._start_prometheus() + self._grafana_process = self._start_grafana() + else: + self._ray_address = self.config.ray_address + self._prometheus_process = None + self._grafana_process = None + logger.info(f"connected to ray cluster at {self._ray_address}") + + # start workers + if self.config.num_executors > 0: + # override configs + kwargs["job_id"] = self.config.job_id + + self._job_names = self._platform.start_job( + self.config.num_executors, + entrypoint=os.path.join(os.path.dirname(__file__), "worker.py"), + args=[ + f"--ray_address={self._ray_address}", + f"--log_dir={self._runtime_ctx.log_root}", + *(["--bind_numa_node"] if self.config.bind_numa_node else []), + ], + extra_opts=kwargs, + ) + else: + self._job_names = [] + + # spawn a thread to periodically dump metrics + self._stop_event = threading.Event() + self._dump_thread = threading.Thread( + name="dump_thread", target=self._dump_periodically, daemon=True + ) + self._dump_thread.start() + + def shutdown(self): + """ + Shutdown the session. + """ + logger.info("shutting down session") + self._stop_event.set() + + # stop all jobs + for job_name in self._job_names: + self._platform.stop_job(job_name) + self._job_names = [] + + self._dump_thread.join() + if self.config.ray_address is None: + ray.shutdown() + if self._prometheus_process is not None: + self._prometheus_process.terminate() + self._prometheus_process.wait() + self._prometheus_process = None + logger.info("stopped prometheus") + if self._grafana_process is not None: + self._grafana_process.terminate() + self._grafana_process.wait() + self._grafana_process = None + logger.info("stopped grafana") + + def _spawn_self(self): + """ + Spawn a new job to run the current script. + """ + self._platform.start_job( + num_nodes=1, + entrypoint=sys.argv[0], + args=sys.argv[1:], + extra_opts=dict( + tags=["smallpond", "scheduler", smallpond.__version__], + ), + envs={ + k: v + for k, v in os.environ.items() + if k.startswith("SP_") and k != "SP_SPAWN" + }, + ) + + def _start_prometheus(self) -> Optional[subprocess.Popen]: + """ + Start prometheus server if it exists. + """ + prometheus_path = shutil.which("prometheus") + if prometheus_path is None: + logger.warning("prometheus is not found") + return None + os.makedirs(f"{self._runtime_ctx.log_root}/prometheus", exist_ok=True) + proc = subprocess.Popen( + [ + prometheus_path, + "--config.file=/tmp/ray/session_latest/metrics/prometheus/prometheus.yml", + f"--storage.tsdb.path={self._runtime_ctx.log_root}/prometheus/data", + ], + stderr=open(f"{self._runtime_ctx.log_root}/prometheus/prometheus.log", "w"), + ) + logger.info("started prometheus") + return proc + + def _start_grafana(self) -> Optional[subprocess.Popen]: + """ + Start grafana server if it exists. + """ + homepath = self._platform.grafana_homepath() + if homepath is None: + logger.warning("grafana is not found") + return None + os.makedirs(f"{self._runtime_ctx.log_root}/grafana", exist_ok=True) + proc = subprocess.Popen( + [ + shutil.which("grafana"), + "server", + "--config", + "/tmp/ray/session_latest/metrics/grafana/grafana.ini", + "-homepath", + homepath, + "web", + ], + stdout=open(f"{self._runtime_ctx.log_root}/grafana/grafana.log", "w"), + env={ + "GF_SERVER_HTTP_PORT": "8122", # redirect to an available port + "GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST") + or "http://localhost:8122", + "GF_PATHS_DATA": f"{self._runtime_ctx.log_root}/grafana/data", + }, + ) + logger.info(f"started grafana at http://localhost:8122") + return proc + + @property + def runtime_ctx(self) -> RuntimeContext: + return self._runtime_ctx + + def graph(self) -> Digraph: + """ + Get the logical plan graph. + """ + # implemented in Session class + raise NotImplementedError("graph") + + def dump_graph(self, path: Optional[str] = None): + """ + Dump the logical plan graph to a file. + """ + path = path or os.path.join(self.runtime_ctx.log_root, "graph") + try: + self.graph().render(path, format="png") + logger.debug(f"dumped graph to {path}") + except graphviz.backend.execute.ExecutableNotFound as e: + logger.warning(f"graphviz is not installed, skipping graph dump") + + def dump_timeline(self, path: Optional[str] = None): + """ + Dump the task timeline to a file. + """ + path = path or os.path.join(self.runtime_ctx.log_root, "timeline") + # the default timeline is grouped by worker + exec_path = f"{path}_exec" + ray.timeline(exec_path) + logger.debug(f"dumped timeline to {exec_path}") + + # generate another timeline grouped by node + with open(exec_path) as f: + records = json.load(f) + new_records = [] + for record in records: + # swap record name and pid-tid + name = record["name"] + try: + node_id = name.split(",")[-1] + task_id = name.split("-")[1].split(".")[0] + task_name = name.split("-")[0] + record["pid"] = f"{node_id}-{task_name}" + record["tid"] = f"task {task_id}" + new_records.append(record) + except Exception: + # filter out other records + pass + node_path = f"{path}_plan" + with open(node_path, "w") as f: + json.dump(new_records, f) + logger.debug(f"dumped timeline to {node_path}") + + def _summarize_task(self) -> Tuple[int, int]: + # implemented in Session class + raise NotImplementedError("summarize_task") + + def _dump_periodically(self): + """ + Dump the graph and timeline every minute. + Set `self._stop_event` to have a final dump and stop this thread. + """ + while not self._stop_event.is_set(): + self._stop_event.wait(60) + self.dump_graph() + self.dump_timeline() + num_total_tasks, num_finished_tasks = self._summarize_task() + percent = ( + num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0 + ) + logger.info( + f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)" + ) + + +@dataclass +class Config: + """ + Configuration for a session. + """ + + job_id: str # JOBID + job_time: datetime # JOB_TIME + data_root: str # DATA_ROOT + num_executors: int # NUM_NODES_TOTAL + ray_address: Optional[str] # RAY_ADDRESS + bind_numa_node: bool # BIND_NUMA_NODE + memory_allocator: str # MEMORY_ALLOCATOR + remove_output_root: bool + + @staticmethod + def from_args_and_env( + platform: Optional[str] = None, + job_id: Optional[str] = None, + job_time: Optional[datetime] = None, + data_root: Optional[str] = None, + num_executors: Optional[int] = None, + ray_address: Optional[str] = None, + bind_numa_node: Optional[bool] = None, + memory_allocator: Optional[str] = None, + _remove_output_root: bool = True, + **kwargs, + ) -> Config: + """ + Load config from arguments and environment variables. + If not specified, use the default value. + """ + + def get_env(key: str, type: type = str): + """ + Get an environment variable and convert it to the given type. + If the variable is not set, return None. + """ + value = os.environ.get(f"SP_{key}") + return type(value) if value is not None else None + + platform = get_platform(get_env("PLATFORM") or platform) + job_id = get_env("JOBID") or job_id or platform.default_job_id() + job_time = ( + get_env("JOB_TIME", datetime.fromisoformat) + or job_time + or platform.default_job_time() + ) + data_root = get_env("DATA_ROOT") or data_root or platform.default_data_root() + num_executors = get_env("NUM_EXECUTORS", int) or num_executors or 0 + ray_address = get_env("RAY_ADDRESS") or ray_address + bind_numa_node = get_env("BIND_NUMA_NODE") == "1" or bind_numa_node + memory_allocator = ( + get_env("MEMORY_ALLOCATOR") + or memory_allocator + or platform.default_memory_allocator() + ) + + config = Config( + job_id=job_id, + job_time=job_time, + data_root=data_root, + num_executors=num_executors, + ray_address=ray_address, + bind_numa_node=bind_numa_node, + memory_allocator=memory_allocator, + remove_output_root=_remove_output_root, + ) + return config, platform diff --git a/smallpond/utility.py b/smallpond/utility.py new file mode 100644 index 0000000..80c37f0 --- /dev/null +++ b/smallpond/utility.py @@ -0,0 +1,199 @@ +import cProfile +import inspect +import io +import logging +import pstats +import queue +import subprocess +import sys +import threading +from typing import Any, Dict, Iterable + +from loguru import logger + + +def overall_stats( + ctx, + inp, + sql_per_part, + sql_on_merged, + output_name, + output_dir=None, + cpu_limit=2, + memory_limit=30 << 30, +): + from smallpond.logical.node import DataSetPartitionNode, DataSinkNode, SqlEngineNode + + n = SqlEngineNode( + ctx, inp, sql_per_part, cpu_limit=cpu_limit, memory_limit=memory_limit + ) + p = DataSetPartitionNode(ctx, (n,), npartitions=1) + n2 = SqlEngineNode( + ctx, + (p,), + sql_on_merged, + output_name=output_name, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + ) + if output_dir is not None: + return DataSinkNode(ctx, (n2,), output_dir) + else: + return n2 + + +def execute_command(cmd: str, env: Dict[str, str] = None, shell=False): + with subprocess.Popen( + cmd.split(), + env=env, + shell=shell, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + encoding="utf8", + ) as proc: + for line in proc.stdout: + yield line.rstrip() + return_code = proc.wait() + if return_code != 0: + raise subprocess.CalledProcessError(return_code, cmd) + + +def cprofile_to_string( + perf_profile: cProfile.Profile, order_by=pstats.SortKey.TIME, top_k=20 +): + perf_profile.disable() + pstats_output = io.StringIO() + profile_stats = pstats.Stats(perf_profile, stream=pstats_output) + profile_stats.strip_dirs().sort_stats(order_by).print_stats(top_k) + return pstats_output.getvalue() + + +class Wrapper(object): + def __init__(self, base_obj: Any): + self._base_obj = base_obj + + def __str__(self) -> str: + return str(self._base_obj) + + def __repr__(self) -> str: + return repr(self._base_obj) + + def __getattr__(self, name): + return getattr(self._base_obj, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + super().__setattr__(name, value) + else: + return setattr(self._base_obj, name, value) + + +class ConcurrentIterError(Exception): + pass + + +class ConcurrentIter(object): + """ + Use a background thread to iterate over an iterable. + + Examples + -------- + The following code snippet is a common pattern to read record batches from parquet files asynchronously in arrow stream task. + ``` + from smallpond.utility import ConcurrentIter + + with ConcurrentIter(input_readers[0], max_buffer_size=1) as async_reader: + for batch_idx, batch in enumerate(async_reader): + # your code here + yield StreamOutput(output_table, batch_indices=[batch_idx]) + ``` + """ + + def __init__(self, iterable: Iterable, max_buffer_size=1) -> None: + assert isinstance( + iterable, Iterable + ), f"expect an iterable but found: {repr(iterable)}" + self.__iterable = iterable + self.__queue = queue.Queue(max_buffer_size) + self.__last = object() + self.__stop = threading.Event() + self.__thread = threading.Thread(target=self._producer) + + def __enter__(self): + self.__thread.start() + return iter(self) + + def __exit__(self, exc_type, exc_value, traceback): + self.join() + + def __iter__(self): + try: + yield from self._consumer() + finally: + self.join() + + def join(self): + self.__stop.set() + self.clear() + self.__thread.join(timeout=1) + if self.__thread.is_alive(): + print(f"waiting {self.__thread.name} of {self}", file=sys.stderr) + self.__thread.join() + print(f"joined {self.__thread.name} of {self}", file=sys.stderr) + + def clear(self): + try: + while self.__queue.get_nowait() is not None: + pass + except queue.Empty: + pass + + def _producer(self): + try: + for item in self.__iterable: + self.__queue.put(item) + if self.__stop.is_set(): + self.clear() + break + except Exception as ex: + print(f"Error in {self}: {ex}", file=sys.stderr) + self.clear() + self.__queue.put(ConcurrentIterError(ex)) + else: + self.__queue.put(self.__last) + + def _consumer(self): + while True: + item = self.__queue.get() + if item is self.__last: + break + if isinstance(item, ConcurrentIterError): + (ex,) = item.args + raise item from ex + yield item + + +class InterceptHandler(logging.Handler): + """ + Intercept standard logging messages toward loguru sinks. + See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging + """ + + def emit(self, record: logging.LogRecord) -> None: + # Get corresponding Loguru level if it exists. + level: str | int + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + # Find caller from where originated the logged message. + frame, depth = inspect.currentframe(), 0 + while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__): + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) diff --git a/smallpond/worker.py b/smallpond/worker.py new file mode 100644 index 0000000..891cc6e --- /dev/null +++ b/smallpond/worker.py @@ -0,0 +1,78 @@ +# this file is the entry point for smallpond workers + +import argparse +import os +import socket +import subprocess + +import psutil + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="smallpond worker") + parser.add_argument( + "--ray_address", + required=True, + help="The address of the Ray cluster to connect to", + ) + parser.add_argument( + "--log_dir", required=True, help="The directory where logs will be stored" + ) + parser.add_argument( + "--bind_numa_node", + action="store_true", + help="Bind executor processes to numa nodes", + ) + + args = parser.parse_args() + log_path = os.path.join(args.log_dir, f"{socket.gethostname()}.log") + + # limit the number of CPUs to the number of physical cores + cpu_count = psutil.cpu_count(logical=False) + memory = psutil.virtual_memory().total + + if args.bind_numa_node: + import numa + + numa_node_count = numa.info.get_num_configured_nodes() + cpu_count_per_socket = cpu_count // numa_node_count + memory_per_socket = memory // numa_node_count + for i in range(numa_node_count): + subprocess.run( + [ + "numactl", + "-N", + str(i), + "-m", + str(i), + "ray", + "start", + "--address", + args.ray_address, + "--num-cpus", + str(cpu_count_per_socket), + "--memory", + str(memory_per_socket), + ], + check=True, + ) + else: + subprocess.run( + [ + "ray", + "start", + "--address", + args.ray_address, + "--num-cpus", + str(cpu_count), + ], + check=True, + ) + + # keep printing logs + while True: + try: + subprocess.run(["tail", "-F", log_path], check=True) + except subprocess.CalledProcessError as e: + # XXX: sometimes it raises `No such file or directory` + # don't know why. just ignore it + print(e) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..83ca8a8 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,30 @@ +import os + +import pytest +import ray + +import smallpond + + +@pytest.fixture(scope="session") +def ray_address(): + """A global Ray instance for all tests""" + ray_address = ray.init( + address="local", + # disable dashboard in unit tests + include_dashboard=False, + ).address_info["gcs_address"] + yield ray_address + ray.shutdown() + + +@pytest.fixture +def sp(ray_address: str, request): + """A smallpond session for each test""" + runtime_root = os.getenv("TEST_RUNTIME_ROOT") or f"tests/runtime" + sp = smallpond.init( + data_root=os.path.join(runtime_root, request.node.name), + ray_address=ray_address, + ) + yield sp + sp.shutdown() diff --git a/tests/datagen.py b/tests/datagen.py new file mode 100644 index 0000000..1750137 --- /dev/null +++ b/tests/datagen.py @@ -0,0 +1,186 @@ +import base64 +import glob +import os +import random +import string +from concurrent.futures import ProcessPoolExecutor +from datetime import datetime, timedelta, timezone +from typing import Tuple + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +from filelock import FileLock + + +def generate_url_and_domain() -> Tuple[str, str]: + domain_part = "".join( + random.choices(string.ascii_lowercase, k=random.randint(5, 15)) + ) + tld = random.choice(["com", "net", "org", "cn", "edu", "gov", "co", "io"]) + domain = f"www.{domain_part}.{tld}" + + path_segments = [] + for _ in range(random.randint(1, 3)): + segment = "".join( + random.choices( + string.ascii_lowercase + string.digits, k=random.randint(3, 10) + ) + ) + path_segments.append(segment) + path = "/" + "/".join(path_segments) + + protocol = random.choice(["http", "https"]) + + if random.random() < 0.3: + path += random.choice([".html", ".php", ".htm", ".aspx"]) + + return f"{protocol}://{domain}{path}", domain + + +def generate_random_date() -> str: + start = datetime(2023, 1, 1, tzinfo=timezone.utc) + end = datetime(2023, 12, 31, tzinfo=timezone.utc) + delta = end - start + random_date = start + timedelta( + seconds=random.randint(0, int(delta.total_seconds())) + ) + return random_date.strftime("%Y-%m-%dT%H:%M:%SZ") + + +def generate_content() -> bytes: + target_length = ( + random.randint(1000, 100000) + if random.random() < 0.8 + else random.randint(100000, 1000000) + ) + before = b"Random Page" + after = b"" + total_before_after = len(before) + len(after) + + fill_length = max(target_length - total_before_after, 0) + filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[ + :fill_length + ] + + return before + filler + after + + +def generate_arrow_parquet(path: str, num_rows=100): + data = [] + for _ in range(num_rows): + url, domain = generate_url_and_domain() + date = generate_random_date() + content = generate_content() + + data.append({"url": url, "domain": domain, "date": date, "content": content}) + + df = pd.DataFrame(data) + df.to_parquet(path, engine="pyarrow") + + +def generate_arrow_files(output_dir: str, num_files=10): + os.makedirs(output_dir, exist_ok=True) + with ProcessPoolExecutor(max_workers=10) as executor: + executor.map( + generate_arrow_parquet, + [f"{output_dir}/data{i}.parquet" for i in range(num_files)], + ) + + +def concat_arrow_files(input_dir: str, output_dir: str, repeat: int = 10): + os.makedirs(output_dir, exist_ok=True) + files = glob.glob(os.path.join(input_dir, "*.parquet")) + table = pa.concat_tables([pa.parquet.read_table(file) for file in files] * repeat) + pq.write_table(table, os.path.join(output_dir, "large_array.parquet")) + + +def generate_random_string(length: int) -> str: + """Generate a random string of a specified length""" + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) + + +def generate_random_url() -> str: + """Generate a random URL""" + path = generate_random_string(random.randint(10, 20)) + return ( + f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html" + ) + + +def generate_random_data() -> str: + """Generate random data""" + url = generate_random_url() + content = generate_random_string(random.randint(50, 100)) + encoded_content = base64.b64encode(content.encode()).decode() + return f"{url}\t{encoded_content}" + + +def generate_url_parquet(path: str, num_rows=100): + """Generate a parquet file with a specified number of random data lines""" + data = [] + for _ in range(num_rows): + url = generate_random_url() + host = url.split("/")[0] + data.append({"host": host, "url": url}) + + df = pd.DataFrame(data) + df.to_parquet(path, engine="pyarrow") + + +def generate_url_parquet_files(output_dir: str, num_files: int = 10): + """Generate multiple parquet files with a specified number of random data lines""" + os.makedirs(output_dir, exist_ok=True) + with ProcessPoolExecutor(max_workers=10) as executor: + executor.map( + generate_url_parquet, + [f"{output_dir}/urls{i}.parquet" for i in range(num_files)], + ) + + +def generate_url_tsv_files( + output_dir: str, num_files: int = 10, lines_per_file: int = 100 +): + """Generate multiple files, each containing a specified number of random data lines""" + os.makedirs(output_dir, exist_ok=True) + for i in range(num_files): + with open(f"{output_dir}/urls{i}.tsv", "w") as f: + for _ in range(lines_per_file): + f.write(generate_random_data() + "\n") + + +def generate_long_path_list(path: str, num_lines: int = 1048576): + """Generate a list of long paths""" + with open(path, "w", buffering=16 * 1024 * 1024) as f: + for i in range(num_lines): + path = os.path.abspath(f"tests/data/arrow/data{i % 10}.parquet") + f.write(f"{path}\n") + + +def generate_data(path: str = "tests/data"): + """ + Generate all data for testing. + """ + os.makedirs(path, exist_ok=True) + try: + with FileLock(path + "/data.lock"): + print("Generating data...") + if not os.path.exists(path + "/mock_urls"): + generate_url_tsv_files( + output_dir=path + "/mock_urls", num_files=10, lines_per_file=100 + ) + generate_url_parquet_files(output_dir=path + "/mock_urls", num_files=10) + if not os.path.exists(path + "/arrow"): + generate_arrow_files(output_dir=path + "/arrow", num_files=10) + if not os.path.exists(path + "/large_array"): + concat_arrow_files( + input_dir=path + "/arrow", output_dir=path + "/large_array" + ) + if not os.path.exists(path + "/long_path_list.txt"): + generate_long_path_list(path=path + "/long_path_list.txt") + except Exception as e: + print(f"Error generating data: {e}") + + +if __name__ == "__main__": + generate_data() diff --git a/tests/test_arrow.py b/tests/test_arrow.py new file mode 100644 index 0000000..4e5a13f --- /dev/null +++ b/tests/test_arrow.py @@ -0,0 +1,187 @@ +import glob +import os.path +import tempfile +import unittest + +import pyarrow.parquet as parquet +from loguru import logger + +from smallpond.io.arrow import ( + RowRange, + build_batch_reader_from_files, + cast_columns_to_large_string, + dump_to_parquet_files, + load_from_parquet_files, +) +from smallpond.utility import ConcurrentIter +from tests.test_fabric import TestFabric + + +class TestArrow(TestFabric, unittest.TestCase): + def test_load_from_parquet_files(self): + for dataset_path in ( + "tests/data/arrow/*.parquet", + "tests/data/large_array/*.parquet", + ): + with self.subTest(dataset_path=dataset_path): + parquet_files = glob.glob(dataset_path) + expected = self._load_parquet_files(parquet_files) + actual = load_from_parquet_files(parquet_files) + self._compare_arrow_tables(expected, actual) + + def test_load_parquet_row_ranges(self): + for dataset_path in ( + "tests/data/arrow/data0.parquet", + "tests/data/large_array/large_array.parquet", + ): + with self.subTest(dataset_path=dataset_path): + metadata = parquet.read_metadata(dataset_path) + file_num_rows = metadata.num_rows + data_size = sum( + metadata.row_group(i).total_byte_size + for i in range(metadata.num_row_groups) + ) + row_range = RowRange( + path=dataset_path, + begin=100, + end=200, + data_size=data_size, + file_num_rows=file_num_rows, + ) + expected = self._load_parquet_files([dataset_path]).slice( + offset=100, length=100 + ) + actual = load_from_parquet_files([row_range]) + self._compare_arrow_tables(expected, actual) + + def test_dump_to_parquet_files(self): + for dataset_path in ( + "tests/data/arrow/*.parquet", + "tests/data/large_array/*.parquet", + ): + with self.subTest(dataset_path=dataset_path): + parquet_files = glob.glob(dataset_path) + expected = self._load_parquet_files(parquet_files) + with tempfile.TemporaryDirectory( + dir=self.output_root_abspath + ) as output_dir: + ok = dump_to_parquet_files(expected, output_dir) + self.assertTrue(ok) + actual = self._load_parquet_files( + glob.glob(f"{output_dir}/*.parquet") + ) + self._compare_arrow_tables(expected, actual) + + def test_dump_load_empty_table(self): + # create empty table + empty_table = self._load_parquet_files( + ["tests/data/arrow/data0.parquet"] + ).slice(length=0) + self.assertEqual(empty_table.num_rows, 0) + # dump empty table + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + ok = dump_to_parquet_files(empty_table, output_dir) + self.assertTrue(ok) + parquet_files = glob.glob(f"{output_dir}/*.parquet") + # load empty table from file + actual_table = load_from_parquet_files(parquet_files) + self._compare_arrow_tables(empty_table, actual_table) + + def test_parquet_batch_reader(self): + for dataset_path in ( + "tests/data/arrow/*.parquet", + "tests/data/large_array/*.parquet", + ): + with self.subTest(dataset_path=dataset_path): + parquet_files = glob.glob(dataset_path) + expected_num_rows = sum( + parquet.read_metadata(file).num_rows for file in parquet_files + ) + with build_batch_reader_from_files( + parquet_files, + batch_size=expected_num_rows, + max_batch_byte_size=None, + ) as batch_reader, ConcurrentIter(batch_reader) as concurrent_iter: + total_num_rows = 0 + for batch in concurrent_iter: + print( + f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}" + ) + self.assertLessEqual(batch.num_rows, expected_num_rows) + total_num_rows += batch.num_rows + self.assertEqual(total_num_rows, expected_num_rows) + + def test_table_to_batches(self): + for dataset_path in ( + "tests/data/arrow/*.parquet", + "tests/data/large_array/*.parquet", + ): + with self.subTest(dataset_path=dataset_path): + parquet_files = glob.glob(dataset_path) + table = self._load_parquet_files(parquet_files) + total_num_rows = 0 + for batch in table.to_batches(max_chunksize=table.num_rows): + print( + f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}" + ) + self.assertLessEqual(batch.num_rows, table.num_rows) + total_num_rows += batch.num_rows + self.assertEqual(total_num_rows, table.num_rows) + + def test_arrow_schema_metadata(self): + table = self._load_parquet_files(glob.glob("tests/data/arrow/*.parquet")) + metadata = {b"a": b"1", b"b": b"2"} + table_with_meta = table.replace_schema_metadata(metadata) + print(f"table_with_meta.schema.metadata {table_with_meta.schema.metadata}") + + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + self.assertTrue( + dump_to_parquet_files( + table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2 + ) + ) + parquet_files = glob.glob( + os.path.join(output_dir, "arrow_schema_metadata*.parquet") + ) + loaded_table = load_from_parquet_files( + parquet_files, table.column_names[:1] + ) + print(f"loaded_table.schema.metadata {loaded_table.schema.metadata}") + self.assertEqual( + table_with_meta.schema.metadata, loaded_table.schema.metadata + ) + with parquet.ParquetFile(parquet_files[0]) as file: + print(f"file.schema_arrow.metadata {file.schema_arrow.metadata}") + self.assertEqual( + table_with_meta.schema.metadata, file.schema_arrow.metadata + ) + + def test_load_mixed_string_types(self): + parquet_paths = glob.glob("tests/data/arrow/*.parquet") + table = self._load_parquet_files(parquet_paths) + + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + dump_to_parquet_files(cast_columns_to_large_string(table), output_dir) + parquet_paths += glob.glob(os.path.join(output_dir, "*.parquet")) + loaded_table = load_from_parquet_files(parquet_paths) + self.assertEqual(table.num_rows * 2, loaded_table.num_rows) + batch_reader = build_batch_reader_from_files(parquet_paths) + self.assertEqual( + table.num_rows * 2, sum(batch.num_rows for batch in batch_reader) + ) + + @logger.catch(reraise=True, message="failed to load parquet files") + def _load_from_parquet_files_with_log(self, paths, columns): + load_from_parquet_files(paths, columns) + + def test_load_not_exist_column(self): + parquet_files = glob.glob("tests/data/arrow/*.parquet") + with self.assertRaises(AssertionError) as context: + self._load_from_parquet_files_with_log(parquet_files, ["not_exist_column"]) + + def test_change_ordering_of_columns(self): + parquet_files = glob.glob("tests/data/arrow/*.parquet") + loaded_table = load_from_parquet_files(parquet_files) + reversed_cols = list(reversed(loaded_table.column_names)) + loaded_table = load_from_parquet_files(parquet_files, reversed_cols) + self.assertEqual(loaded_table.column_names, reversed_cols) diff --git a/tests/test_bench.py b/tests/test_bench.py new file mode 100644 index 0000000..f0fe872 --- /dev/null +++ b/tests/test_bench.py @@ -0,0 +1,90 @@ +import shutil +import unittest + +from benchmarks.file_io_benchmark import file_io_benchmark +from benchmarks.gray_sort_benchmark import generate_random_records, gray_sort_benchmark +from benchmarks.hash_partition_benchmark import hash_partition_benchmark +from benchmarks.urls_sort_benchmark import urls_sort_benchmark +from smallpond.common import MB +from smallpond.logical.node import Context, LogicalPlan +from tests.test_fabric import TestFabric + + +class TestBench(TestFabric, unittest.TestCase): + + fault_inject_prob = 0.05 + + def test_file_io_benchmark(self): + for io_engine in ("duckdb", "arrow", "stream"): + with self.subTest(io_engine=io_engine): + plan = file_io_benchmark( + ["tests/data/mock_urls/*.parquet"], + npartitions=3, + io_engine=io_engine, + ) + self.execute_plan(plan, enable_profiling=True) + + def test_urls_sort_benchmark(self): + for engine_type in ("duckdb", "arrow"): + with self.subTest(engine_type=engine_type): + plan = urls_sort_benchmark( + ["tests/data/mock_urls/*.tsv"], + num_data_partitions=3, + num_hash_partitions=3, + engine_type=engine_type, + ) + self.execute_plan(plan, enable_profiling=True) + + @unittest.skipIf(shutil.which("gensort") is None, "gensort not found") + def test_gray_sort_benchmark(self): + record_nbytes = 100 + key_nbytes = 10 + total_data_nbytes = 100 * MB + gensort_batch_nbytes = 10 * MB + num_data_partitions = 5 + num_sort_partitions = 1 << 3 + for shuffle_engine in ("duckdb", "arrow"): + for sort_engine in ("duckdb", "arrow", "polars"): + with self.subTest( + shuffle_engine=shuffle_engine, sort_engine=sort_engine + ): + ctx = Context() + random_records = generate_random_records( + ctx, + record_nbytes, + key_nbytes, + total_data_nbytes, + gensort_batch_nbytes, + num_data_partitions, + num_sort_partitions, + ) + plan = LogicalPlan(ctx, random_records) + exec_plan = self.execute_plan(plan, enable_profiling=True) + + plan = gray_sort_benchmark( + record_nbytes, + key_nbytes, + total_data_nbytes, + gensort_batch_nbytes, + num_data_partitions, + num_sort_partitions, + input_paths=exec_plan.final_output.resolved_paths, + shuffle_engine=shuffle_engine, + sort_engine=sort_engine, + hive_partitioning=True, + validate_results=True, + ) + self.execute_plan(plan, enable_profiling=True) + + def test_hash_partition_benchmark(self): + for engine_type in ("duckdb", "arrow"): + with self.subTest(engine_type=engine_type): + plan = hash_partition_benchmark( + ["tests/data/mock_urls/*.parquet"], + npartitions=5, + hash_columns=["url"], + engine_type=engine_type, + hive_partitioning=True, + partition_stats=True, + ) + self.execute_plan(plan, enable_profiling=True) diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 0000000..d0edbfd --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,73 @@ +import itertools +import unittest + +import numpy as np +from hypothesis import given +from hypothesis import strategies as st + +from smallpond.common import get_nth_partition, split_into_cols, split_into_rows +from tests.test_fabric import TestFabric + + +class TestCommon(TestFabric, unittest.TestCase): + def test_get_nth_partition(self): + items = [1, 2, 3] + # split into 1 partitions + self.assertListEqual([1, 2, 3], get_nth_partition(items, 0, 1)) + # split into 2 partitions + self.assertListEqual([1, 2], get_nth_partition(items, 0, 2)) + self.assertListEqual([3], get_nth_partition(items, 1, 2)) + # split into 3 partitions + self.assertListEqual([1], get_nth_partition(items, 0, 3)) + self.assertListEqual([2], get_nth_partition(items, 1, 3)) + self.assertListEqual([3], get_nth_partition(items, 2, 3)) + # split into 5 partitions + self.assertListEqual([1], get_nth_partition(items, 0, 5)) + self.assertListEqual([2], get_nth_partition(items, 1, 5)) + self.assertListEqual([3], get_nth_partition(items, 2, 5)) + self.assertListEqual([], get_nth_partition(items, 3, 5)) + self.assertListEqual([], get_nth_partition(items, 4, 5)) + + @given(st.data()) + def test_split_into_rows(self, data: st.data): + nelements = data.draw(st.integers(1, 100)) + npartitions = data.draw(st.integers(1, 2 * nelements)) + items = list(range(nelements)) + computed = split_into_rows(items, npartitions) + expected = [ + get_nth_partition(items, n, npartitions) for n in range(npartitions) + ] + self.assertEqual(expected, computed) + + @given(st.data()) + def test_split_into_cols(self, data: st.data): + nelements = data.draw(st.integers(1, 100)) + npartitions = data.draw(st.integers(1, 2 * nelements)) + items = list(range(nelements)) + chunks = split_into_cols(items, npartitions) + self.assertEqual(npartitions, len(chunks)) + self.assertListEqual( + items, + [x for row in itertools.zip_longest(*chunks) for x in row if x is not None], + ) + chunk_sizes = set(len(chk) for chk in chunks) + if len(chunk_sizes) > 1: + small_size, large_size = sorted(chunk_sizes) + self.assertEqual(small_size + 1, large_size) + else: + (chunk_size,) = chunk_sizes + self.assertEqual(len(items), chunk_size * npartitions) + + def test_split_into_rows_bench(self): + for nelements in [100000, 1000000]: + items = np.arange(nelements) + for npartitions in [1024, 4096, 10240, nelements, 2 * nelements]: + chunks = split_into_rows(items, npartitions) + self.assertEqual(npartitions, len(chunks)) + + def test_split_into_cols_bench(self): + for nelements in [100000, 1000000]: + items = np.arange(nelements) + for npartitions in [1024, 4096, 10240, nelements, 2 * nelements]: + chunks = split_into_cols(items, npartitions) + self.assertEqual(npartitions, len(chunks)) diff --git a/tests/test_dataframe.py b/tests/test_dataframe.py new file mode 100644 index 0000000..49aff23 --- /dev/null +++ b/tests/test_dataframe.py @@ -0,0 +1,223 @@ +from typing import List + +import pandas as pd +import pyarrow as pa +import pytest + +from smallpond.dataframe import Session + + +def test_pandas(sp: Session): + pandas_df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = sp.from_pandas(pandas_df) + assert df.to_pandas().equals(pandas_df) + + +def test_arrow(sp: Session): + arrow_table = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = sp.from_arrow(arrow_table) + assert df.to_arrow() == arrow_table + + +def test_items(sp: Session): + df = sp.from_items([1, 2, 3]) + assert df.take_all() == [{"item": 1}, {"item": 2}, {"item": 3}] + df = sp.from_items([{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]) + assert df.take_all() == [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}] + + +def test_csv(sp: Session): + df = sp.read_csv( + "tests/data/mock_urls/*.tsv", + schema={"urlstr": "varchar", "valstr": "varchar"}, + delim=r"\t", + ) + assert df.count() == 1000 + + +def test_parquet(sp: Session): + df = sp.read_parquet("tests/data/mock_urls/*.parquet") + assert df.count() == 1000 + + +def test_take(sp: Session): + df = sp.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) + assert df.take(2) == [{"a": 1, "b": 4}, {"a": 2, "b": 5}] + assert df.take_all() == [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}] + + +def test_map(sp: Session): + df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})) + df1 = df.map("a + b as c") + assert df1.to_arrow() == pa.table({"c": [5, 7, 9]}) + df2 = df.map(lambda r: {"c": r["a"] + r["b"]}) + assert df2.to_arrow() == pa.table({"c": [5, 7, 9]}) + + # user need to specify the schema if can not be inferred from the mapping values + df3 = df.map( + lambda r: {"c": None if r["a"] == 1 else r["a"] + r["b"]}, + schema=pa.schema([("c", pa.int64())]), + ) + assert df3.to_arrow() == pa.table({"c": pa.array([None, 7, 9], type=pa.int64())}) + + +def test_flat_map(sp: Session): + df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})) + df1 = df.flat_map(lambda r: [{"c": r["a"]}, {"c": r["b"]}]) + assert df1.to_arrow() == pa.table({"c": [1, 4, 2, 5, 3, 6]}) + df2 = df.flat_map("unnest(array[a, b]) as c") + assert df2.to_arrow() == pa.table({"c": [1, 4, 2, 5, 3, 6]}) + + # user need to specify the schema if can not be inferred from the mapping values + df3 = df.flat_map(lambda r: [{"c": None}], schema=pa.schema([("c", pa.int64())])) + assert df3.to_arrow() == pa.table( + {"c": pa.array([None, None, None], type=pa.int64())} + ) + + +def test_map_batches(sp: Session): + df = sp.read_parquet("tests/data/mock_urls/*.parquet") + df = df.map_batches( + lambda batch: pa.table({"num_rows": [batch.num_rows]}), + batch_size=350, + ) + assert df.take_all() == [{"num_rows": 350}, {"num_rows": 350}, {"num_rows": 300}] + + +def test_filter(sp: Session): + df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})) + df1 = df.filter("a > 1") + assert df1.to_arrow() == pa.table({"a": [2, 3], "b": [5, 6]}) + df2 = df.filter(lambda r: r["a"] > 1) + assert df2.to_arrow() == pa.table({"a": [2, 3], "b": [5, 6]}) + + +def test_random_shuffle(sp: Session): + df = sp.from_items(list(range(1000))).repartition(10, by_rows=True) + df = df.random_shuffle() + shuffled = [d["item"] for d in df.take_all()] + assert sorted(shuffled) == list(range(1000)) + + def count_inversions(arr: List[int]) -> int: + return sum( + sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j]) + for i in range(len(arr)) + ) + + # check the shuffle is random enough + # the expected number of inversions is n*(n-1)/4 = 249750 + assert 220000 <= count_inversions(shuffled) <= 280000 + + +def test_partition_by(sp: Session): + df = sp.from_items(list(range(1000))).repartition(10, by="item % 10") + df = df.map("min(item % 10) as min, max(item % 10) as max") + assert df.take_all() == [{"min": i, "max": i} for i in range(10)] + + +def test_partition_by_key_out_of_range(sp: Session): + df = sp.from_items(list(range(1000))).repartition(10, by="item % 11") + try: + df.to_arrow() + except Exception as ex: + assert "partition key 10 is out of range 0-9" in str(ex) + else: + assert False, "expected exception" + + +def test_partition_by_hash(sp: Session): + df = sp.from_items(list(range(1000))).repartition(10, hash_by="item") + items = [d["item"] for d in df.take_all()] + assert sorted(items) == list(range(1000)) + + +def test_count(sp: Session): + df = sp.from_items([1, 2, 3]) + assert df.count() == 3 + + +def test_limit(sp: Session): + df = sp.from_items(list(range(1000))).repartition(10, by_rows=True) + assert df.limit(2).count() == 2 + + +@pytest.mark.skip(reason="limit can not be pushed down to sql node for now") +@pytest.mark.timeout(10) +def test_limit_large(sp: Session): + # limit will be fused with the previous select + # otherwise, it will be timeout + df = sp.partial_sql("select * from range(1000000000)") + assert df.limit(2).count() == 2 + + +def test_partial_sql(sp: Session): + # no input deps + df = sp.partial_sql("select * from range(3)") + assert df.to_arrow() == pa.table({"range": [0, 1, 2]}) + + # join + df1 = sp.from_arrow(pa.table({"id1": [1, 2, 3], "val1": ["a", "b", "c"]})) + df2 = sp.from_arrow(pa.table({"id2": [1, 2, 3], "val2": ["d", "e", "f"]})) + joined = sp.partial_sql( + "select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2 + ) + assert joined.to_arrow() == pa.table( + {"id1": [1, 2, 3], "val1": ["a", "b", "c"], "val2": ["d", "e", "f"]}, + schema=pa.schema( + [ + ("id1", pa.int64()), + ("val1", pa.large_string()), + ("val2", pa.large_string()), + ] + ), + ) + + +def test_error_message(sp: Session): + df = sp.from_items([1, 2, 3]) + df = sp.partial_sql("select a,, from {0}", df) + try: + df.to_arrow() + except Exception as ex: + # sql query should be in the exception message + assert "select a,, from" in str(ex) + else: + assert False, "expected exception" + + +def test_unpicklable_task_exception(sp: Session): + from loguru import logger + + df = sp.from_items([1, 2, 3]) + try: + df.map(lambda x: logger.info("use outside logger")).to_arrow() + except Exception as ex: + assert "Can't pickle task" in str(ex) + assert ( + "HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." + in str(ex) + ) + else: + assert False, "expected exception" + + +def test_log(sp: Session): + df = sp.from_items([1, 2, 3]) + + def log_record(x): + import logging + import sys + + from loguru import logger + + print("stdout") + print("stderr", file=sys.stderr) + logger.info("loguru") + logging.info("logging") + return x + + df.map(log_record).to_arrow() + + # TODO: check logs should be see in the log file + # FIXME: logs in unit test are not written to the log file + # because we share the same ray instance for all tests diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..b9c7a05 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,174 @@ +import glob +import os.path +import unittest +from pathlib import PurePath + +import duckdb +import pandas +import pyarrow as arrow +import pytest +from loguru import logger + +from smallpond.common import DEFAULT_ROW_GROUP_SIZE, MB +from smallpond.logical.dataset import ParquetDataSet +from smallpond.utility import ConcurrentIter +from tests.test_fabric import TestFabric + + +class TestDataSet(TestFabric, unittest.TestCase): + def test_parquet_file_created_by_pandas(self): + num_urls = 0 + for txt_file in glob.glob("tests/data/mock_urls/*.tsv"): + urls = pandas.read_csv(txt_file, delimiter="\t", names=["url"]) + urls.to_parquet( + os.path.join( + self.output_root_abspath, + PurePath(os.path.basename(txt_file)).with_suffix(".parquet"), + ) + ) + num_urls += urls.size + dataset = ParquetDataSet([os.path.join(self.output_root_abspath, "*.parquet")]) + self.assertEqual(num_urls, dataset.num_rows) + + def _generate_parquet_dataset( + self, output_path, npartitions, num_rows, row_group_size + ): + duckdb.sql( + f"""copy ( + select range as i, range % {npartitions} as partition from range(0, {num_rows}) ) + to '{output_path}' + (FORMAT PARQUET, ROW_GROUP_SIZE {row_group_size}, PARTITION_BY partition, OVERWRITE_OR_IGNORE true)""" + ) + return ParquetDataSet([f"{output_path}/**/*.parquet"]) + + def _check_partition_datasets( + self, orig_dataset: ParquetDataSet, partition_func, npartition + ): + # build partitioned datasets + partitioned_datasets = partition_func(npartition) + self.assertEqual(npartition, len(partitioned_datasets)) + self.assertEqual( + orig_dataset.num_rows, + sum(dataset.num_rows for dataset in partitioned_datasets), + ) + # load as arrow table + loaded_table = arrow.concat_tables( + [dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets] + ) + self.assertEqual(orig_dataset.num_rows, loaded_table.num_rows) + # compare arrow tables + orig_table = orig_dataset.to_arrow_table(max_workers=1) + self.assertEqual(orig_table.shape, loaded_table.shape) + self.assertTrue(orig_table.sort_by("i").equals(loaded_table.sort_by("i"))) + # compare sql query results + join_query = f""" + select count(a.i) as num_rows + from {orig_dataset.sql_query_fragment()} as a + join ( {' union all '.join([dataset.sql_query_fragment() for dataset in partitioned_datasets])} ) as b on a.i = b.i""" + results = duckdb.sql(join_query).fetchall() + self.assertEqual(orig_dataset.num_rows, results[0][0]) + + def test_num_rows(self): + dataset = ParquetDataSet(["tests/data/arrow/*.parquet"]) + self.assertEqual(dataset.num_rows, 1000) + + def test_partition_by_files(self): + output_path = os.path.join(self.output_root_abspath, "test_partition_by_files") + orig_dataset = self._generate_parquet_dataset( + output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000 + ) + num_files = len(orig_dataset.resolved_paths) + for npartition in range(1, num_files + 1): + for random_shuffle in (False, True): + with self.subTest(npartition=npartition, random_shuffle=random_shuffle): + orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir) + self._check_partition_datasets( + orig_dataset, + lambda n: orig_dataset.partition_by_files( + n, random_shuffle=random_shuffle + ), + npartition, + ) + + def test_partition_by_rows(self): + output_path = os.path.join(self.output_root_abspath, "test_partition_by_rows") + orig_dataset = self._generate_parquet_dataset( + output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000 + ) + num_files = len(orig_dataset.resolved_paths) + for npartition in range(1, 2 * num_files + 1): + for random_shuffle in (False, True): + with self.subTest(npartition=npartition, random_shuffle=random_shuffle): + orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir) + self._check_partition_datasets( + orig_dataset, + lambda n: orig_dataset.partition_by_rows( + n, random_shuffle=random_shuffle + ), + npartition, + ) + + def test_resolved_many_paths(self): + with open("tests/data/long_path_list.txt", buffering=16 * MB) as fin: + filenames = list(map(os.path.basename, map(str.strip, fin.readlines()))) + logger.info(f"loaded {len(filenames)} filenames") + dataset = ParquetDataSet(filenames) + self.assertEqual(len(dataset.resolved_paths), len(filenames)) + + def test_paths_with_char_ranges(self): + dataset_with_char_ranges = ParquetDataSet( + ["tests/data/arrow/data[0-9].parquet"] + ) + dataset_with_wildcards = ParquetDataSet(["tests/data/arrow/*.parquet"]) + self.assertEqual( + len(dataset_with_char_ranges.resolved_paths), + len(dataset_with_wildcards.resolved_paths), + ) + + def test_to_arrow_table_batch_reader(self): + memdb = duckdb.connect( + database=":memory:", config={"arrow_large_buffer_size": "true"} + ) + for dataset_path in ( + "tests/data/arrow/*.parquet", + "tests/data/large_array/*.parquet", + ): + for conn in (None, memdb): + print(f"dataset_path: {dataset_path}, conn: {conn}") + with self.subTest(dataset_path=dataset_path, conn=conn): + dataset = ParquetDataSet([dataset_path]) + to_batches = dataset.to_arrow_table( + max_workers=1, conn=conn + ).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2) + batch_reader = dataset.to_batch_reader( + batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn + ) + with ConcurrentIter( + batch_reader, max_buffer_size=2 + ) as batch_reader: + for batch_iter in (to_batches, batch_reader): + total_num_rows = 0 + for batch in batch_iter: + print( + f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}" + ) + self.assertLessEqual( + batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2 + ) + total_num_rows += batch.num_rows + print(f"{dataset_path}: total_num_rows {total_num_rows}") + self.assertEqual(total_num_rows, dataset.num_rows) + + +@pytest.mark.parametrize("reader", ["arrow", "duckdb"]) +@pytest.mark.parametrize("dataset_path", ["tests/data/arrow/*.parquet"]) +# @pytest.mark.parametrize("dataset_path", ["tests/data/arrow/*.parquet", "tests/data/large_array/*.parquet"]) +def test_arrow_reader(benchmark, reader: str, dataset_path: str): + dataset = ParquetDataSet([dataset_path]) + conn = None + if reader == "duckdb": + conn = duckdb.connect( + database=":memory:", config={"arrow_large_buffer_size": "true"} + ) + benchmark(dataset.to_arrow_table, conn=conn) + # result: arrow reader is 4x faster than duckdb reader in small dataset, 1.4x faster in large dataset diff --git a/tests/test_deltalake.py b/tests/test_deltalake.py new file mode 100644 index 0000000..c6c2d4b --- /dev/null +++ b/tests/test_deltalake.py @@ -0,0 +1,60 @@ +import glob +import importlib +import tempfile +import unittest + +from smallpond.io.arrow import cast_columns_to_large_string +from tests.test_fabric import TestFabric + + +@unittest.skipUnless( + importlib.util.find_spec("deltalake") is not None, "cannot find deltalake" +) +class TestDeltaLake(TestFabric, unittest.TestCase): + def test_read_write_deltalake(self): + from deltalake import DeltaTable, write_deltalake + + for dataset_path in ( + "tests/data/arrow/*.parquet", + "tests/data/large_array/*.parquet", + ): + parquet_files = glob.glob(dataset_path) + expected = self._load_parquet_files(parquet_files) + with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory( + dir=self.output_root_abspath + ) as output_dir: + write_deltalake(output_dir, expected, large_dtypes=True) + dt = DeltaTable(output_dir) + self._compare_arrow_tables(expected, dt.to_pyarrow_table()) + + def test_load_mixed_large_dtypes(self): + from deltalake import DeltaTable, write_deltalake + + for dataset_path in ( + "tests/data/arrow/*.parquet", + "tests/data/large_array/*.parquet", + ): + parquet_files = glob.glob(dataset_path) + with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory( + dir=self.output_root_abspath + ) as output_dir: + table = cast_columns_to_large_string( + self._load_parquet_files(parquet_files) + ) + write_deltalake(output_dir, table, large_dtypes=True, mode="overwrite") + write_deltalake(output_dir, table, large_dtypes=False, mode="append") + loaded_table = DeltaTable(output_dir).to_pyarrow_table() + print("table:\n", table.schema) + print("loaded_table:\n", loaded_table.schema) + self.assertEqual(table.num_rows * 2, loaded_table.num_rows) + + def test_delete_update(self): + import pandas as pd + from deltalake import DeltaTable, write_deltalake + + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + df = pd.DataFrame({"num": [1, 2, 3], "animal": ["cat", "dog", "snake"]}) + write_deltalake(output_dir, df, mode="overwrite") + dt = DeltaTable(output_dir) + dt.delete("animal = 'cat'") + dt.update(predicate="num = 3", new_values={"animal": "fish"}) diff --git a/tests/test_driver.py b/tests/test_driver.py new file mode 100644 index 0000000..a9805e3 --- /dev/null +++ b/tests/test_driver.py @@ -0,0 +1,46 @@ +import os.path +import unittest +import uuid + +from loguru import logger + +from benchmarks.gray_sort_benchmark import gray_sort_benchmark +from examples.sort_mock_urls import sort_mock_urls +from smallpond.common import GB, MB +from smallpond.execution.driver import Driver +from tests.test_fabric import TestFabric + + +@unittest.skipUnless(os.getenv("ENABLE_DRIVER_TEST"), "unit test disabled") +class TestDriver(TestFabric, unittest.TestCase): + + fault_inject_prob = 0.05 + + def create_driver(self, num_executors: int): + cmdline = f"scheduler --job_id {str(uuid.uuid4())} --job_name {self._testMethodName} --data_root {self.output_root_abspath} --num_executors {num_executors} --fault_inject_prob {self.fault_inject_prob}" + driver = Driver() + driver.parse_arguments(args=cmdline.split()) + logger.info(f"{cmdline=} {driver.mode=} {driver.job_id=} {driver.data_root=}") + return driver + + def test_standalone_mode(self): + plan = sort_mock_urls(["tests/data/mock_urls/*.tsv"], npartitions=3) + driver = self.create_driver(num_executors=0) + exec_plan = driver.run(plan, stop_process_on_done=False) + self.assertTrue(exec_plan.successful) + self.assertGreater(exec_plan.final_output.num_files, 0) + + def test_run_on_remote_executors(self): + driver = self.create_driver(num_executors=2) + plan = gray_sort_benchmark( + record_nbytes=100, + key_nbytes=10, + total_data_nbytes=1 * GB, + gensort_batch_nbytes=100 * MB, + num_data_partitions=10, + num_sort_partitions=10, + validate_results=True, + ) + exec_plan = driver.run(plan, stop_process_on_done=False) + self.assertTrue(exec_plan.successful) + self.assertGreater(exec_plan.final_output.num_files, 0) diff --git a/tests/test_execution.py b/tests/test_execution.py new file mode 100644 index 0000000..b503b3b --- /dev/null +++ b/tests/test_execution.py @@ -0,0 +1,886 @@ +import functools +import os.path +import socket +import tempfile +import time +import unittest +from datetime import datetime +from typing import Iterable, List, Tuple + +import pandas +import pyarrow as arrow +from loguru import logger +from pandas.core.api import DataFrame as DataFrame + +from smallpond.common import GB, MB, split_into_rows +from smallpond.execution.task import ( + DataSinkTask, + DataSourceTask, + JobId, + PartitionInfo, + PythonScriptTask, + RuntimeContext, + StreamOutput, +) +from smallpond.execution.workqueue import WorkStatus +from smallpond.logical.dataset import ( + ArrowTableDataSet, + DataSet, + ParquetDataSet, + SqlQueryDataSet, +) +from smallpond.logical.node import ( + ArrowBatchNode, + ArrowComputeNode, + ArrowStreamNode, + Context, + DataSetPartitionNode, + DataSinkNode, + DataSourceNode, + EvenlyDistributedPartitionNode, + HashPartitionNode, + LogicalPlan, + Node, + PandasBatchNode, + PandasComputeNode, + ProjectionNode, + PythonScriptNode, + RootNode, + SqlEngineNode, +) +from smallpond.logical.udf import UDFListType, UDFType +from tests.test_fabric import TestFabric + + +class OutputMsgPythonTask(PythonScriptTask): + def __init__(self, msg: str, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.msg = msg + + def initialize(self): + pass + + def finalize(self): + pass + + def process( + self, + runtime_ctx: RuntimeContext, + input_datasets: List[DataSet], + output_path: str, + ) -> bool: + logger.info( + f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}" + ) + self.inject_fault() + return True + + +# method1: inherit Task class and override spawn method +class OutputMsgPythonNode(PythonScriptNode): + def spawn(self, *args, **kwargs) -> OutputMsgPythonTask: + return OutputMsgPythonTask("python script", *args, **kwargs) + + +# method2: override process method +# this usage is not recommended and only for testing. use `process_func` instead. +class OutputMsgPythonNode2(PythonScriptNode): + def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], msg: str) -> None: + super().__init__(ctx, input_deps) + self.msg = msg + + def process( + self, + runtime_ctx: RuntimeContext, + input_datasets: List[DataSet], + output_path: str, + ) -> bool: + logger.info(f"msg: {self.msg}, num files: {input_datasets[0].num_files}") + return True + + +# this usage is not recommended and only for testing. use `process_func` instead. +class CopyInputArrowNode(ArrowComputeNode): + def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], msg: str) -> None: + super().__init__(ctx, input_deps) + self.msg = msg + + def process( + self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] + ) -> arrow.Table: + return copy_input_arrow(runtime_ctx, input_tables, self.msg) + + +# this usage is not recommended and only for testing. use `process_func` instead. +class CopyInputStreamNode(ArrowStreamNode): + def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], msg: str) -> None: + super().__init__(ctx, input_deps) + self.msg = msg + + def process( + self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader] + ) -> Iterable[arrow.Table]: + return copy_input_stream(runtime_ctx, input_readers, self.msg) + + +def copy_input_arrow( + runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str +) -> arrow.Table: + logger.info(f"msg: {msg}, num rows: {input_tables[0].num_rows}") + time.sleep(runtime_ctx.secs_executor_probe_interval) + runtime_ctx.task.inject_fault() + return input_tables[0] + + +def copy_input_stream( + runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str +) -> Iterable[arrow.Table]: + for index, batch in enumerate(input_readers[0]): + logger.info(f"msg: {msg}, batch index: {index}, num rows: {batch.num_rows}") + time.sleep(runtime_ctx.secs_executor_probe_interval) + yield StreamOutput( + arrow.Table.from_batches([batch]), + batch_indices=[index], + force_checkpoint=True, + ) + runtime_ctx.task.inject_fault() + + +def copy_input_batch( + runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str +) -> arrow.Table: + logger.info(f"msg: {msg}, num rows: {input_batches[0].num_rows}") + time.sleep(runtime_ctx.secs_executor_probe_interval) + runtime_ctx.task.inject_fault() + return input_batches[0] + + +def copy_input_data_frame( + runtime_ctx: RuntimeContext, input_dfs: List[DataFrame] +) -> DataFrame: + runtime_ctx.task.inject_fault() + return input_dfs[0] + + +def copy_input_data_frame_batch( + runtime_ctx: RuntimeContext, input_dfs: List[DataFrame] +) -> DataFrame: + runtime_ctx.task.inject_fault() + return input_dfs[0] + + +def merge_input_tables( + runtime_ctx: RuntimeContext, input_batches: List[arrow.Table] +) -> arrow.Table: + runtime_ctx.task.inject_fault() + output = arrow.concat_tables(input_batches) + logger.info( + f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}" + ) + return output + + +def merge_input_data_frames( + runtime_ctx: RuntimeContext, input_dfs: List[DataFrame] +) -> DataFrame: + runtime_ctx.task.inject_fault() + output = pandas.concat(input_dfs) + logger.info( + f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}" + ) + return output + + +def parse_url( + runtime_ctx: RuntimeContext, input_tables: List[arrow.Table] +) -> arrow.Table: + urls = input_tables[0].columns[0] + hosts = [url.as_py().split("/", maxsplit=2)[0] for url in urls] + return input_tables[0].append_column("host", arrow.array(hosts)) + + +def nonzero_exit_code( + runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str +) -> bool: + import sys + + if runtime_ctx.task._memory_boost == 1: + sys.exit(1) + return True + + +# create an empty file with a fixed name +def empty_file( + runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str +) -> bool: + import os + + with open(os.path.join(output_path, "file"), "w") as fout: + pass + return True + + +def return_fake_gpus(count: int = 8): + import GPUtil + + return [GPUtil.GPU(i, *list(range(11))) for i in range(count)] + + +def split_url(urls: arrow.array) -> arrow.array: + url_parts = [url.as_py().split("/") for url in urls] + return arrow.array(url_parts, type=arrow.list_(arrow.string())) + + +def choose_random_urls( + runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5 +) -> arrow.Table: + # get the current running task + runtime_task = runtime_ctx.task + # access task-specific attributes + cpu_limit = runtime_task.cpu_limit + random_gen = runtime_task.python_random_gen + # input data + (url_table,) = input_tables + hosts, urls = url_table.columns + logger.info(f"{cpu_limit=} {len(urls)=}") + # generate ramdom samples + random_urls = random_gen.choices(urls.to_pylist(), k=k) + return arrow.Table.from_arrays([arrow.array(random_urls)], names=["random_urls"]) + + +class TestExecution(TestFabric, unittest.TestCase): + + fault_inject_prob = 0.05 + + def test_arrow_task(self): + for use_duckdb_reader in (False, True): + with self.subTest(use_duckdb_reader=use_duckdb_reader): + with tempfile.TemporaryDirectory( + dir=self.output_root_abspath + ) as output_dir: + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_table = dataset.to_arrow_table() + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode( + ctx, (data_files,), npartitions=7 + ) + if use_duckdb_reader: + data_partitions = ProjectionNode( + ctx, + data_partitions, + columns=["*", "string_split(url, '/')[0] as parsed_host"], + ) + arrow_compute = ArrowComputeNode( + ctx, + (data_partitions,), + process_func=functools.partial( + copy_input_arrow, msg="arrow compute" + ), + use_duckdb_reader=use_duckdb_reader, + output_name="arrow_compute", + output_path=output_dir, + cpu_limit=2, + ) + arrow_stream = ArrowStreamNode( + ctx, + (data_partitions,), + process_func=functools.partial( + copy_input_stream, msg="arrow stream" + ), + streaming_batch_size=10, + secs_checkpoint_interval=0.5, + use_duckdb_reader=use_duckdb_reader, + output_name="arrow_stream", + output_path=output_dir, + cpu_limit=2, + ) + arrow_batch = ArrowBatchNode( + ctx, + (data_partitions,), + process_func=functools.partial( + copy_input_batch, msg="arrow batch" + ), + streaming_batch_size=10, + secs_checkpoint_interval=0.5, + use_duckdb_reader=use_duckdb_reader, + output_name="arrow_batch", + output_path=output_dir, + cpu_limit=2, + ) + data_sink = DataSinkNode( + ctx, + (arrow_compute, arrow_stream, arrow_batch), + output_path=output_dir, + ) + plan = LogicalPlan(ctx, data_sink) + exec_plan = self.execute_plan( + plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5 + ) + self.assertTrue( + all(map(os.path.exists, exec_plan.final_output.resolved_paths)) + ) + arrow_compute_output = ParquetDataSet( + [os.path.join(output_dir, "arrow_compute", "**/*.parquet")], + recursive=True, + ) + arrow_stream_output = ParquetDataSet( + [os.path.join(output_dir, "arrow_stream", "**/*.parquet")], + recursive=True, + ) + arrow_batch_output = ParquetDataSet( + [os.path.join(output_dir, "arrow_batch", "**/*.parquet")], + recursive=True, + ) + self._compare_arrow_tables( + data_table, + arrow_compute_output.to_arrow_table().select( + data_table.column_names + ), + ) + self._compare_arrow_tables( + data_table, + arrow_stream_output.to_arrow_table().select( + data_table.column_names + ), + ) + self._compare_arrow_tables( + data_table, + arrow_batch_output.to_arrow_table().select( + data_table.column_names + ), + ) + + def test_pandas_task(self): + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_table = dataset.to_arrow_table() + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=7) + pandas_compute = PandasComputeNode( + ctx, + (data_partitions,), + process_func=copy_input_data_frame, + output_name="pandas_compute", + output_path=output_dir, + cpu_limit=2, + ) + pandas_batch = PandasBatchNode( + ctx, + (data_partitions,), + process_func=copy_input_data_frame_batch, + streaming_batch_size=10, + secs_checkpoint_interval=0.5, + output_name="pandas_batch", + output_path=output_dir, + cpu_limit=2, + ) + data_sink = DataSinkNode( + ctx, (pandas_compute, pandas_batch), output_path=output_dir + ) + plan = LogicalPlan(ctx, data_sink) + exec_plan = self.execute_plan( + plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5 + ) + self.assertTrue( + all(map(os.path.exists, exec_plan.final_output.resolved_paths)) + ) + pandas_compute_output = ParquetDataSet( + [os.path.join(output_dir, "pandas_compute", "**/*.parquet")], + recursive=True, + ) + pandas_batch_output = ParquetDataSet( + [os.path.join(output_dir, "pandas_batch", "**/*.parquet")], + recursive=True, + ) + self._compare_arrow_tables( + data_table, pandas_compute_output.to_arrow_table() + ) + self._compare_arrow_tables(data_table, pandas_batch_output.to_arrow_table()) + + def test_variable_length_input_datasets(self): + ctx = Context() + small_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + large_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10) + small_partitions = DataSetPartitionNode( + ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7 + ) + large_partitions = DataSetPartitionNode( + ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7 + ) + arrow_batch = ArrowBatchNode( + ctx, + (small_partitions, large_partitions), + process_func=merge_input_tables, + streaming_batch_size=100, + secs_checkpoint_interval=0.5, + output_name="arrow_batch", + cpu_limit=2, + ) + pandas_batch = PandasBatchNode( + ctx, + (small_partitions, large_partitions), + process_func=merge_input_data_frames, + streaming_batch_size=100, + secs_checkpoint_interval=0.5, + output_name="pandas_batch", + cpu_limit=2, + ) + plan = LogicalPlan(ctx, RootNode(ctx, (arrow_batch, pandas_batch))) + exec_plan = self.execute_plan( + plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5 + ) + self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths))) + arrow_batch_output = ParquetDataSet( + [os.path.join(exec_plan.ctx.output_root, "arrow_batch", "**/*.parquet")], + recursive=True, + ) + pandas_batch_output = ParquetDataSet( + [os.path.join(exec_plan.ctx.output_root, "pandas_batch", "**/*.parquet")], + recursive=True, + ) + self.assertEqual( + small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows + ) + self.assertEqual( + small_dataset.num_rows + large_dataset.num_rows, + pandas_batch_output.num_rows, + ) + + def test_projection_task(self): + ctx = Context() + # select columns when defining dataset + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"], columns=["url"]) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode( + ctx, (data_files,), npartitions=3, partition_by_rows=True + ) + # projection as input of arrow node + generated_columns = ["filename", "file_row_number"] + urls_with_host = ArrowComputeNode( + ctx, + (ProjectionNode(ctx, data_partitions, ["url"], generated_columns),), + process_func=parse_url, + use_duckdb_reader=True, + ) + # projection as input of sql node + distinct_urls_with_host = SqlEngineNode( + ctx, + ( + ProjectionNode( + ctx, + data_partitions, + ["url", "string_split(url, '/')[0] as host"], + generated_columns, + ), + ), + r"select distinct host, url, filename from {0}", + ) + # unify different schemas + merged_diff_schemas = ProjectionNode( + ctx, + DataSetPartitionNode( + ctx, (distinct_urls_with_host, urls_with_host), npartitions=1 + ), + union_by_name=True, + ) + host_partitions = HashPartitionNode( + ctx, + (merged_diff_schemas,), + npartitions=3, + hash_columns=["host"], + engine_type="duckdb", + output_name="host_partitions", + ) + host_partitions.max_num_producer_tasks = 1 + plan = LogicalPlan(ctx, host_partitions) + final_output = self.execute_plan(plan, fault_inject_prob=0.1).final_output + final_table = final_output.to_arrow_table() + self.assertEqual( + sorted( + [ + "url", + "host", + *generated_columns, + HashPartitionNode.default_data_partition_column, + ] + ), + sorted(final_table.column_names), + ) + + def test_arrow_type_in_udfs(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode( + ctx, (data_files,), npartitions=dataset.num_files + ) + ctx.create_function( + "split_url", + split_url, + [UDFType.VARCHAR], + UDFListType(UDFType.VARCHAR), + use_arrow_type=True, + ) + uniq_hosts = SqlEngineNode( + ctx, + (data_partitions,), + r"select split_url(url) as url_parts from {0}", + udfs=["split_url"], + ) + plan = LogicalPlan(ctx, uniq_hosts) + self.execute_plan(plan) + + def test_many_simple_tasks(self): + ctx = Context() + npartitions = 1000 + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * npartitions) + data_files = DataSourceNode(ctx, dataset) + data_partitions = EvenlyDistributedPartitionNode( + ctx, (data_files,), npartitions=npartitions + ) + output_msg = OutputMsgPythonNode(ctx, (data_partitions,)) + plan = LogicalPlan(ctx, output_msg) + self.execute_plan( + plan, + num_executors=10, + secs_executor_probe_interval=5, + enable_profiling=True, + ) + + def test_many_producers_and_partitions(self): + ctx = Context() + npartitions = 10000 + dataset = ParquetDataSet( + ["tests/data/mock_urls/*.parquet"] * (npartitions * 10) + ) + data_files = DataSourceNode(ctx, dataset) + data_partitions = EvenlyDistributedPartitionNode( + ctx, (data_files,), npartitions=npartitions, cpu_limit=1 + ) + data_partitions.max_num_producer_tasks = 20 + output_msg = OutputMsgPythonNode(ctx, (data_partitions,)) + plan = LogicalPlan(ctx, output_msg) + self.execute_plan( + plan, + num_executors=10, + secs_executor_probe_interval=5, + enable_profiling=True, + ) + + def test_local_gpu_rank(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode( + ctx, (data_files,), npartitions=dataset.num_files + ) + output_msg = OutputMsgPythonNode( + ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5 + ) + plan = LogicalPlan(ctx, output_msg) + runtime_ctx = RuntimeContext( + JobId.new(), + datetime.now(), + self.output_root_abspath, + console_log_level="WARNING", + ) + runtime_ctx.get_local_gpus = return_fake_gpus + runtime_ctx.initialize(socket.gethostname(), cleanup_root=True) + self.execute_plan(plan, runtime_ctx=runtime_ctx) + + def test_python_node_with_process_method(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + copy_input_arrow_node = CopyInputArrowNode(ctx, (data_files,), "hello") + copy_input_stream_node = CopyInputStreamNode(ctx, (data_files,), "hello") + output_msg = OutputMsgPythonNode2( + ctx, (copy_input_arrow_node, copy_input_stream_node), "hello" + ) + plan = LogicalPlan(ctx, output_msg) + self.execute_plan(plan) + + def test_sql_engine_oom(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + uniq_urls = SqlEngineNode( + ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB + ) + uniq_url_partitions = DataSetPartitionNode(ctx, (uniq_urls,), 2) + uniq_url_count = SqlEngineNode( + ctx, + (uniq_url_partitions,), + sql_query=r"select count(distinct columns(*)) from {0}", + memory_limit=2 * MB, + ) + plan = LogicalPlan(ctx, uniq_url_count) + self.execute_plan(plan, max_fail_count=10) + + @unittest.skip("flaky on CI") + def test_enforce_memory_limit(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/arrow/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + arrow_compute = ArrowComputeNode( + ctx, + (data_files,), + process_func=functools.partial(copy_input_arrow, msg="arrow compute"), + memory_limit=1 * GB, + ) + arrow_stream = ArrowStreamNode( + ctx, + (data_files,), + process_func=functools.partial(copy_input_stream, msg="arrow stream"), + memory_limit=1 * GB, + ) + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + data_sink = DataSinkNode( + ctx, (arrow_compute, arrow_stream), output_path=output_dir + ) + plan = LogicalPlan(ctx, data_sink) + self.execute_plan( + plan, + max_fail_count=10, + enforce_memory_limit=True, + nonzero_exitcode_as_oom=True, + ) + + def test_task_crash_as_oom(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + nonzero_exitcode = PythonScriptNode( + ctx, (data_files,), process_func=nonzero_exit_code + ) + plan = LogicalPlan(ctx, nonzero_exitcode) + exec_plan = self.execute_plan( + plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False + ) + self.assertFalse(exec_plan.successful) + exec_plan = self.execute_plan( + plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True + ) + self.assertTrue(exec_plan.successful) + + def test_manifest_only_data_sink(self): + with open("tests/data/long_path_list.txt", buffering=16 * MB) as fin: + filenames = list(map(str.strip, fin.readlines())) + logger.info(f"loaded {len(filenames)} filenames") + + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + ctx = Context() + dataset = ParquetDataSet(filenames) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=512) + data_sink = DataSinkNode( + ctx, (data_partitions,), output_path=output_dir, manifest_only=True + ) + plan = LogicalPlan(ctx, data_sink) + self.execute_plan(plan) + + with open( + os.path.join(output_dir, DataSinkTask.manifest_filename), + buffering=16 * MB, + ) as fin: + num_lines = len(fin.readlines()) + self.assertEqual(len(filenames), num_lines) + + def test_sql_batched_processing(self): + for materialize_in_memory in (False, True): + with self.subTest(materialize_in_memory=materialize_in_memory): + ctx = Context() + dataset = ParquetDataSet(["tests/data/large_array/*.parquet"] * 2) + data_files = DataSourceNode(ctx, dataset) + content_length = SqlEngineNode( + ctx, + (data_files,), + r"select url, octet_length(content) as content_len from {0}", + materialize_in_memory=materialize_in_memory, + batched_processing=True, + cpu_limit=2, + memory_limit=2 * GB, + ) + plan = LogicalPlan(ctx, content_length) + final_output: ParquetDataSet = self.execute_plan(plan).final_output + self.assertEqual(dataset.num_rows, final_output.num_rows) + + def test_multiple_sql_queries(self): + for materialize_in_memory in (False, True): + with self.subTest(materialize_in_memory=materialize_in_memory): + ctx = Context() + dataset = ParquetDataSet(["tests/data/large_array/*.parquet"] * 2) + data_files = DataSourceNode(ctx, dataset) + content_length = SqlEngineNode( + ctx, + (data_files,), + [ + r"create or replace temp table content_len_data as select url, octet_length(content) as content_len from {0}", + r"select * from content_len_data", + ], + materialize_in_memory=materialize_in_memory, + batched_processing=True, + cpu_limit=2, + memory_limit=2 * GB, + ) + plan = LogicalPlan(ctx, content_length) + final_output: ParquetDataSet = self.execute_plan(plan).final_output + self.assertEqual(dataset.num_rows, final_output.num_rows) + + def test_temp_outputs_in_final_results(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10) + url_counts = SqlEngineNode( + ctx, (data_partitions,), r"select count(url) as cnt from {0}" + ) + distinct_url_counts = SqlEngineNode( + ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}" + ) + merged_counts = DataSetPartitionNode( + ctx, + ( + ProjectionNode(ctx, url_counts, ["cnt"]), + ProjectionNode(ctx, distinct_url_counts, ["cnt"]), + ), + npartitions=1, + ) + split_counts = DataSetPartitionNode(ctx, (merged_counts,), npartitions=10) + plan = LogicalPlan(ctx, split_counts) + final_output: ParquetDataSet = self.execute_plan(plan).final_output + self.assertEqual(data_partitions.npartitions * 2, final_output.num_rows) + + def test_override_output_path(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10) + url_counts = SqlEngineNode( + ctx, + (data_partitions,), + r"select count(url) as cnt from {0}", + output_name="url_counts", + ) + distinct_url_counts = SqlEngineNode( + ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}" + ) + merged_counts = DataSetPartitionNode( + ctx, + ( + ProjectionNode(ctx, url_counts, ["cnt"]), + ProjectionNode(ctx, distinct_url_counts, ["cnt"]), + ), + npartitions=1, + ) + plan = LogicalPlan(ctx, merged_counts) + + output_path = os.path.join(self.runtime_ctx.output_root, "final_output") + final_output = self.execute_plan(plan, output_path=output_path).final_output + self.assertTrue(os.path.exists(os.path.join(output_path, "url_counts"))) + self.assertTrue(os.path.exists(os.path.join(output_path, "FinalResults"))) + + def test_data_sink_avoid_filename_conflicts(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10) + empty_files1 = PythonScriptNode( + ctx, (data_partitions,), process_func=empty_file + ) + empty_files2 = PythonScriptNode( + ctx, (data_partitions,), process_func=empty_file + ) + link_path = os.path.join(self.runtime_ctx.output_root, "link") + copy_path = os.path.join(self.runtime_ctx.output_root, "copy") + copy_input_path = os.path.join(self.runtime_ctx.output_root, "copy_input") + data_link = DataSinkNode( + ctx, (empty_files1, empty_files2), type="link", output_path=link_path + ) + data_copy = DataSinkNode( + ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path + ) + data_copy_input = DataSinkNode( + ctx, (data_partitions,), type="copy", output_path=copy_input_path + ) + plan = LogicalPlan(ctx, RootNode(ctx, (data_link, data_copy, data_copy_input))) + + self.execute_plan(plan) + # there should be 21 files (20 input files + 1 manifest file) in the sink dir + self.assertEqual(21, len(os.listdir(link_path))) + self.assertEqual(21, len(os.listdir(copy_path))) + # file name should not be modified if no conflict + self.assertEqual( + set( + filename + for filename in os.listdir("tests/data/mock_urls") + if filename.endswith(".parquet") + ), + set( + filename + for filename in os.listdir(copy_input_path) + if filename.endswith(".parquet") + ), + ) + + def test_literal_datasets_as_data_sources(self): + ctx = Context() + num_rows = 10 + query_dataset = SqlQueryDataSet(f"select i from range({num_rows}) as x(i)") + table_dataset = ArrowTableDataSet( + arrow.Table.from_arrays([list(range(num_rows))], names=["i"]) + ) + query_source = DataSourceNode(ctx, query_dataset) + table_source = DataSourceNode(ctx, table_dataset) + query_partitions = DataSetPartitionNode( + ctx, (query_source,), npartitions=num_rows, partition_by_rows=True + ) + table_partitions = DataSetPartitionNode( + ctx, (table_source,), npartitions=num_rows, partition_by_rows=True + ) + joined_rows = SqlEngineNode( + ctx, + (query_partitions, table_partitions), + r"select a.i as i, b.i as j from {0} as a join {1} as b on a.i = b.i", + ) + plan = LogicalPlan(ctx, joined_rows) + final_output: ParquetDataSet = self.execute_plan(plan).final_output + self.assertEqual(num_rows, final_output.num_rows) + + def test_partial_process_func(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=3) + # use default value of k + random_urls_k5 = ArrowComputeNode( + ctx, + (data_partitions,), + process_func=choose_random_urls, + output_name="random_urls_k5", + ) + # set value of k using functools.partial + random_urls_k10 = ArrowComputeNode( + ctx, + (data_partitions,), + process_func=functools.partial(choose_random_urls, k=10), + output_name="random_urls_k10", + ) + random_urls_all = SqlEngineNode( + ctx, + (random_urls_k5, random_urls_k10), + r"select * from {0} union select * from {1}", + output_name="random_urls_all", + ) + plan = LogicalPlan(ctx, random_urls_all) + exec_plan = self.execute_plan(plan) + self.assertEqual( + data_partitions.npartitions * 5, + exec_plan.get_output("random_urls_k5").to_arrow_table().num_rows, + ) + self.assertEqual( + data_partitions.npartitions * 10, + exec_plan.get_output("random_urls_k10").to_arrow_table().num_rows, + ) diff --git a/tests/test_fabric.py b/tests/test_fabric.py new file mode 100644 index 0000000..4fe9782 --- /dev/null +++ b/tests/test_fabric.py @@ -0,0 +1,294 @@ +import os.path +import queue +import sys +import unittest +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from multiprocessing import Manager, Process +from typing import List, Optional + +import fsspec +import numpy as np +import psutil +import pyarrow as arrow +import pyarrow.compute as pc +import pyarrow.parquet as parquet +from loguru import logger + +from smallpond.common import DEFAULT_MAX_FAIL_COUNT, DEFAULT_MAX_RETRY_COUNT, GB, MB +from smallpond.execution.executor import Executor +from smallpond.execution.scheduler import Scheduler +from smallpond.execution.task import ExecutionPlan, JobId, RuntimeContext +from smallpond.io.arrow import cast_columns_to_large_string +from smallpond.logical.node import LogicalPlan +from smallpond.logical.planner import Planner +from tests.datagen import generate_data + +generate_data() + + +def run_scheduler( + runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue +): + runtime_ctx.initialize("scheduler") + scheduler.add_state_observer(Scheduler.StateObserver(SaveSchedState(queue))) + retval = scheduler.run() + print(f"scheduler exited with value {retval}", file=sys.stderr) + + +def run_executor(runtime_ctx: RuntimeContext, executor: Executor): + runtime_ctx.initialize(executor.id) + retval = executor.run() + print(f"{executor.id} exited with value {retval}", file=sys.stderr) + + +class SaveSchedState: + """ + A state observer that push the scheduler state into a queue when finished. + """ + + def __init__(self, queue: queue.Queue): + self.queue = queue + + def __call__(self, sched_state: Scheduler) -> bool: + if sched_state.num_local_running_works == 0: + self.queue.put(sched_state) + return True + + +class TestFabric(unittest.TestCase): + """ + A helper class that includes boilerplate code to test a logical plan. + """ + + runtime_root = os.getenv("TEST_RUNTIME_ROOT") or f"tests/runtime" + runtime_ctx = None + fault_inject_prob = 0.00 + + queue_manager = None + sched_states: queue.Queue = None + latest_state: Scheduler = None + executors: List[Executor] = None + processes: List[Process] = None + + @property + def output_dir(self): + return os.path.join(self.__class__.__name__, self._testMethodName) + + @property + def output_root_abspath(self): + output_root = os.path.abspath(os.path.join(self.runtime_root, self.output_dir)) + os.makedirs(output_root, exist_ok=True) + return output_root + + def setUp(self) -> None: + try: + from pytest_cov.embed import cleanup_on_sigterm + except ImportError: + pass + else: + cleanup_on_sigterm() + self.runtime_ctx = RuntimeContext( + JobId.new(), + datetime.now(), + self.output_root_abspath, + console_log_level="WARNING", + ) + self.runtime_ctx.initialize("setup") + return super().setUp() + + def tearDown(self) -> None: + if self.sched_states is not None: + self.get_latest_sched_state() + assert self.sched_states.qsize() == 0 + self.sched_states = None + if self.queue_manager is not None: + self.queue_manager.shutdown() + self.queue_manager = None + return super().tearDown() + + def get_latest_sched_state(self) -> Scheduler: + while True: + try: + self.latest_state = self.sched_states.get(block=False) + except queue.Empty: + return self.latest_state + + def join_running_procs(self, timeout=30): + for i, process in enumerate(self.processes): + if process.is_alive(): + logger.info(f"join #{i} process: {process.name}") + process.join(timeout=None if i == 0 else timeout) + + if process.exitcode is None: + logger.info(f"terminate #{i} process: {process.name}") + process.terminate() + process.join(timeout=timeout) + + if process.exitcode is None: + logger.info(f"kill #{i} process: {process.name}") + process.kill() + process.join() + + logger.info( + f"#{i} process {process.name} exited with code {process.exitcode}" + ) + + def start_execution( + self, + plan: LogicalPlan, + num_executors: int = 2, + secs_wq_poll_interval: float = 0.1, + secs_executor_probe_interval: float = 1, + max_num_missed_probes: int = 10, + max_retry_count: int = DEFAULT_MAX_RETRY_COUNT, + max_fail_count: int = DEFAULT_MAX_FAIL_COUNT, + prioritize_retry=False, + speculative_exec="enable", + stop_executor_on_failure=False, + enforce_memory_limit=False, + nonzero_exitcode_as_oom=False, + fault_inject_prob=None, + enable_profiling=False, + enable_diagnostic_metrics=False, + remove_empty_parquet=False, + skip_task_with_empty_input=False, + console_log_level="WARNING", + file_log_level="DEBUG", + output_path: Optional[str] = None, + runtime_ctx: Optional[RuntimeContext] = None, + ): + """ + Start a scheduler and `num_executors` executors to execute `plan`. + When this function returns, the execution is mostly still running. + + Parameters + ---------- + plan + A logical plan. + num_executors, optional + The number of executors + console_log_level, optional + Set to logger.INFO if more verbose loguru is needed for debug, by default "WARNING". + + Returns + ------- + A 3-tuple of type (Scheduler, List[Executor], List[Process]). + """ + if runtime_ctx is None: + runtime_ctx = RuntimeContext( + JobId.new(), + datetime.now(), + self.output_root_abspath, + num_executors=num_executors, + random_seed=123456, + enforce_memory_limit=enforce_memory_limit, + max_usable_cpu_count=min(64, psutil.cpu_count(logical=False)), + max_usable_gpu_count=0, + max_usable_memory_size=min(64 * GB, psutil.virtual_memory().total), + secs_wq_poll_interval=secs_wq_poll_interval, + secs_executor_probe_interval=secs_executor_probe_interval, + max_num_missed_probes=max_num_missed_probes, + fault_inject_prob=( + fault_inject_prob + if fault_inject_prob is not None + else self.fault_inject_prob + ), + enable_profiling=enable_profiling, + enable_diagnostic_metrics=enable_diagnostic_metrics, + remove_empty_parquet=remove_empty_parquet, + skip_task_with_empty_input=skip_task_with_empty_input, + console_log_level=console_log_level, + file_log_level=file_log_level, + output_path=output_path, + ) + + self.queue_manager = Manager() + self.sched_states = self.queue_manager.Queue() + + exec_plan = Planner(runtime_ctx).create_exec_plan(plan) + scheduler = Scheduler( + exec_plan, + max_retry_count=max_retry_count, + max_fail_count=max_fail_count, + prioritize_retry=prioritize_retry, + speculative_exec=speculative_exec, + stop_executor_on_failure=stop_executor_on_failure, + nonzero_exitcode_as_oom=nonzero_exitcode_as_oom, + ) + self.latest_state = scheduler + self.executors = [ + Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors) + ] + self.processes = [ + Process( + target=run_scheduler, + # XXX: on macOS, scheduler state observer will be cleared when cross-process + # so we pass the queue and add the observer in the new process + args=(runtime_ctx, scheduler, self.sched_states), + name="scheduler", + ) + ] + self.processes += [ + Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id) + for executor in self.executors + ] + + for process in reversed(self.processes): + process.start() + + return self.sched_states, self.executors, self.processes + + def execute_plan(self, *args, check_result=True, **kvargs) -> ExecutionPlan: + """ + Start a scheduler and `num_executors` executors to execute `plan`, + and wait the execution completed, then assert if it succeeds. + + Parameters + ---------- + plan + A logical plan. + num_executors, optional + The number of executors + console_log_level, optional + Set to logger.INFO if more verbose loguru is needed for debug, by default "WARNING". + + Returns + ------- + The completed ExecutionPlan instance. + """ + self.start_execution(*args, **kvargs) + self.join_running_procs() + latest_state = self.get_latest_sched_state() + if check_result: + self.assertTrue(latest_state.success) + return latest_state.exec_plan + + def _load_parquet_files( + self, paths, filesystem: fsspec.AbstractFileSystem = None + ) -> arrow.Table: + def read_parquet_file(path): + return arrow.Table.from_batches( + parquet.ParquetFile( + path, buffer_size=16 * MB, filesystem=filesystem + ).iter_batches() + ) + + with ThreadPoolExecutor(16) as pool: + return arrow.concat_tables(pool.map(read_parquet_file, paths)) + + def _compare_arrow_tables(self, expected: arrow.Table, actual: arrow.Table): + def sorted_table(t: arrow.Table): + return t.sort_by([(col, "ascending") for col in t.column_names]) + + self.assertEqual(expected.shape, actual.shape) + self.assertEqual(expected.column_names, actual.column_names) + expected = sorted_table(cast_columns_to_large_string(expected)) + actual = sorted_table(cast_columns_to_large_string(actual)) + for col, x, y in zip(expected.column_names, expected.columns, actual.columns): + if not pc.equal(x, y): + x = x.to_numpy(zero_copy_only=False) + y = y.to_numpy(zero_copy_only=False) + logger.error(f" expect {col}: {x}") + logger.error(f" actual {col}: {y}") + np.testing.assert_array_equal(x, y, verbose=True) diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py new file mode 100644 index 0000000..d8b1dda --- /dev/null +++ b/tests/test_filesystem.py @@ -0,0 +1,25 @@ +import os.path +import tempfile +import threading +import unittest + +from smallpond.io.filesystem import dump, load +from tests.test_fabric import TestFabric + + +class TestFilesystem(TestFabric, unittest.TestCase): + def test_pickle_runtime_ctx(self): + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + pickle_path = os.path.join(output_dir, "runtime_ctx.pickle") + dump(self.runtime_ctx, pickle_path) + runtime_ctx = load(pickle_path) + self.assertEqual(self.runtime_ctx.job_id, runtime_ctx.job_id) + + def test_pickle_trace(self): + with self.assertRaises(TypeError) as context: + with tempfile.TemporaryDirectory( + dir=self.output_root_abspath + ) as output_dir: + thread = threading.Thread() + pickle_path = os.path.join(output_dir, "thread.pickle") + dump(thread, pickle_path) diff --git a/tests/test_logical.py b/tests/test_logical.py new file mode 100644 index 0000000..a2901a2 --- /dev/null +++ b/tests/test_logical.py @@ -0,0 +1,103 @@ +import unittest + +from loguru import logger + +from smallpond.logical.dataset import ParquetDataSet +from smallpond.logical.node import ( + Context, + DataSetPartitionNode, + DataSourceNode, + EvenlyDistributedPartitionNode, + HashPartitionNode, + LogicalPlan, + SqlEngineNode, +) +from smallpond.logical.planner import Planner +from tests.test_fabric import TestFabric + + +class TestLogicalPlan(TestFabric, unittest.TestCase): + + def test_join_chunkmeta_inodes(self): + ctx = Context() + + chunkmeta_dump = DataSourceNode( + ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"]) + ) + chunkmeta_partitions = HashPartitionNode( + ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"] + ) + + inodes_dump = DataSourceNode( + ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"]) + ) + inodes_partitions = HashPartitionNode( + ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"] + ) + + num_gc_chunks = SqlEngineNode( + ctx, + (chunkmeta_partitions, inodes_partitions), + r""" + select count(chunkmeta_chunkId) from {0} + where chunkmeta.chunkmeta_chunkId NOT LIKE "F%" AND + chunkmeta.inodeId not in ( select distinct inode_id from {1} )""", + ) + + plan = LogicalPlan(ctx, num_gc_chunks) + logger.info(str(plan)) + exec_plan = Planner(self.runtime_ctx).create_exec_plan(plan) + logger.info(str(exec_plan)) + + def test_partition_dims_not_compatible(self): + ctx = Context() + parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_source = DataSourceNode(ctx, parquet_dataset) + partition_dim_a = EvenlyDistributedPartitionNode( + ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A" + ) + partition_dim_b = EvenlyDistributedPartitionNode( + ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B" + ) + join_two_inputs = SqlEngineNode( + ctx, + (partition_dim_a, partition_dim_b), + r"select a.* from {0} as a join {1} as b on a.host = b.host", + ) + plan = LogicalPlan(ctx, join_two_inputs) + logger.info(str(plan)) + with self.assertRaises(AssertionError) as context: + Planner(self.runtime_ctx).create_exec_plan(plan) + + def test_npartitions_not_compatible(self): + ctx = Context() + parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_source = DataSourceNode(ctx, parquet_dataset) + partition_dim_a = EvenlyDistributedPartitionNode( + ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A" + ) + partition_dim_a2 = EvenlyDistributedPartitionNode( + ctx, + (data_source,), + npartitions=parquet_dataset.num_files * 2, + dimension="A", + ) + join_two_inputs1 = SqlEngineNode( + ctx, + (partition_dim_a, partition_dim_a2), + r"select a.* from {0} as a join {1} as b on a.host = b.host", + ) + join_two_inputs2 = SqlEngineNode( + ctx, + (partition_dim_a2, partition_dim_a), + r"select a.* from {0} as a join {1} as b on a.host = b.host", + ) + plan = LogicalPlan( + ctx, + DataSetPartitionNode( + ctx, (join_two_inputs1, join_two_inputs2), npartitions=1 + ), + ) + logger.info(str(plan)) + with self.assertRaises(AssertionError) as context: + Planner(self.runtime_ctx).create_exec_plan(plan) diff --git a/tests/test_partition.py b/tests/test_partition.py new file mode 100644 index 0000000..ac7704c --- /dev/null +++ b/tests/test_partition.py @@ -0,0 +1,659 @@ +import os.path +import tempfile +import unittest +from typing import List + +import pyarrow.compute as pc + +from smallpond.common import DATA_PARTITION_COLUMN_NAME, GB +from smallpond.execution.task import RuntimeContext +from smallpond.logical.dataset import DataSet, ParquetDataSet +from smallpond.logical.node import ( + ArrowComputeNode, + ConsolidateNode, + Context, + DataSetPartitionNode, + DataSinkNode, + DataSourceNode, + EvenlyDistributedPartitionNode, + HashPartitionNode, + LoadPartitionedDataSetNode, + LogicalPlan, + ProjectionNode, + SqlEngineNode, + UnionNode, + UserDefinedPartitionNode, + UserPartitionedDataSourceNode, +) +from tests.test_execution import parse_url +from tests.test_fabric import TestFabric + + +class CalculatePartitionFromFilename(UserDefinedPartitionNode): + def partition(self, runtime_ctx: RuntimeContext, dataset: DataSet) -> List[DataSet]: + partitioned_datasets: List[ParquetDataSet] = [ + ParquetDataSet([]) for _ in range(self.npartitions) + ] + for path in dataset.resolved_paths: + partition_idx = hash(path) % self.npartitions + partitioned_datasets[partition_idx].paths.append(path) + return partitioned_datasets + + +class TestPartition(TestFabric, unittest.TestCase): + def test_many_file_partitions(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode( + ctx, (data_files,), npartitions=dataset.num_files + ) + count_rows = SqlEngineNode( + ctx, + (data_partitions,), + "select count(*) from {0}", + cpu_limit=1, + memory_limit=1 * GB, + ) + plan = LogicalPlan(ctx, count_rows) + self.execute_plan(plan) + + def test_many_row_partitions(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode( + ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True + ) + count_rows = SqlEngineNode( + ctx, + (data_partitions,), + "select count(*) from {0}", + cpu_limit=1, + memory_limit=1 * GB, + ) + plan = LogicalPlan(ctx, count_rows) + exec_plan = self.execute_plan(plan, num_executors=5) + self.assertEqual( + exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows + ) + + def test_empty_dataset_partition(self): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + # create more partitions than files + data_partitions = EvenlyDistributedPartitionNode( + ctx, (data_files,), npartitions=dataset.num_files * 2 + ) + data_partitions.max_num_producer_tasks = 3 + unique_urls = SqlEngineNode( + ctx, + (data_partitions,), + r"select distinct url from {0}", + cpu_limit=1, + memory_limit=1 * GB, + ) + # nested partition + nested_partitioned_urls = EvenlyDistributedPartitionNode( + ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True + ) + parsed_urls = ArrowComputeNode( + ctx, + (nested_partitioned_urls,), + process_func=parse_url, + cpu_limit=1, + memory_limit=1 * GB, + ) + plan = LogicalPlan(ctx, parsed_urls) + final_output = self.execute_plan( + plan, remove_empty_parquet=True, skip_task_with_empty_input=True + ).final_output + self.assertTrue(isinstance(final_output, ParquetDataSet)) + self.assertEqual(dataset.num_rows, final_output.num_rows) + + def test_hash_partition(self): + for engine_type in ("duckdb", "arrow"): + for partition_by_rows in (False, True): + for hive_partitioning in ( + (False, True) if engine_type == "duckdb" else (False,) + ): + with self.subTest( + engine_type=engine_type, + partition_by_rows=partition_by_rows, + hive_partitioning=hive_partitioning, + ): + ctx = Context() + dataset = ParquetDataSet(["tests/data/arrow/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + npartitions = 3 + data_partitions = DataSetPartitionNode( + ctx, + (data_files,), + npartitions=npartitions, + partition_by_rows=partition_by_rows, + ) + hash_partitions = HashPartitionNode( + ctx, + (ProjectionNode(ctx, data_partitions, ["url"]),), + npartitions=npartitions, + hash_columns=["url"], + engine_type=engine_type, + hive_partitioning=hive_partitioning, + cpu_limit=2, + memory_limit=2 * GB, + output_name="hash_partitions", + ) + row_count = SqlEngineNode( + ctx, + (hash_partitions,), + r"select count(*) as row_count from {0}", + cpu_limit=1, + memory_limit=1 * GB, + ) + plan = LogicalPlan(ctx, row_count) + exec_plan = self.execute_plan(plan) + self.assertEqual( + dataset.num_rows, + pc.sum( + exec_plan.final_output.to_arrow_table().column( + "row_count" + ) + ).as_py(), + ) + self.assertEqual( + npartitions, + len( + exec_plan.final_output.load_partitioned_datasets( + npartitions, DATA_PARTITION_COLUMN_NAME + ) + ), + ) + self.assertEqual( + npartitions, + len( + exec_plan.get_output( + "hash_partitions" + ).load_partitioned_datasets( + npartitions, + DATA_PARTITION_COLUMN_NAME, + hive_partitioning, + ) + ), + ) + + def test_empty_hash_partition(self): + for engine_type in ("duckdb", "arrow"): + for partition_by_rows in (False, True): + for hive_partitioning in ( + (False, True) if engine_type == "duckdb" else (False,) + ): + with self.subTest( + engine_type=engine_type, + partition_by_rows=partition_by_rows, + hive_partitioning=hive_partitioning, + ): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + npartitions = 3 + npartitions_nested = 4 + num_rows = 1 + head_rows = SqlEngineNode( + ctx, (data_files,), f"select * from {{0}} limit {num_rows}" + ) + data_partitions = DataSetPartitionNode( + ctx, + (head_rows,), + npartitions=npartitions, + partition_by_rows=partition_by_rows, + ) + hash_partitions = HashPartitionNode( + ctx, + (data_partitions,), + npartitions=npartitions, + hash_columns=["url"], + data_partition_column="hash_partitions", + engine_type=engine_type, + hive_partitioning=hive_partitioning, + output_name="hash_partitions", + cpu_limit=2, + memory_limit=1 * GB, + ) + nested_hash_partitions = HashPartitionNode( + ctx, + (hash_partitions,), + npartitions=npartitions_nested, + hash_columns=["url"], + data_partition_column="nested_hash_partitions", + nested=True, + engine_type=engine_type, + hive_partitioning=hive_partitioning, + output_name="nested_hash_partitions", + cpu_limit=2, + memory_limit=1 * GB, + ) + select_every_row = SqlEngineNode( + ctx, + (nested_hash_partitions,), + r"select * from {0}", + cpu_limit=1, + memory_limit=1 * GB, + ) + plan = LogicalPlan(ctx, select_every_row) + exec_plan = self.execute_plan( + plan, skip_task_with_empty_input=True + ) + self.assertEqual(num_rows, exec_plan.final_output.num_rows) + self.assertEqual( + npartitions, + len( + exec_plan.final_output.load_partitioned_datasets( + npartitions, "hash_partitions" + ) + ), + ) + self.assertEqual( + npartitions_nested, + len( + exec_plan.final_output.load_partitioned_datasets( + npartitions_nested, "nested_hash_partitions" + ) + ), + ) + self.assertEqual( + npartitions, + len( + exec_plan.get_output( + "hash_partitions" + ).load_partitioned_datasets( + npartitions, "hash_partitions" + ) + ), + ) + self.assertEqual( + npartitions_nested, + len( + exec_plan.get_output( + "nested_hash_partitions" + ).load_partitioned_datasets( + npartitions_nested, "nested_hash_partitions" + ) + ), + ) + if hive_partitioning: + self.assertEqual( + npartitions, + len( + exec_plan.get_output( + "hash_partitions" + ).load_partitioned_datasets( + npartitions, + "hash_partitions", + hive_partitioning=True, + ) + ), + ) + self.assertEqual( + npartitions_nested, + len( + exec_plan.get_output( + "nested_hash_partitions" + ).load_partitioned_datasets( + npartitions_nested, + "nested_hash_partitions", + hive_partitioning=True, + ) + ), + ) + + def test_load_partitioned_datasets(self): + def run_test_plan( + npartitions: int, + data_partition_column: str, + engine_type: str, + hive_partitioning: bool, + ): + ctx = Context() + input_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + input_data_files = DataSourceNode(ctx, input_dataset) + # create hash partitions + input_partitions = HashPartitionNode( + ctx, + (input_data_files,), + npartitions=npartitions, + hash_columns=["url"], + data_partition_column=data_partition_column, + engine_type=engine_type, + hive_partitioning=hive_partitioning, + output_name="input_partitions", + cpu_limit=1, + memory_limit=1 * GB, + ) + split_urls = SqlEngineNode( + ctx, + (input_partitions,), + f"select url, string_split(url, '/')[0] as host from {{0}}", + cpu_limit=1, + memory_limit=1 * GB, + ) + plan = LogicalPlan(ctx, split_urls) + exec_plan = self.execute_plan(plan) + self.assertEqual( + npartitions, + len( + exec_plan.final_output.load_partitioned_datasets( + npartitions, data_partition_column + ) + ), + ) + self.assertEqual( + npartitions, + len( + exec_plan.get_output("input_partitions").load_partitioned_datasets( + npartitions, data_partition_column, hive_partitioning + ) + ), + ) + return exec_plan + + npartitions = 5 + data_partition_column = "_human_readable_column_name_" + + for engine_type in ("duckdb", "arrow"): + with self.subTest(engine_type=engine_type): + exec_plan1 = run_test_plan( + npartitions, + data_partition_column, + engine_type, + hive_partitioning=engine_type == "duckdb", + ) + exec_plan2 = run_test_plan( + npartitions, + data_partition_column, + engine_type, + hive_partitioning=False, + ) + + ctx = Context() + output1 = DataSourceNode( + ctx, dataset=exec_plan1.get_output("input_partitions") + ) + output2 = DataSourceNode( + ctx, dataset=exec_plan2.get_output("input_partitions") + ) + split_urls1 = LoadPartitionedDataSetNode( + ctx, + (output1,), + npartitions=npartitions, + data_partition_column=data_partition_column, + hive_partitioning=engine_type == "duckdb", + ) + split_urls2 = LoadPartitionedDataSetNode( + ctx, + (output2,), + npartitions=npartitions, + data_partition_column=data_partition_column, + hive_partitioning=False, + ) + split_urls3 = SqlEngineNode( + ctx, + (split_urls1, split_urls2), + f""" + select split_urls1.url, string_split(split_urls2.url, '/')[0] as host + from {{0}} as split_urls1 + join {{1}} as split_urls2 + on split_urls1.url = split_urls2.url + """, + cpu_limit=1, + memory_limit=1 * GB, + ) + plan = LogicalPlan(ctx, split_urls3) + exec_plan3 = self.execute_plan(plan) + # load each partition as arrow table and compare + final_output_partitions1 = ( + exec_plan1.final_output.load_partitioned_datasets( + npartitions, data_partition_column + ) + ) + final_output_partitions3 = ( + exec_plan3.final_output.load_partitioned_datasets( + npartitions, data_partition_column + ) + ) + self.assertEqual(npartitions, len(final_output_partitions3)) + for x, y in zip(final_output_partitions1, final_output_partitions3): + self._compare_arrow_tables(x.to_arrow_table(), y.to_arrow_table()) + + def test_nested_partition(self): + ctx = Context() + parquet_files = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_source = DataSourceNode(ctx, parquet_files) + + SqlEngineNode.default_cpu_limit = 1 + SqlEngineNode.default_memory_limit = 1 * GB + initial_reduce = r"select host, count(*) as cnt from {0} group by host" + combine_reduce_results = ( + r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host" + ) + join_query = r"select host, cnt from {0} where (exists (select * from {1} where {1}.host = {0}.host)) and (exists (select * from {2} where {2}.host = {0}.host))" + + partition_by_hosts = HashPartitionNode( + ctx, + (data_source,), + npartitions=3, + hash_columns=["host"], + data_partition_column="host_partition", + ) + partition_by_hosts_x_urls = HashPartitionNode( + ctx, + (partition_by_hosts,), + npartitions=5, + hash_columns=["url"], + data_partition_column="url_partition", + nested=True, + ) + url_count_by_hosts_x_urls1 = SqlEngineNode( + ctx, + (partition_by_hosts_x_urls,), + initial_reduce, + output_name="url_count_by_hosts_x_urls1", + ) + url_count_by_hosts1 = SqlEngineNode( + ctx, + (ConsolidateNode(ctx, url_count_by_hosts_x_urls1, ["host_partition"]),), + combine_reduce_results, + output_name="url_count_by_hosts1", + ) + join_count_by_hosts_x_urls1 = SqlEngineNode( + ctx, + (url_count_by_hosts_x_urls1, url_count_by_hosts1, data_source), + join_query, + output_name="join_count_by_hosts_x_urls1", + ) + + partitioned_urls = LoadPartitionedDataSetNode( + ctx, + (partition_by_hosts_x_urls,), + data_partition_column="url_partition", + npartitions=5, + ) + partitioned_hosts_x_urls = LoadPartitionedDataSetNode( + ctx, + (partitioned_urls,), + data_partition_column="host_partition", + npartitions=3, + nested=True, + ) + partitioned_3dims = EvenlyDistributedPartitionNode( + ctx, + (partitioned_hosts_x_urls,), + npartitions=2, + dimension="inner_partition", + partition_by_rows=True, + nested=True, + ) + url_count_by_3dims = SqlEngineNode(ctx, (partitioned_3dims,), initial_reduce) + url_count_by_hosts_x_urls2 = SqlEngineNode( + ctx, + ( + ConsolidateNode( + ctx, url_count_by_3dims, ["host_partition", "url_partition"] + ), + ), + combine_reduce_results, + output_name="url_count_by_hosts_x_urls2", + ) + url_count_by_hosts2 = SqlEngineNode( + ctx, + (ConsolidateNode(ctx, url_count_by_hosts_x_urls2, ["host_partition"]),), + combine_reduce_results, + output_name="url_count_by_hosts2", + ) + url_count_by_hosts_expected = SqlEngineNode( + ctx, + (data_source,), + initial_reduce, + per_thread_output=False, + output_name="url_count_by_hosts_expected", + ) + join_count_by_hosts_x_urls2 = SqlEngineNode( + ctx, + (url_count_by_hosts_x_urls2, url_count_by_hosts2, data_source), + join_query, + output_name="join_count_by_hosts_x_urls2", + ) + + union_url_count_by_hosts = UnionNode( + ctx, (url_count_by_hosts1, url_count_by_hosts2) + ) + union_url_count_by_hosts_x_urls = UnionNode( + ctx, + ( + url_count_by_hosts_x_urls1, + url_count_by_hosts_x_urls2, + join_count_by_hosts_x_urls1, + join_count_by_hosts_x_urls2, + ), + ) + + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + data_sink = DataSinkNode( + ctx, + ( + url_count_by_hosts_expected, + union_url_count_by_hosts, + union_url_count_by_hosts_x_urls, + ), + output_path=output_dir, + manifest_only=True, + ) + plan = LogicalPlan(ctx, data_sink) + exec_plan = self.execute_plan(plan, remove_empty_parquet=True) + # verify results + self._compare_arrow_tables( + exec_plan.get_output("url_count_by_hosts_x_urls1").to_arrow_table(), + exec_plan.get_output("url_count_by_hosts_x_urls2").to_arrow_table(), + ) + self._compare_arrow_tables( + exec_plan.get_output("join_count_by_hosts_x_urls1").to_arrow_table(), + exec_plan.get_output("join_count_by_hosts_x_urls2").to_arrow_table(), + ) + self._compare_arrow_tables( + exec_plan.get_output("url_count_by_hosts_x_urls1").to_arrow_table(), + exec_plan.get_output("join_count_by_hosts_x_urls1").to_arrow_table(), + ) + self._compare_arrow_tables( + exec_plan.get_output("url_count_by_hosts1").to_arrow_table(), + exec_plan.get_output("url_count_by_hosts2").to_arrow_table(), + ) + self._compare_arrow_tables( + exec_plan.get_output("url_count_by_hosts_expected").to_arrow_table(), + exec_plan.get_output("url_count_by_hosts1").to_arrow_table(), + ) + + def test_user_defined_partition(self): + ctx = Context() + parquet_files = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_source = DataSourceNode(ctx, parquet_files) + file_partitions1 = CalculatePartitionFromFilename( + ctx, (data_source,), npartitions=3, dimension="by_filename_hash1" + ) + url_count1 = SqlEngineNode( + ctx, + (file_partitions1,), + r"select host, count(*) as cnt from {0} group by host", + output_name="url_count1", + ) + file_partitions2 = CalculatePartitionFromFilename( + ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2" + ) + url_count2 = SqlEngineNode( + ctx, + (file_partitions2,), + r"select host, cnt from {0}", + output_name="url_count2", + ) + plan = LogicalPlan(ctx, url_count2) + + exec_plan = self.execute_plan(plan, enable_diagnostic_metrics=True) + self._compare_arrow_tables( + exec_plan.get_output("url_count1").to_arrow_table(), + exec_plan.get_output("url_count2").to_arrow_table(), + ) + + def test_user_partitioned_data_source(self): + ctx = Context() + parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_source = DataSourceNode(ctx, parquet_dataset) + evenly_dist_data_source = EvenlyDistributedPartitionNode( + ctx, (data_source,), npartitions=parquet_dataset.num_files + ) + + parquet_datasets = [ParquetDataSet([p]) for p in parquet_dataset.resolved_paths] + partitioned_data_source = UserPartitionedDataSourceNode(ctx, parquet_datasets) + + url_count_by_host1 = SqlEngineNode( + ctx, + (evenly_dist_data_source,), + r"select host, count(*) as cnt from {0} group by host", + output_name="url_count_by_host1", + cpu_limit=1, + memory_limit=1 * GB, + ) + + url_count_by_host2 = SqlEngineNode( + ctx, + (evenly_dist_data_source, partitioned_data_source), + r"select {1}.host, count(*) as cnt from {0} join {1} on {0}.host = {1}.host group by {1}.host", + output_name="url_count_by_host2", + cpu_limit=1, + memory_limit=1 * GB, + ) + + plan = LogicalPlan( + ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2]) + ) + exec_plan = self.execute_plan(plan, enable_diagnostic_metrics=True) + self._compare_arrow_tables( + exec_plan.get_output("url_count_by_host1").to_arrow_table(), + exec_plan.get_output("url_count_by_host2").to_arrow_table(), + ) + + def test_partition_info_in_sql_query(self): + """ + User can refer to the partition info in the SQL query. + """ + ctx = Context() + parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_source = DataSourceNode(ctx, parquet_dataset) + evenly_dist_data_source = EvenlyDistributedPartitionNode( + ctx, (data_source,), npartitions=parquet_dataset.num_files + ) + sql_query = SqlEngineNode( + ctx, + (evenly_dist_data_source,), + r"select host, {__data_partition__} as partition_info from {0}", + ) + plan = LogicalPlan(ctx, sql_query) + exec_plan = self.execute_plan(plan) diff --git a/tests/test_plan.py b/tests/test_plan.py new file mode 100644 index 0000000..39f6fa8 --- /dev/null +++ b/tests/test_plan.py @@ -0,0 +1,70 @@ +import os +import tempfile +import unittest + +from examples.fstest import fstest +from examples.shuffle_data import shuffle_data +from examples.shuffle_mock_urls import shuffle_mock_urls +from examples.sort_mock_urls import sort_mock_urls +from examples.sort_mock_urls_v2 import sort_mock_urls_v2 +from smallpond.dataframe import Session +from tests.test_fabric import TestFabric + + +class TestPlan(TestFabric, unittest.TestCase): + def test_sort_mock_urls(self): + for engine_type in ("duckdb", "arrow"): + with self.subTest(engine_type=engine_type): + plan = sort_mock_urls( + ["tests/data/mock_urls/*.tsv"], + npartitions=3, + engine_type=engine_type, + ) + self.execute_plan(plan) + + def test_sort_mock_urls_external_output_path(self): + with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir: + plan = sort_mock_urls( + ["tests/data/mock_urls/*.tsv"], + npartitions=3, + external_output_path=output_dir, + ) + self.execute_plan(plan) + + def test_shuffle_mock_urls(self): + for engine_type in ("duckdb", "arrow"): + with self.subTest(engine_type=engine_type): + plan = shuffle_mock_urls( + ["tests/data/mock_urls/*.parquet"], + npartitions=3, + sort_rand_keys=True, + ) + self.execute_plan(plan) + + def test_shuffle_data(self): + for engine_type in ("duckdb", "arrow"): + with self.subTest(engine_type=engine_type): + plan = shuffle_data( + ["tests/data/mock_urls/*.parquet"], + num_data_partitions=3, + num_out_data_partitions=3, + engine_type=engine_type, + ) + self.execute_plan(plan) + + +def test_fstest(sp: Session): + path = sp._runtime_ctx.output_root + fstest( + sp, + input_path=os.path.join(path, "*"), + output_path=path, + size="10M", + npartitions=3, + ) + + +def test_sort_mock_urls_v2(sp: Session): + sort_mock_urls_v2( + sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3 + ) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000..85fb807 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,180 @@ +import os.path +import random +import time +import unittest +from typing import List, Tuple + +from loguru import logger + +from smallpond.execution.scheduler import ExecutorState +from smallpond.execution.task import PythonScriptTask, RuntimeContext +from smallpond.logical.dataset import DataSet, ParquetDataSet +from smallpond.logical.node import ( + Context, + DataSetPartitionNode, + DataSourceNode, + LogicalPlan, + Node, + PythonScriptNode, +) +from tests.test_fabric import TestFabric + + +class RandomSleepTask(PythonScriptTask): + def __init__( + self, *args, sleep_secs: float, fail_first_try: bool, **kwargs + ) -> None: + super().__init__(*args, **kwargs) + self.sleep_secs = sleep_secs + self.fail_first_try = fail_first_try + + def process( + self, + runtime_ctx: RuntimeContext, + input_datasets: List[DataSet], + output_path: str, + ) -> bool: + logger.info(f"sleeping {self.sleep_secs} secs") + time.sleep(self.sleep_secs) + with open(os.path.join(output_path, self.output_filename), "w") as fout: + fout.write(f"{repr(self)}") + if self.fail_first_try and self.retry_count == 0: + return False + return True + + +class RandomSleepNode(PythonScriptNode): + def __init__( + self, + ctx: Context, + input_deps: Tuple[Node, ...], + *, + max_sleep_secs=5, + fail_first_try=False, + **kwargs, + ): + super().__init__(ctx, input_deps, **kwargs) + self.max_sleep_secs = max_sleep_secs + self.fail_first_try = fail_first_try + + def spawn(self, *args, **kwargs) -> RandomSleepTask: + sleep_secs = ( + random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs + ) + return RandomSleepTask( + *args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try + ) + + +class TestScheduler(TestFabric, unittest.TestCase): + def create_random_sleep_plan( + self, npartitions, max_sleep_secs, fail_first_try=False + ): + ctx = Context() + dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"]) + data_files = DataSourceNode(ctx, dataset) + data_partitions = DataSetPartitionNode( + ctx, (data_files,), npartitions=npartitions, partition_by_rows=True + ) + random_sleep = RandomSleepNode( + ctx, + (data_partitions,), + max_sleep_secs=max_sleep_secs, + fail_first_try=fail_first_try, + ) + return LogicalPlan(ctx, random_sleep) + + def check_executor_state(self, target_state: ExecutorState, nloops=200): + for _ in range(nloops): + latest_sched_state = self.get_latest_sched_state() + if any( + executor.state == target_state + for executor in latest_sched_state.remote_executors + ): + logger.info( + f"found {target_state} executor in: {latest_sched_state.remote_executors}" + ) + break + time.sleep(0.1) + else: + self.assertTrue( + False, + f"cannot find any executor in state {target_state}: {latest_sched_state.remote_executors}", + ) + + def test_standalone_mode(self): + plan = self.create_random_sleep_plan(npartitions=10, max_sleep_secs=1) + self.execute_plan(plan, num_executors=0) + + def test_failed_executors(self): + num_exec = 6 + num_fail = 4 + plan = self.create_random_sleep_plan(npartitions=300, max_sleep_secs=10) + + _, executors, processes = self.start_execution( + plan, + num_executors=num_exec, + secs_wq_poll_interval=0.1, + secs_executor_probe_interval=0.5, + console_log_level="WARNING", + ) + latest_sched_state = self.get_latest_sched_state() + self.check_executor_state(ExecutorState.GOOD) + + for i, (executor, process) in enumerate( + random.sample(list(zip(executors, processes[1:])), k=num_fail) + ): + if i % 2 == 0: + logger.warning(f"kill executor: {executor}") + process.kill() + else: + logger.warning(f"skip probes: {executor}") + executor.skip_probes(latest_sched_state.ctx.max_num_missed_probes * 2) + + self.join_running_procs() + latest_sched_state = self.get_latest_sched_state() + self.assertTrue(latest_sched_state.success) + self.assertGreater(len(latest_sched_state.abandoned_tasks), 0) + self.assertLessEqual( + 1, + len(latest_sched_state.stopped_executors), + f"remote_executors: {latest_sched_state.remote_executors}", + ) + self.assertLessEqual( + num_fail / 2, + len(latest_sched_state.failed_executors), + f"remote_executors: {latest_sched_state.remote_executors}", + ) + + def test_speculative_scheduling(self): + for speculative_exec in ("disable", "enable", "aggressive"): + with self.subTest(speculative_exec=speculative_exec): + plan = self.create_random_sleep_plan(npartitions=100, max_sleep_secs=10) + self.execute_plan( + plan, + num_executors=3, + secs_wq_poll_interval=0.1, + secs_executor_probe_interval=0.5, + prioritize_retry=(speculative_exec == "aggressive"), + speculative_exec=speculative_exec, + ) + latest_sched_state = self.get_latest_sched_state() + if speculative_exec == "disable": + self.assertEqual(len(latest_sched_state.abandoned_tasks), 0) + else: + self.assertGreater(len(latest_sched_state.abandoned_tasks), 0) + + def test_stop_executor_on_failure(self): + plan = self.create_random_sleep_plan( + npartitions=3, max_sleep_secs=5, fail_first_try=True + ) + exec_plan = self.execute_plan( + plan, + num_executors=5, + secs_wq_poll_interval=0.1, + secs_executor_probe_interval=0.5, + check_result=False, + stop_executor_on_failure=True, + ) + latest_sched_state = self.get_latest_sched_state() + self.assertGreater(len(latest_sched_state.abandoned_tasks), 0) diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..83df07a --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,54 @@ +import os + +from smallpond.dataframe import Session + + +def test_shutdown_cleanup(sp: Session): + assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should exist" + assert os.path.exists( + sp._runtime_ctx.staging_root + ), "staging directory should exist" + assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should exist" + + # create some tasks and complete them + df = sp.from_items([1, 2, 3]) + df.write_parquet(sp._runtime_ctx.output_root) + sp.shutdown() + + # shutdown should clean up directories + assert not os.path.exists( + sp._runtime_ctx.queue_root + ), "queue directory should be cleared" + assert not os.path.exists( + sp._runtime_ctx.staging_root + ), "staging directory should be cleared" + assert not os.path.exists( + sp._runtime_ctx.temp_root + ), "temp directory should be cleared" + with open(sp._runtime_ctx.job_status_path) as fin: + assert "success" in fin.read(), "job status should be success" + + +def test_shutdown_no_cleanup_on_failure(sp: Session): + df = sp.from_items([1, 2, 3]) + try: + # create a task that will fail + df.map(lambda x: x / 0).compute() + except Exception: + pass + else: + raise RuntimeError("task should fail") + sp.shutdown() + + # shutdown should not clean up directories + assert os.path.exists( + sp._runtime_ctx.queue_root + ), "queue directory should not be cleared" + assert os.path.exists( + sp._runtime_ctx.staging_root + ), "staging directory should not be cleared" + assert os.path.exists( + sp._runtime_ctx.temp_root + ), "temp directory should not be cleared" + with open(sp._runtime_ctx.job_status_path) as fin: + assert "failure" in fin.read(), "job status should be failure" diff --git a/tests/test_utility.py b/tests/test_utility.py new file mode 100644 index 0000000..821462a --- /dev/null +++ b/tests/test_utility.py @@ -0,0 +1,50 @@ +import random +import subprocess +import time +import unittest +from typing import Iterable + +from smallpond.utility import ConcurrentIter, execute_command +from tests.test_fabric import TestFabric + + +class TestUtility(TestFabric, unittest.TestCase): + def test_concurrent_iter_no_error(self): + def slow_iterator(iter: Iterable[int], sleep_ms: int): + for i in iter: + time.sleep(sleep_ms / 1000) + yield i + + for n in [1, 5, 10, 50, 100]: + with ConcurrentIter(slow_iterator(range(n), 2)) as iter1: + with ConcurrentIter(slow_iterator(iter1, 5)) as iter2: + self.assertEqual(sum(slow_iterator(iter2, 1)), sum(range(n))) + + def test_concurrent_iter_with_error(self): + def broken_iterator(iter: Iterable[int], sleep_ms: int): + for i in iter: + time.sleep(sleep_ms / 1000) + if random.randint(1, 10) == 1: + raise Exception("raised before yield") + yield i + if random.randint(1, 10) == 1: + raise Exception("raised after yield") + raise Exception("raised at the end") + + for n in [1, 5, 10, 50, 100]: + with self.assertRaises(Exception): + with ConcurrentIter(range(n)) as iter: + print(sum(broken_iterator(iter, 1))) + with self.assertRaises(Exception): + with ConcurrentIter(broken_iterator(range(n), 2)) as iter1: + with ConcurrentIter(broken_iterator(iter1, 5)) as iter2: + print(sum(iter2)) + + def test_execute_command(self): + with self.assertRaises(subprocess.CalledProcessError): + for line in execute_command("ls non_existent_file"): + print(line) + for line in execute_command("echo hello"): + print(line) + for line in execute_command("cat /dev/null"): + print(line) diff --git a/tests/test_workqueue.py b/tests/test_workqueue.py new file mode 100644 index 0000000..b9072e0 --- /dev/null +++ b/tests/test_workqueue.py @@ -0,0 +1,119 @@ +import multiprocessing +import multiprocessing.dummy +import multiprocessing.queues +import queue +import tempfile +import time +import unittest + +from loguru import logger + +from smallpond.execution.workqueue import ( + WorkItem, + WorkQueue, + WorkQueueInMemory, + WorkQueueOnFilesystem, +) +from tests.test_fabric import TestFabric + + +class PrintWork(WorkItem): + def __init__(self, name: str, message: str) -> None: + super().__init__(name, cpu_limit=1, gpu_limit=0, memory_limit=0) + self.message = message + + def run(self) -> bool: + logger.debug(f"{self.key}: {self.message}") + return True + + +def producer(wq: WorkQueue, id: int, numItems: int, numConsumers: int) -> None: + print(f"wq.outbound_works: {wq.outbound_works}") + for i in range(numItems): + wq.push(PrintWork(f"item-{i}", message="hello"), buffering=(i % 3 == 1)) + # wq.push(PrintWork(f"item-{i}", message="hello")) + if i % 5 == 0: + wq.flush() + for i in range(numConsumers): + wq.push(PrintWork(f"stop-{i}", message="stop")) + logger.success(f"producer {id} generated {numItems} items") + + +def consumer(wq: WorkQueue, id: int) -> int: + numItems = 0 + numWaits = 0 + running = True + while running: + items = wq.pop(count=1) + if not items: + numWaits += 1 + time.sleep(0.01) + continue + for item in items: + assert isinstance(item, PrintWork) + if item.message == "stop": + running = False + break + item.exec() + numItems += 1 + logger.success(f"consumer {id} collected {numItems} items, {numWaits} waits") + logger.complete() + return numItems + + +class WorkQueueTestBase(object): + + wq: WorkQueue = None + pool: multiprocessing.Pool = None + + def setUp(self) -> None: + logger.disable("smallpond.execution.workqueue") + return super().setUp() + + def test_basics(self): + numItems = 200 + for i in range(numItems): + self.wq.push(PrintWork(f"item-{i}", message="hello")) + numCollected = 0 + for _ in range(numItems): + items = self.wq.pop() + logger.info(f"{len(items)} items") + numCollected += len(items) + if numItems == numCollected: + break + + def test_multi_consumers(self): + numConsumers = 10 + numItems = 200 + result = self.pool.starmap_async( + consumer, [(self.wq, id) for id in range(numConsumers)] + ) + producer(self.wq, 0, numItems, numConsumers) + + logger.info("waiting for result") + numCollected = sum(result.get(timeout=20)) + logger.info(f"expected vs collected: {numItems} vs {numCollected}") + self.assertEqual(numItems, numCollected) + logger.success("all done") + + self.pool.terminate() + self.pool.join() + logger.success("workers stopped") + + +class TestWorkQueueInMemory(WorkQueueTestBase, TestFabric, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.wq = WorkQueueInMemory(queue_type=queue.Queue) + self.pool = multiprocessing.dummy.Pool(10) + + +class TestWorkQueueOnFilesystem(WorkQueueTestBase, TestFabric, unittest.TestCase): + + workq_root: str + + def setUp(self) -> None: + super().setUp() + self.workq_root = tempfile.mkdtemp(dir=self.runtime_ctx.queue_root) + self.wq = WorkQueueOnFilesystem(self.workq_root, sort=True) + self.pool = multiprocessing.Pool(10)