smallpond/tests/test_execution.py
2025-03-05 22:46:23 +08:00

759 lines
32 KiB
Python

import functools
import os.path
import socket
import tempfile
import time
import unittest
from datetime import datetime
from typing import Iterable, List, Tuple
import pandas
import pyarrow as arrow
from loguru import logger
from pandas.core.api import DataFrame as DataFrame
from smallpond.common import GB, MB, split_into_rows
from smallpond.execution.task import (
DataSinkTask,
DataSourceTask,
JobId,
PartitionInfo,
PythonScriptTask,
RuntimeContext,
StreamOutput,
)
from smallpond.execution.workqueue import WorkStatus
from smallpond.logical.dataset import (
ArrowTableDataSet,
DataSet,
ParquetDataSet,
SqlQueryDataSet,
)
from smallpond.logical.node import (
ArrowBatchNode,
ArrowComputeNode,
ArrowStreamNode,
Context,
DataSetPartitionNode,
DataSinkNode,
DataSourceNode,
EvenlyDistributedPartitionNode,
HashPartitionNode,
LogicalPlan,
Node,
PandasBatchNode,
PandasComputeNode,
ProjectionNode,
PythonScriptNode,
RootNode,
SqlEngineNode,
)
from smallpond.logical.udf import UDFListType, UDFType
from tests.test_fabric import TestFabric
class OutputMsgPythonTask(PythonScriptTask):
def __init__(self, msg: str, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.msg = msg
def initialize(self):
pass
def finalize(self):
pass
def process(
self,
runtime_ctx: RuntimeContext,
input_datasets: List[DataSet],
output_path: str,
) -> bool:
logger.info(f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}")
self.inject_fault()
return True
# method1: inherit Task class and override spawn method
class OutputMsgPythonNode(PythonScriptNode):
def spawn(self, *args, **kwargs) -> OutputMsgPythonTask:
return OutputMsgPythonTask("python script", *args, **kwargs)
# method2: override process method
# this usage is not recommended and only for testing. use `process_func` instead.
class OutputMsgPythonNode2(PythonScriptNode):
def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], msg: str) -> None:
super().__init__(ctx, input_deps)
self.msg = msg
def process(
self,
runtime_ctx: RuntimeContext,
input_datasets: List[DataSet],
output_path: str,
) -> bool:
logger.info(f"msg: {self.msg}, num files: {input_datasets[0].num_files}")
return True
# this usage is not recommended and only for testing. use `process_func` instead.
class CopyInputArrowNode(ArrowComputeNode):
def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], msg: str) -> None:
super().__init__(ctx, input_deps)
self.msg = msg
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
return copy_input_arrow(runtime_ctx, input_tables, self.msg)
# this usage is not recommended and only for testing. use `process_func` instead.
class CopyInputStreamNode(ArrowStreamNode):
def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], msg: str) -> None:
super().__init__(ctx, input_deps)
self.msg = msg
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
return copy_input_stream(runtime_ctx, input_readers, self.msg)
def copy_input_arrow(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str) -> arrow.Table:
logger.info(f"msg: {msg}, num rows: {input_tables[0].num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
runtime_ctx.task.inject_fault()
return input_tables[0]
def copy_input_stream(runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str) -> Iterable[arrow.Table]:
for index, batch in enumerate(input_readers[0]):
logger.info(f"msg: {msg}, batch index: {index}, num rows: {batch.num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
yield StreamOutput(
arrow.Table.from_batches([batch]),
batch_indices=[index],
force_checkpoint=True,
)
runtime_ctx.task.inject_fault()
def copy_input_batch(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str) -> arrow.Table:
logger.info(f"msg: {msg}, num rows: {input_batches[0].num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
runtime_ctx.task.inject_fault()
return input_batches[0]
def copy_input_data_frame(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
runtime_ctx.task.inject_fault()
return input_dfs[0]
def copy_input_data_frame_batch(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
runtime_ctx.task.inject_fault()
return input_dfs[0]
def merge_input_tables(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table]) -> arrow.Table:
runtime_ctx.task.inject_fault()
output = arrow.concat_tables(input_batches)
logger.info(f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}")
return output
def merge_input_data_frames(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
runtime_ctx.task.inject_fault()
output = pandas.concat(input_dfs)
logger.info(f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}")
return output
def parse_url(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
urls = input_tables[0].columns[0]
hosts = [url.as_py().split("/", maxsplit=2)[0] for url in urls]
return input_tables[0].append_column("host", arrow.array(hosts))
def nonzero_exit_code(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
import sys
if runtime_ctx.task._memory_boost == 1:
sys.exit(1)
return True
# create an empty file with a fixed name
def empty_file(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
import os
with open(os.path.join(output_path, "file"), "w") as fout:
pass
return True
def return_fake_gpus(count: int = 8):
import GPUtil
return [GPUtil.GPU(i, *list(range(11))) for i in range(count)]
def split_url(urls: arrow.array) -> arrow.array:
url_parts = [url.as_py().split("/") for url in urls]
return arrow.array(url_parts, type=arrow.list_(arrow.string()))
def choose_random_urls(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5) -> arrow.Table:
# get the current running task
runtime_task = runtime_ctx.task
# access task-specific attributes
cpu_limit = runtime_task.cpu_limit
random_gen = runtime_task.python_random_gen
# input data
(url_table,) = input_tables
hosts, urls = url_table.columns
logger.info(f"{cpu_limit=} {len(urls)=}")
# generate ramdom samples
random_urls = random_gen.choices(urls.to_pylist(), k=k)
return arrow.Table.from_arrays([arrow.array(random_urls)], names=["random_urls"])
class TestExecution(TestFabric, unittest.TestCase):
fault_inject_prob = 0.05
def test_arrow_task(self):
for use_duckdb_reader in (False, True):
with self.subTest(use_duckdb_reader=use_duckdb_reader):
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_table = dataset.to_arrow_table()
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=7)
if use_duckdb_reader:
data_partitions = ProjectionNode(
ctx,
data_partitions,
columns=["*", "string_split(url, '/')[0] as parsed_host"],
)
arrow_compute = ArrowComputeNode(
ctx,
(data_partitions,),
process_func=functools.partial(copy_input_arrow, msg="arrow compute"),
use_duckdb_reader=use_duckdb_reader,
output_name="arrow_compute",
output_path=output_dir,
cpu_limit=2,
)
arrow_stream = ArrowStreamNode(
ctx,
(data_partitions,),
process_func=functools.partial(copy_input_stream, msg="arrow stream"),
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
use_duckdb_reader=use_duckdb_reader,
output_name="arrow_stream",
output_path=output_dir,
cpu_limit=2,
)
arrow_batch = ArrowBatchNode(
ctx,
(data_partitions,),
process_func=functools.partial(copy_input_batch, msg="arrow batch"),
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
use_duckdb_reader=use_duckdb_reader,
output_name="arrow_batch",
output_path=output_dir,
cpu_limit=2,
)
data_sink = DataSinkNode(
ctx,
(arrow_compute, arrow_stream, arrow_batch),
output_path=output_dir,
)
plan = LogicalPlan(ctx, data_sink)
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
arrow_compute_output = ParquetDataSet(
[os.path.join(output_dir, "arrow_compute", "**/*.parquet")],
recursive=True,
)
arrow_stream_output = ParquetDataSet(
[os.path.join(output_dir, "arrow_stream", "**/*.parquet")],
recursive=True,
)
arrow_batch_output = ParquetDataSet(
[os.path.join(output_dir, "arrow_batch", "**/*.parquet")],
recursive=True,
)
self._compare_arrow_tables(
data_table,
arrow_compute_output.to_arrow_table().select(data_table.column_names),
)
self._compare_arrow_tables(
data_table,
arrow_stream_output.to_arrow_table().select(data_table.column_names),
)
self._compare_arrow_tables(
data_table,
arrow_batch_output.to_arrow_table().select(data_table.column_names),
)
def test_pandas_task(self):
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_table = dataset.to_arrow_table()
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=7)
pandas_compute = PandasComputeNode(
ctx,
(data_partitions,),
process_func=copy_input_data_frame,
output_name="pandas_compute",
output_path=output_dir,
cpu_limit=2,
)
pandas_batch = PandasBatchNode(
ctx,
(data_partitions,),
process_func=copy_input_data_frame_batch,
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
output_name="pandas_batch",
output_path=output_dir,
cpu_limit=2,
)
data_sink = DataSinkNode(ctx, (pandas_compute, pandas_batch), output_path=output_dir)
plan = LogicalPlan(ctx, data_sink)
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
pandas_compute_output = ParquetDataSet(
[os.path.join(output_dir, "pandas_compute", "**/*.parquet")],
recursive=True,
)
pandas_batch_output = ParquetDataSet(
[os.path.join(output_dir, "pandas_batch", "**/*.parquet")],
recursive=True,
)
self._compare_arrow_tables(data_table, pandas_compute_output.to_arrow_table())
self._compare_arrow_tables(data_table, pandas_batch_output.to_arrow_table())
def test_variable_length_input_datasets(self):
ctx = Context()
small_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
large_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10)
small_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7)
large_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7)
arrow_batch = ArrowBatchNode(
ctx,
(small_partitions, large_partitions),
process_func=merge_input_tables,
streaming_batch_size=100,
secs_checkpoint_interval=0.5,
output_name="arrow_batch",
cpu_limit=2,
)
pandas_batch = PandasBatchNode(
ctx,
(small_partitions, large_partitions),
process_func=merge_input_data_frames,
streaming_batch_size=100,
secs_checkpoint_interval=0.5,
output_name="pandas_batch",
cpu_limit=2,
)
plan = LogicalPlan(ctx, RootNode(ctx, (arrow_batch, pandas_batch)))
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
arrow_batch_output = ParquetDataSet(
[os.path.join(exec_plan.ctx.output_root, "arrow_batch", "**/*.parquet")],
recursive=True,
)
pandas_batch_output = ParquetDataSet(
[os.path.join(exec_plan.ctx.output_root, "pandas_batch", "**/*.parquet")],
recursive=True,
)
self.assertEqual(small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows)
self.assertEqual(
small_dataset.num_rows + large_dataset.num_rows,
pandas_batch_output.num_rows,
)
def test_projection_task(self):
ctx = Context()
# select columns when defining dataset
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"], columns=["url"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=3, partition_by_rows=True)
# projection as input of arrow node
generated_columns = ["filename", "file_row_number"]
urls_with_host = ArrowComputeNode(
ctx,
(ProjectionNode(ctx, data_partitions, ["url"], generated_columns),),
process_func=parse_url,
use_duckdb_reader=True,
)
# projection as input of sql node
distinct_urls_with_host = SqlEngineNode(
ctx,
(
ProjectionNode(
ctx,
data_partitions,
["url", "string_split(url, '/')[0] as host"],
generated_columns,
),
),
r"select distinct host, url, filename from {0}",
)
# unify different schemas
merged_diff_schemas = ProjectionNode(
ctx,
DataSetPartitionNode(ctx, (distinct_urls_with_host, urls_with_host), npartitions=1),
union_by_name=True,
)
host_partitions = HashPartitionNode(
ctx,
(merged_diff_schemas,),
npartitions=3,
hash_columns=["host"],
engine_type="duckdb",
output_name="host_partitions",
)
host_partitions.max_num_producer_tasks = 1
plan = LogicalPlan(ctx, host_partitions)
final_output = self.execute_plan(plan, fault_inject_prob=0.1).final_output
final_table = final_output.to_arrow_table()
self.assertEqual(
sorted(
[
"url",
"host",
*generated_columns,
HashPartitionNode.default_data_partition_column,
]
),
sorted(final_table.column_names),
)
def test_arrow_type_in_udfs(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
ctx.create_function(
"split_url",
split_url,
[UDFType.VARCHAR],
UDFListType(UDFType.VARCHAR),
use_arrow_type=True,
)
uniq_hosts = SqlEngineNode(
ctx,
(data_partitions,),
r"select split_url(url) as url_parts from {0}",
udfs=["split_url"],
)
plan = LogicalPlan(ctx, uniq_hosts)
self.execute_plan(plan)
def test_many_simple_tasks(self):
ctx = Context()
npartitions = 1000
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * npartitions)
data_files = DataSourceNode(ctx, dataset)
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions)
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(
plan,
num_executors=10,
secs_executor_probe_interval=5,
enable_profiling=True,
)
def test_many_producers_and_partitions(self):
ctx = Context()
npartitions = 10000
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * (npartitions * 10))
data_files = DataSourceNode(ctx, dataset)
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions, cpu_limit=1)
data_partitions.max_num_producer_tasks = 20
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(
plan,
num_executors=10,
secs_executor_probe_interval=5,
enable_profiling=True,
)
def test_local_gpu_rank(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
output_msg = OutputMsgPythonNode(ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5)
plan = LogicalPlan(ctx, output_msg)
runtime_ctx = RuntimeContext(
JobId.new(),
datetime.now(),
self.output_root_abspath,
console_log_level="WARNING",
)
runtime_ctx.get_local_gpus = return_fake_gpus
runtime_ctx.initialize(socket.gethostname(), cleanup_root=True)
self.execute_plan(plan, runtime_ctx=runtime_ctx)
def test_python_node_with_process_method(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
copy_input_arrow_node = CopyInputArrowNode(ctx, (data_files,), "hello")
copy_input_stream_node = CopyInputStreamNode(ctx, (data_files,), "hello")
output_msg = OutputMsgPythonNode2(ctx, (copy_input_arrow_node, copy_input_stream_node), "hello")
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(plan)
def test_sql_engine_oom(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
uniq_urls = SqlEngineNode(ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB)
uniq_url_partitions = DataSetPartitionNode(ctx, (uniq_urls,), 2)
uniq_url_count = SqlEngineNode(
ctx,
(uniq_url_partitions,),
sql_query=r"select count(distinct columns(*)) from {0}",
memory_limit=2 * MB,
)
plan = LogicalPlan(ctx, uniq_url_count)
self.execute_plan(plan, max_fail_count=10)
@unittest.skip("flaky on CI")
def test_enforce_memory_limit(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/arrow/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
arrow_compute = ArrowComputeNode(
ctx,
(data_files,),
process_func=functools.partial(copy_input_arrow, msg="arrow compute"),
memory_limit=1 * GB,
)
arrow_stream = ArrowStreamNode(
ctx,
(data_files,),
process_func=functools.partial(copy_input_stream, msg="arrow stream"),
memory_limit=1 * GB,
)
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
data_sink = DataSinkNode(ctx, (arrow_compute, arrow_stream), output_path=output_dir)
plan = LogicalPlan(ctx, data_sink)
self.execute_plan(
plan,
max_fail_count=10,
enforce_memory_limit=True,
nonzero_exitcode_as_oom=True,
)
def test_task_crash_as_oom(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
nonzero_exitcode = PythonScriptNode(ctx, (data_files,), process_func=nonzero_exit_code)
plan = LogicalPlan(ctx, nonzero_exitcode)
exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False)
self.assertFalse(exec_plan.successful)
exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True)
self.assertTrue(exec_plan.successful)
def test_manifest_only_data_sink(self):
with open("tests/data/long_path_list.txt", buffering=16 * MB) as fin:
filenames = list(map(str.strip, fin.readlines()))
logger.info(f"loaded {len(filenames)} filenames")
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ctx = Context()
dataset = ParquetDataSet(filenames)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=512)
data_sink = DataSinkNode(ctx, (data_partitions,), output_path=output_dir, manifest_only=True)
plan = LogicalPlan(ctx, data_sink)
self.execute_plan(plan)
with open(
os.path.join(output_dir, DataSinkTask.manifest_filename),
buffering=16 * MB,
) as fin:
num_lines = len(fin.readlines())
self.assertEqual(len(filenames), num_lines)
def test_sql_batched_processing(self):
for materialize_in_memory in (False, True):
with self.subTest(materialize_in_memory=materialize_in_memory):
ctx = Context()
dataset = ParquetDataSet(["tests/data/large_array/*.parquet"] * 2)
data_files = DataSourceNode(ctx, dataset)
content_length = SqlEngineNode(
ctx,
(data_files,),
r"select url, octet_length(content) as content_len from {0}",
materialize_in_memory=materialize_in_memory,
batched_processing=True,
cpu_limit=2,
memory_limit=2 * GB,
)
plan = LogicalPlan(ctx, content_length)
final_output: ParquetDataSet = self.execute_plan(plan).final_output
self.assertEqual(dataset.num_rows, final_output.num_rows)
def test_multiple_sql_queries(self):
for materialize_in_memory in (False, True):
with self.subTest(materialize_in_memory=materialize_in_memory):
ctx = Context()
dataset = ParquetDataSet(["tests/data/large_array/*.parquet"] * 2)
data_files = DataSourceNode(ctx, dataset)
content_length = SqlEngineNode(
ctx,
(data_files,),
[
r"create or replace temp table content_len_data as select url, octet_length(content) as content_len from {0}",
r"select * from content_len_data",
],
materialize_in_memory=materialize_in_memory,
batched_processing=True,
cpu_limit=2,
memory_limit=2 * GB,
)
plan = LogicalPlan(ctx, content_length)
final_output: ParquetDataSet = self.execute_plan(plan).final_output
self.assertEqual(dataset.num_rows, final_output.num_rows)
def test_temp_outputs_in_final_results(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(url) as cnt from {0}")
distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}")
merged_counts = DataSetPartitionNode(
ctx,
(
ProjectionNode(ctx, url_counts, ["cnt"]),
ProjectionNode(ctx, distinct_url_counts, ["cnt"]),
),
npartitions=1,
)
split_counts = DataSetPartitionNode(ctx, (merged_counts,), npartitions=10)
plan = LogicalPlan(ctx, split_counts)
final_output: ParquetDataSet = self.execute_plan(plan).final_output
self.assertEqual(data_partitions.npartitions * 2, final_output.num_rows)
def test_override_output_path(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
url_counts = SqlEngineNode(
ctx,
(data_partitions,),
r"select count(url) as cnt from {0}",
output_name="url_counts",
)
distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}")
merged_counts = DataSetPartitionNode(
ctx,
(
ProjectionNode(ctx, url_counts, ["cnt"]),
ProjectionNode(ctx, distinct_url_counts, ["cnt"]),
),
npartitions=1,
)
plan = LogicalPlan(ctx, merged_counts)
output_path = os.path.join(self.runtime_ctx.output_root, "final_output")
final_output = self.execute_plan(plan, output_path=output_path).final_output
self.assertTrue(os.path.exists(os.path.join(output_path, "url_counts")))
self.assertTrue(os.path.exists(os.path.join(output_path, "FinalResults")))
def test_data_sink_avoid_filename_conflicts(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
empty_files1 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file)
empty_files2 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file)
link_path = os.path.join(self.runtime_ctx.output_root, "link")
copy_path = os.path.join(self.runtime_ctx.output_root, "copy")
copy_input_path = os.path.join(self.runtime_ctx.output_root, "copy_input")
data_link = DataSinkNode(ctx, (empty_files1, empty_files2), type="link", output_path=link_path)
data_copy = DataSinkNode(ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path)
data_copy_input = DataSinkNode(ctx, (data_partitions,), type="copy", output_path=copy_input_path)
plan = LogicalPlan(ctx, RootNode(ctx, (data_link, data_copy, data_copy_input)))
self.execute_plan(plan)
# there should be 21 files (20 input files + 1 manifest file) in the sink dir
self.assertEqual(21, len(os.listdir(link_path)))
self.assertEqual(21, len(os.listdir(copy_path)))
# file name should not be modified if no conflict
self.assertEqual(
set(filename for filename in os.listdir("tests/data/mock_urls") if filename.endswith(".parquet")),
set(filename for filename in os.listdir(copy_input_path) if filename.endswith(".parquet")),
)
def test_literal_datasets_as_data_sources(self):
ctx = Context()
num_rows = 10
query_dataset = SqlQueryDataSet(f"select i from range({num_rows}) as x(i)")
table_dataset = ArrowTableDataSet(arrow.Table.from_arrays([list(range(num_rows))], names=["i"]))
query_source = DataSourceNode(ctx, query_dataset)
table_source = DataSourceNode(ctx, table_dataset)
query_partitions = DataSetPartitionNode(ctx, (query_source,), npartitions=num_rows, partition_by_rows=True)
table_partitions = DataSetPartitionNode(ctx, (table_source,), npartitions=num_rows, partition_by_rows=True)
joined_rows = SqlEngineNode(
ctx,
(query_partitions, table_partitions),
r"select a.i as i, b.i as j from {0} as a join {1} as b on a.i = b.i",
)
plan = LogicalPlan(ctx, joined_rows)
final_output: ParquetDataSet = self.execute_plan(plan).final_output
self.assertEqual(num_rows, final_output.num_rows)
def test_partial_process_func(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=3)
# use default value of k
random_urls_k5 = ArrowComputeNode(
ctx,
(data_partitions,),
process_func=choose_random_urls,
output_name="random_urls_k5",
)
# set value of k using functools.partial
random_urls_k10 = ArrowComputeNode(
ctx,
(data_partitions,),
process_func=functools.partial(choose_random_urls, k=10),
output_name="random_urls_k10",
)
random_urls_all = SqlEngineNode(
ctx,
(random_urls_k5, random_urls_k10),
r"select * from {0} union select * from {1}",
output_name="random_urls_all",
)
plan = LogicalPlan(ctx, random_urls_all)
exec_plan = self.execute_plan(plan)
self.assertEqual(
data_partitions.npartitions * 5,
exec_plan.get_output("random_urls_k5").to_arrow_table().num_rows,
)
self.assertEqual(
data_partitions.npartitions * 10,
exec_plan.get_output("random_urls_k10").to_arrow_table().num_rows,
)