reformat code with --line-length=150 (#18)

This commit is contained in:
Runji Wang
2025-03-05 22:46:23 +08:00
committed by GitHub
parent ed112db42a
commit 52ecc5e455
48 changed files with 794 additions and 2604 deletions

View File

@@ -26,7 +26,7 @@ jobs:
- name: Check code formatting
run: |
black --check .
black --check --line-length=150 .
# - name: Check typos
# uses: crate-ci/typos@v1.29.10

2
Makefile Normal file
View File

@@ -0,0 +1,2 @@
fmt:
black --line-length=150 .

View File

@@ -64,9 +64,7 @@ def main():
driver = Driver()
driver.add_argument("-i", "--input_paths", nargs="+")
driver.add_argument("-n", "--npartitions", type=int, default=None)
driver.add_argument(
"-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream")
)
driver.add_argument("-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream"))
driver.add_argument("-b", "--batch_size", type=int, default=1024 * 1024)
driver.add_argument("-s", "--row_group_size", type=int, default=1024 * 1024)
driver.add_argument("-o", "--output_name", default="data")

View File

@@ -73,26 +73,18 @@ def generate_records(
subprocess.run(gensort_cmd.split()).check_returncode()
runtime_task.add_elapsed_time("generate records (secs)")
shm_file.seek(0)
buffer = arrow.py_buffer(
shm_file.read(record_count * record_nbytes)
)
buffer = arrow.py_buffer(shm_file.read(record_count * record_nbytes))
runtime_task.add_elapsed_time("read records (secs)")
# https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout
records = arrow.Array.from_buffers(
arrow.binary(record_nbytes), record_count, [None, buffer]
)
records = arrow.Array.from_buffers(arrow.binary(record_nbytes), record_count, [None, buffer])
keys = pc.binary_slice(records, 0, key_nbytes)
# get first 2 bytes and convert to big-endian uint16
binary_prefix = pc.binary_slice(records, 0, 2).cast(arrow.binary())
reversed_prefix = pc.binary_reverse(binary_prefix).cast(
arrow.binary(2)
)
reversed_prefix = pc.binary_reverse(binary_prefix).cast(arrow.binary(2))
uint16_prefix = reversed_prefix.view(arrow.uint16())
buckets = pc.shift_right(uint16_prefix, 16 - bucket_nbits)
runtime_task.add_elapsed_time("build arrow table (secs)")
yield arrow.Table.from_arrays(
[buckets, keys, records], schema=schema
)
yield arrow.Table.from_arrays([buckets, keys, records], schema=schema)
yield StreamOutput(
schema.empty_table(),
batch_indices=[batch_idx],
@@ -108,9 +100,7 @@ def sort_records(
write_io_nbytes=500 * MB,
) -> bool:
runtime_task: PythonScriptTask = runtime_ctx.task
data_file_path = os.path.join(
runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat"
)
data_file_path = os.path.join(runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat")
if sort_engine == "polars":
input_data = polars.read_parquet(
@@ -134,9 +124,7 @@ def sort_records(
record_arrays = sorted_table.column("records").chunks
runtime_task.add_elapsed_time("convert to chunks (secs)")
elif sort_engine == "duckdb":
with duckdb.connect(
database=":memory:", config={"allow_unsigned_extensions": "true"}
) as conn:
with duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"}) as conn:
runtime_task.prepare_connection(conn)
input_views = runtime_task.create_input_views(conn, input_datasets)
sql_query = "select records from {0} order by keys".format(*input_views)
@@ -154,8 +142,7 @@ def sort_records(
buffer_mem = memoryview(values)
total_write_nbytes = sum(
fout.write(buffer_mem[offset : offset + write_io_nbytes])
for offset in range(0, len(buffer_mem), write_io_nbytes)
fout.write(buffer_mem[offset : offset + write_io_nbytes]) for offset in range(0, len(buffer_mem), write_io_nbytes)
)
assert total_write_nbytes == len(buffer_mem)
@@ -164,16 +151,10 @@ def sort_records(
return True
def validate_records(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
def validate_records(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
for data_path in input_datasets[0].resolved_paths:
summary_path = os.path.join(
output_path, PurePath(data_path).with_suffix(".sum").name
)
cmdstr = (
f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m"
)
summary_path = os.path.join(output_path, PurePath(data_path).with_suffix(".sum").name)
cmdstr = f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m"
logging.debug(f"running command: {cmdstr}")
result = subprocess.run(cmdstr.split(), capture_output=True, encoding="utf8")
if result.stderr:
@@ -185,9 +166,7 @@ def validate_records(
return True
def validate_summary(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
def validate_summary(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
concated_summary_path = os.path.join(output_path, "merged.sum")
with open(concated_summary_path, "wb") as fout:
for path in input_datasets[0].resolved_paths:
@@ -224,22 +203,13 @@ def generate_random_records(
)
range_begin_at = [pos for pos in range(0, total_num_records, record_range_size)]
range_num_records = [
min(total_num_records, record_range_size * (range_idx + 1)) - begin_at
for range_idx, begin_at in enumerate(range_begin_at)
]
range_num_records = [min(total_num_records, record_range_size * (range_idx + 1)) - begin_at for range_idx, begin_at in enumerate(range_begin_at)]
assert sum(range_num_records) == total_num_records
record_range = DataSourceNode(
ctx,
ArrowTableDataSet(
arrow.Table.from_arrays(
[range_begin_at, range_num_records], names=["begin_at", "num_records"]
)
),
)
record_range_partitions = DataSetPartitionNode(
ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True
ArrowTableDataSet(arrow.Table.from_arrays([range_begin_at, range_num_records], names=["begin_at", "num_records"])),
)
record_range_partitions = DataSetPartitionNode(ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True)
random_records = ArrowStreamNode(
ctx,
@@ -288,9 +258,7 @@ def gray_sort_benchmark(
if input_paths:
input_dataset = ParquetDataSet(input_paths)
input_nbytes = sum(os.path.getsize(p) for p in input_dataset.resolved_paths)
logging.warning(
f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files"
)
logging.warning(f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files")
random_records = DataSourceNode(ctx, input_dataset)
else:
random_records = generate_random_records(
@@ -335,12 +303,8 @@ def gray_sort_benchmark(
process_func=validate_records,
output_name="partitioned_summaries",
)
merged_summaries = DataSetPartitionNode(
ctx, (partitioned_summaries,), npartitions=1
)
final_check = PythonScriptNode(
ctx, (merged_summaries,), process_func=validate_summary
)
merged_summaries = DataSetPartitionNode(ctx, (partitioned_summaries,), npartitions=1)
final_check = PythonScriptNode(ctx, (merged_summaries,), process_func=validate_summary)
root = final_check
else:
root = sorted_records
@@ -359,17 +323,11 @@ def main():
driver.add_argument("-n", "--num_data_partitions", type=int, default=None)
driver.add_argument("-t", "--num_sort_partitions", type=int, default=None)
driver.add_argument("-i", "--input_paths", nargs="+", default=[])
driver.add_argument(
"-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow")
)
driver.add_argument(
"-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars")
)
driver.add_argument("-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow"))
driver.add_argument("-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars"))
driver.add_argument("-H", "--hive_partitioning", action="store_true")
driver.add_argument("-V", "--validate_results", action="store_true")
driver.add_argument(
"-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit
)
driver.add_argument("-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit)
driver.add_argument(
"-M",
"--shuffle_memory_limit",
@@ -378,12 +336,8 @@ def main():
)
driver.add_argument("-TC", "--sort_cpu_limit", type=int, default=8)
driver.add_argument("-TM", "--sort_memory_limit", type=int, default=None)
driver.add_argument(
"-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False)
)
driver.add_argument(
"-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total
)
driver.add_argument("-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False))
driver.add_argument("-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total)
driver.add_argument("-CP", "--parquet_compression", default=None)
driver.add_argument("-LV", "--parquet_compression_level", type=int, default=None)
@@ -393,16 +347,9 @@ def main():
total_num_cpus = max(1, driver_args.num_executors) * user_args.cpus_per_node
memory_per_cpu = user_args.memory_per_node // user_args.cpus_per_node
user_args.sort_cpu_limit = (
1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit
)
sort_memory_limit = (
user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu
)
user_args.total_data_nbytes = (
user_args.total_data_nbytes
or max(1, driver_args.num_executors) * user_args.memory_per_node
)
user_args.sort_cpu_limit = 1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit
sort_memory_limit = user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu
user_args.total_data_nbytes = user_args.total_data_nbytes or max(1, driver_args.num_executors) * user_args.memory_per_node
user_args.num_data_partitions = user_args.num_data_partitions or total_num_cpus // 2
user_args.num_sort_partitions = user_args.num_sort_partitions or max(
total_num_cpus // user_args.sort_cpu_limit,

View File

@@ -70,18 +70,12 @@ def main():
driver.add_argument("-i", "--input_paths", nargs="+", required=True)
driver.add_argument("-n", "--npartitions", type=int, default=None)
driver.add_argument("-c", "--hash_columns", nargs="+", required=True)
driver.add_argument(
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
)
driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow"))
driver.add_argument("-S", "--partition_stats", action="store_true")
driver.add_argument("-W", "--use_parquet_writer", action="store_true")
driver.add_argument("-H", "--hive_partitioning", action="store_true")
driver.add_argument(
"-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit
)
driver.add_argument(
"-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit
)
driver.add_argument("-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit)
driver.add_argument("-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit)
driver.add_argument("-NC", "--cpus_per_node", type=int, default=192)
driver.add_argument("-NM", "--memory_per_node", type=int, default=2000 * GB)

View File

@@ -29,9 +29,7 @@ def urls_sort_benchmark(
delim=r"\t",
)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=num_data_partitions
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=num_data_partitions)
imported_urls = SqlEngineNode(
ctx,
@@ -80,16 +78,10 @@ def urls_sort_benchmark_v2(
sort_cpu_limit=8,
sort_memory_limit=None,
):
dataset = sp.read_csv(
input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t"
)
dataset = sp.read_csv(input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t")
data_partitions = dataset.repartition(num_data_partitions)
urls_partitions = data_partitions.repartition(
num_hash_partitions, hash_by="urlstr", engine_type=engine_type
)
sorted_urls = urls_partitions.partial_sort(
by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit
)
urls_partitions = data_partitions.repartition(num_hash_partitions, hash_by="urlstr", engine_type=engine_type)
sorted_urls = urls_partitions.partial_sort(by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit)
sorted_urls.write_parquet(output_path)
@@ -106,12 +98,8 @@ def main():
num_nodes = driver_args.num_executors
cpus_per_node = 120
partition_rounds = 2
user_args.num_data_partitions = (
user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds
)
user_args.num_hash_partitions = (
user_args.num_hash_partitions or num_nodes * cpus_per_node
)
user_args.num_data_partitions = user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds
user_args.num_hash_partitions = user_args.num_hash_partitions or num_nodes * cpus_per_node
# v1
plan = urls_sort_benchmark(**vars(user_args))

View File

@@ -80,9 +80,7 @@ def check_data(actual: bytes, expected: bytes, offset: int) -> None:
)
expected = expected[index : index + 16]
actual = actual[index : index + 16]
raise ValueError(
f"Data mismatch at offset {offset + index}.\nexpect: {expected}\nactual: {actual}"
)
raise ValueError(f"Data mismatch at offset {offset + index}.\nexpect: {expected}\nactual: {actual}")
def generate_data(offset: int, length: int) -> bytes:
@@ -92,16 +90,10 @@ def generate_data(offset: int, length: int) -> bytes:
"""
istart = offset // 4
iend = (offset + length + 3) // 4
return (
np.arange(istart, iend)
.astype(np.uint32)
.tobytes()[offset % 4 : offset % 4 + length]
)
return np.arange(istart, iend).astype(np.uint32).tobytes()[offset % 4 : offset % 4 + length]
def iter_io_slice(
offset: int, length: int, block_size: Union[int, Tuple[int, int]]
) -> Iterator[Tuple[int, int]]:
def iter_io_slice(offset: int, length: int, block_size: Union[int, Tuple[int, int]]) -> Iterator[Tuple[int, int]]:
"""
Generate the IO (offset, size) for the slice [offset, offset + length) with the given block size.
`block_size` can be an integer or a range [start, end]. If a range is provided, the IO size will be randomly selected from the range.
@@ -161,9 +153,7 @@ def fstest(
if output_path is not None:
os.makedirs(output_path, exist_ok=True)
df = sp.from_items(
[{"path": os.path.join(output_path, f"{i}")} for i in range(npartitions)]
)
df = sp.from_items([{"path": os.path.join(output_path, f"{i}")} for i in range(npartitions)])
df = df.repartition(npartitions, by_rows=True)
stats = df.map(lambda x: fswrite(x["path"], size, blocksize)).to_pandas()
logging.info(f"write stats:\n{stats}")
@@ -187,18 +177,14 @@ if __name__ == "__main__":
python example/fstest.py -o 'fstest' -j 8 -s 1G -i 'fstest/*'
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-o", "--output_path", type=str, help="The output path to write data to."
)
parser.add_argument("-o", "--output_path", type=str, help="The output path to write data to.")
parser.add_argument(
"-i",
"--input_path",
type=str,
help="The input path to read data from. If -o is provided, this is ignored.",
)
parser.add_argument(
"-j", "--npartitions", type=int, help="The number of parallel jobs", default=10
)
parser.add_argument("-j", "--npartitions", type=int, help="The number of parallel jobs", default=10)
parser.add_argument(
"-s",
"--size",

View File

@@ -52,9 +52,7 @@ def shuffle_data(
npartitions=num_out_data_partitions,
partition_by_rows=True,
)
shuffled_urls = StreamCopy(
ctx, (repartitioned,), output_name="data_copy", cpu_limit=1
)
shuffled_urls = StreamCopy(ctx, (repartitioned,), output_name="data_copy", cpu_limit=1)
plan = LogicalPlan(ctx, shuffled_urls)
return plan
@@ -66,9 +64,7 @@ def main():
driver.add_argument("-nd", "--num_data_partitions", type=int, default=1024)
driver.add_argument("-nh", "--num_hash_partitions", type=int, default=3840)
driver.add_argument("-no", "--num_out_data_partitions", type=int, default=1920)
driver.add_argument(
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
)
driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow"))
driver.add_argument("-x", "--skip_hash_partition", action="store_true")
plan = shuffle_data(**driver.get_arguments())
driver.run(plan)

View File

@@ -11,9 +11,7 @@ from smallpond.logical.node import (
)
def shuffle_mock_urls(
input_paths, npartitions: int = 10, sort_rand_keys=True, engine_type="duckdb"
) -> LogicalPlan:
def shuffle_mock_urls(input_paths, npartitions: int = 10, sort_rand_keys=True, engine_type="duckdb") -> LogicalPlan:
ctx = Context()
dataset = ParquetDataSet(input_paths)
data_files = DataSourceNode(ctx, dataset)
@@ -61,9 +59,7 @@ def main():
driver.add_argument("-i", "--input_paths", nargs="+")
driver.add_argument("-n", "--npartitions", type=int, default=500)
driver.add_argument("-s", "--sort_rand_keys", action="store_true")
driver.add_argument(
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
)
driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow"))
plan = shuffle_mock_urls(**driver.get_arguments())
driver.run(plan)

View File

@@ -20,9 +20,7 @@ from smallpond.logical.node import (
class SortUrlsNode(ArrowComputeNode):
def process(
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
logging.info(f"sorting urls by 'host', table shape: {input_tables[0].shape}")
return input_tables[0].sort_by("host")
@@ -90,9 +88,7 @@ def sort_mock_urls(
def main():
driver = Driver()
driver.add_argument(
"-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"]
)
driver.add_argument("-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"])
driver.add_argument("-n", "--npartitions", type=int, default=10)
driver.add_argument("-e", "--engine_type", default="duckdb")

View File

@@ -5,12 +5,8 @@ import smallpond
from smallpond.dataframe import Session
def sort_mock_urls_v2(
sp: Session, input_paths: List[str], output_path: str, npartitions: int
):
dataset = sp.read_csv(
input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t"
).repartition(npartitions)
def sort_mock_urls_v2(sp: Session, input_paths: List[str], output_path: str, npartitions: int):
dataset = sp.read_csv(input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t").repartition(npartitions)
urls = dataset.map(
"""
split_part(urlstr, '/', 1) as host,
@@ -25,9 +21,7 @@ def sort_mock_urls_v2(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"]
)
parser.add_argument("-i", "--input_paths", nargs="+", default=["tests/data/mock_urls/*.tsv"])
parser.add_argument("-o", "--output_path", type=str, default="sort_mock_urls")
parser.add_argument("-n", "--npartitions", type=int, default=10)
args = parser.parse_args()

View File

@@ -41,9 +41,7 @@ def clamp_row_group_size(val, minval=DEFAULT_ROW_GROUP_SIZE, maxval=MAX_ROW_GROU
return clamp_value(val, minval, maxval)
def clamp_row_group_bytes(
val, minval=DEFAULT_ROW_GROUP_BYTES, maxval=MAX_ROW_GROUP_BYTES
):
def clamp_row_group_bytes(val, minval=DEFAULT_ROW_GROUP_BYTES, maxval=MAX_ROW_GROUP_BYTES):
return clamp_value(val, minval, maxval)
@@ -74,10 +72,7 @@ def first_value_in_dict(d: Dict[K, V]) -> V:
def split_into_cols(items: List[V], npartitions: int) -> List[List[V]]:
none = object()
chunks = [items[i : i + npartitions] for i in range(0, len(items), npartitions)]
return [
[x for x in col if x is not none]
for col in itertools.zip_longest([none] * npartitions, *chunks, fillvalue=none)
]
return [[x for x in col if x is not none] for col in itertools.zip_longest([none] * npartitions, *chunks, fillvalue=none)]
def split_into_rows(items: List[V], npartitions: int) -> List[List[V]]:
@@ -101,10 +96,7 @@ def get_nth_partition(items: List[V], n: int, npartitions: int) -> List[V]:
start = n * large_partition_size
items_in_partition = items[start : start + large_partition_size]
else:
start = (
large_partition_size * num_large_partitions
+ (n - num_large_partitions) * small_partition_size
)
start = large_partition_size * num_large_partitions + (n - num_large_partitions) * small_partition_size
items_in_partition = items[start : start + small_partition_size]
return items_in_partition

View File

@@ -8,17 +8,13 @@ from smallpond.logical.node import ArrowComputeNode, ArrowStreamNode
class CopyArrowTable(ArrowComputeNode):
def process(
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
logger.info(f"copying table: {input_tables[0].num_rows} rows ...")
return input_tables[0]
class StreamCopy(ArrowStreamNode):
def process(
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
) -> Iterable[arrow.Table]:
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
for batch in input_readers[0]:
logger.info(f"copying batch: {batch.num_rows} rows ...")
yield arrow.Table.from_batches([batch])

View File

@@ -25,9 +25,7 @@ class LogDataSetTask(PythonScriptTask):
class LogDataSet(PythonScriptNode):
def __init__(
self, ctx: Context, input_deps: Tuple[Node, ...], num_rows=200, **kwargs
) -> None:
def __init__(self, ctx: Context, input_deps: Tuple[Node, ...], num_rows=200, **kwargs) -> None:
super().__init__(ctx, input_deps, **kwargs)
self.num_rows = num_rows

View File

@@ -29,16 +29,12 @@ class ImportWarcFiles(PythonScriptNode):
]
)
def import_warc_file(
self, warc_path: PurePath, parquet_path: PurePath
) -> Tuple[int, int]:
def import_warc_file(self, warc_path: PurePath, parquet_path: PurePath) -> Tuple[int, int]:
total_size = 0
docs = []
with open(warc_path, "rb") as warc_file:
zstd_reader = zstd.ZstdDecompressor().stream_reader(
warc_file, read_size=16 * MB
)
zstd_reader = zstd.ZstdDecompressor().stream_reader(warc_file, read_size=16 * MB)
for record in ArchiveIterator(zstd_reader):
if record.rec_type == "response":
url = record.rec_headers.get_header("WARC-Target-URI")
@@ -48,9 +44,7 @@ class ImportWarcFiles(PythonScriptNode):
total_size += len(content)
docs.append((url, domain, date, content))
table = arrow.Table.from_arrays(
[arrow.array(column) for column in zip(*docs)], schema=self.schema
)
table = arrow.Table.from_arrays([arrow.array(column) for column in zip(*docs)], schema=self.schema)
dump_to_parquet_files(table, parquet_path.parent, parquet_path.name)
return len(docs), total_size
@@ -60,14 +54,9 @@ class ImportWarcFiles(PythonScriptNode):
input_datasets: List[DataSet],
output_path: str,
) -> bool:
warc_paths = [
PurePath(warc_path)
for dataset in input_datasets
for warc_path in dataset.resolved_paths
]
warc_paths = [PurePath(warc_path) for dataset in input_datasets for warc_path in dataset.resolved_paths]
parquet_paths = [
PurePath(output_path)
/ f"data{path_index}-{PurePath(warc_path.name).with_suffix('.parquet')}"
PurePath(output_path) / f"data{path_index}-{PurePath(warc_path.name).with_suffix('.parquet')}"
for path_index, warc_path in enumerate(warc_paths)
]
@@ -75,24 +64,16 @@ class ImportWarcFiles(PythonScriptNode):
for warc_path, parquet_path in zip(warc_paths, parquet_paths):
try:
doc_count, total_size = self.import_warc_file(warc_path, parquet_path)
logger.info(
f"imported {doc_count} web pages ({total_size/MB:.3f}MB) from file '{warc_path}' to '{parquet_path}'"
)
logger.info(f"imported {doc_count} web pages ({total_size/MB:.3f}MB) from file '{warc_path}' to '{parquet_path}'")
except Exception as ex:
logger.opt(exception=ex).error(
f"failed to import web pages from file '{warc_path}'"
)
logger.opt(exception=ex).error(f"failed to import web pages from file '{warc_path}'")
return False
return True
class ExtractHtmlBody(ArrowStreamNode):
unicode_punctuation = "".join(
chr(i)
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
)
unicode_punctuation = "".join(chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P"))
separator_str = string.whitespace + string.punctuation + unicode_punctuation
translator = str.maketrans(separator_str, " " * len(separator_str))
@@ -117,27 +98,19 @@ class ExtractHtmlBody(ArrowStreamNode):
tokens.extend(self.split_string(doc.get_text(" ", strip=True).lower()))
return tokens
except Exception as ex:
logger.opt(exception=ex).error(
f"failed to extract tokens from {url.as_py()}"
)
logger.opt(exception=ex).error(f"failed to extract tokens from {url.as_py()}")
return []
def process(
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
) -> Iterable[arrow.Table]:
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
for batch in input_readers[0]:
urls, domains, dates, contents = batch.columns
doc_tokens = []
try:
for i, (url, content) in enumerate(zip(urls, contents)):
tokens = self.extract_tokens(url, content)
logger.info(
f"#{i}/{len(urls)} extracted {len(tokens)} tokens from {url}"
)
logger.info(f"#{i}/{len(urls)} extracted {len(tokens)} tokens from {url}")
doc_tokens.append(tokens)
yield arrow.Table.from_arrays(
[urls, domains, dates, arrow.array(doc_tokens)], schema=self.schema
)
yield arrow.Table.from_arrays([urls, domains, dates, arrow.array(doc_tokens)], schema=self.schema)
except Exception as ex:
logger.opt(exception=ex).error(f"failed to extract tokens")
break

View File

@@ -35,9 +35,7 @@ class Session(SessionBase):
Subsequent DataFrames can reuse the tasks to avoid recomputation.
"""
def read_csv(
self, paths: Union[str, List[str]], schema: Dict[str, str], delim=","
) -> DataFrame:
def read_csv(self, paths: Union[str, List[str]], schema: Dict[str, str], delim=",") -> DataFrame:
"""
Create a DataFrame from CSV files.
"""
@@ -55,15 +53,11 @@ class Session(SessionBase):
"""
Create a DataFrame from Parquet files.
"""
dataset = ParquetDataSet(
paths, columns=columns, union_by_name=union_by_name, recursive=recursive
)
dataset = ParquetDataSet(paths, columns=columns, union_by_name=union_by_name, recursive=recursive)
plan = DataSourceNode(self._ctx, dataset)
return DataFrame(self, plan)
def read_json(
self, paths: Union[str, List[str]], schema: Dict[str, str]
) -> DataFrame:
def read_json(self, paths: Union[str, List[str]], schema: Dict[str, str]) -> DataFrame:
"""
Create a DataFrame from JSON files.
"""
@@ -115,9 +109,7 @@ class Session(SessionBase):
c = sp.partial_sql("select * from {0} join {1} on a.id = b.id", a, b)
"""
plan = SqlEngineNode(
self._ctx, tuple(input.plan for input in inputs), query, **kwargs
)
plan = SqlEngineNode(self._ctx, tuple(input.plan for input in inputs), query, **kwargs)
recompute = any(input.need_recompute for input in inputs)
return DataFrame(self, plan, recompute=recompute)
@@ -177,26 +169,15 @@ class Session(SessionBase):
"""
Return the total number of tasks and the number of tasks that are finished.
"""
dataset_refs = [
task._dataset_ref
for tasks in self._node_to_tasks.values()
for task in tasks
if task._dataset_ref is not None
]
ready_tasks, _ = ray.wait(
dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False
)
dataset_refs = [task._dataset_ref for tasks in self._node_to_tasks.values() for task in tasks if task._dataset_ref is not None]
ready_tasks, _ = ray.wait(dataset_refs, num_returns=len(dataset_refs), timeout=0, fetch_local=False)
return len(dataset_refs), len(ready_tasks)
def _all_tasks_finished(self) -> bool:
"""
Check if all tasks are finished.
"""
dataset_refs = [
task._dataset_ref
for tasks in self._node_to_tasks.values()
for task in tasks
]
dataset_refs = [task._dataset_ref for tasks in self._node_to_tasks.values() for task in tasks]
try:
ray.get(dataset_refs, timeout=0)
except Exception:
@@ -232,12 +213,8 @@ class DataFrame:
# optimize the plan
if self.optimized_plan is None:
logger.info(f"optimizing\n{LogicalPlan(self.session._ctx, self.plan)}")
self.optimized_plan = Optimizer(
exclude_nodes=set(self.session._node_to_tasks.keys())
).visit(self.plan)
logger.info(
f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}"
)
self.optimized_plan = Optimizer(exclude_nodes=set(self.session._node_to_tasks.keys())).visit(self.plan)
logger.info(f"optimized\n{LogicalPlan(self.session._ctx, self.optimized_plan)}")
# return the tasks if already created
if tasks := self.session._node_to_tasks.get(self.optimized_plan):
return tasks
@@ -281,9 +258,7 @@ class DataFrame:
"""
for retry_count in range(3):
try:
return ray.get(
[task.run_on_ray() for task in self._get_or_create_tasks()]
)
return ray.get([task.run_on_ray() for task in self._get_or_create_tasks()])
except ray.exceptions.RuntimeEnvSetupError as e:
# XXX: Ray may raise this error when a worker is interrupted.
# ```
@@ -361,9 +336,7 @@ class DataFrame:
)
elif hash_by is not None:
hash_columns = [hash_by] if isinstance(hash_by, str) else hash_by
plan = HashPartitionNode(
self.session._ctx, (self.plan,), npartitions, hash_columns, **kwargs
)
plan = HashPartitionNode(self.session._ctx, (self.plan,), npartitions, hash_columns, **kwargs)
else:
plan = EvenlyDistributedPartitionNode(
self.session._ctx,
@@ -420,9 +393,7 @@ class DataFrame:
)
return DataFrame(self.session, plan, recompute=self.need_recompute)
def filter(
self, sql_or_func: Union[str, Callable[[Dict[str, Any]], bool]], **kwargs
) -> DataFrame:
def filter(self, sql_or_func: Union[str, Callable[[Dict[str, Any]], bool]], **kwargs) -> DataFrame:
"""
Filter out rows that don't satisfy the given predicate.
@@ -453,13 +424,9 @@ class DataFrame:
table = tables[0]
return table.filter([func(row) for row in table.to_pylist()])
plan = ArrowBatchNode(
self.session._ctx, (self.plan,), process_func=process_func, **kwargs
)
plan = ArrowBatchNode(self.session._ctx, (self.plan,), process_func=process_func, **kwargs)
else:
raise ValueError(
"condition must be a SQL expression or a predicate function"
)
raise ValueError("condition must be a SQL expression or a predicate function")
return DataFrame(self.session, plan, recompute=self.need_recompute)
def map(
@@ -510,18 +477,14 @@ class DataFrame:
"""
if isinstance(sql := sql_or_func, str):
plan = SqlEngineNode(
self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs
)
plan = SqlEngineNode(self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs)
elif isinstance(func := sql_or_func, Callable):
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
output_rows = [func(row) for row in tables[0].to_pylist()]
return arrow.Table.from_pylist(output_rows, schema=schema)
plan = ArrowBatchNode(
self.session._ctx, (self.plan,), process_func=process_func, **kwargs
)
plan = ArrowBatchNode(self.session._ctx, (self.plan,), process_func=process_func, **kwargs)
else:
raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}")
return DataFrame(self.session, plan, recompute=self.need_recompute)
@@ -555,20 +518,14 @@ class DataFrame:
"""
if isinstance(sql := sql_or_func, str):
plan = SqlEngineNode(
self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs
)
plan = SqlEngineNode(self.session._ctx, (self.plan,), f"select {sql} from {{0}}", **kwargs)
elif isinstance(func := sql_or_func, Callable):
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
output_rows = [
item for row in tables[0].to_pylist() for item in func(row)
]
output_rows = [item for row in tables[0].to_pylist() for item in func(row)]
return arrow.Table.from_pylist(output_rows, schema=schema)
plan = ArrowBatchNode(
self.session._ctx, (self.plan,), process_func=process_func, **kwargs
)
plan = ArrowBatchNode(self.session._ctx, (self.plan,), process_func=process_func, **kwargs)
else:
raise ValueError(f"must be a SQL expression or a function: {sql_or_func!r}")
return DataFrame(self.session, plan, recompute=self.need_recompute)
@@ -642,9 +599,7 @@ class DataFrame:
sp.wait(o1, o2)
"""
plan = DataSinkNode(
self.session._ctx, (self.plan,), os.path.abspath(path), type="link_or_copy"
)
plan = DataSinkNode(self.session._ctx, (self.plan,), os.path.abspath(path), type="link_or_copy")
return DataFrame(self.session, plan, recompute=self.need_recompute)
# inspection
@@ -710,6 +665,4 @@ class DataFrame:
"""
datasets = self._compute()
with ThreadPoolExecutor() as pool:
return arrow.concat_tables(
pool.map(lambda dataset: dataset.to_arrow_table(), datasets)
)
return arrow.concat_tables(pool.map(lambda dataset: dataset.to_arrow_table(), datasets))

View File

@@ -28,43 +28,23 @@ class Driver(object):
self.all_args = None
def _create_driver_args_parser(self):
parser = argparse.ArgumentParser(
prog="driver.py", description="Smallpond Driver", add_help=False
)
parser.add_argument(
"mode", choices=["executor", "scheduler", "ray"], default="executor"
)
parser.add_argument(
"--exec_id", default=socket.gethostname(), help="Unique executor id"
)
parser = argparse.ArgumentParser(prog="driver.py", description="Smallpond Driver", add_help=False)
parser.add_argument("mode", choices=["executor", "scheduler", "ray"], default="executor")
parser.add_argument("--exec_id", default=socket.gethostname(), help="Unique executor id")
parser.add_argument("--job_id", type=str, help="Unique job id")
parser.add_argument(
"--job_time", type=float, help="Job create time (seconds since epoch)"
)
parser.add_argument(
"--job_name", default="smallpond", help="Display name of the job"
)
parser.add_argument("--job_time", type=float, help="Job create time (seconds since epoch)")
parser.add_argument("--job_name", default="smallpond", help="Display name of the job")
parser.add_argument(
"--job_priority",
type=int,
help="Job priority",
)
parser.add_argument("--resource_group", type=str, help="Resource group")
parser.add_argument(
"--env_variables", nargs="*", default=[], help="Env variables for the job"
)
parser.add_argument(
"--sidecars", nargs="*", default=[], help="Sidecars for the job"
)
parser.add_argument(
"--tags", nargs="*", default=[], help="Tags for submitted platform task"
)
parser.add_argument(
"--task_image", default="default", help="Container image of platform task"
)
parser.add_argument(
"--python_venv", type=str, help="Python virtual env for the job"
)
parser.add_argument("--env_variables", nargs="*", default=[], help="Env variables for the job")
parser.add_argument("--sidecars", nargs="*", default=[], help="Sidecars for the job")
parser.add_argument("--tags", nargs="*", default=[], help="Tags for submitted platform task")
parser.add_argument("--task_image", default="default", help="Container image of platform task")
parser.add_argument("--python_venv", type=str, help="Python virtual env for the job")
parser.add_argument(
"--data_root",
type=str,
@@ -257,9 +237,7 @@ class Driver(object):
default="DEBUG",
choices=log_level_choices,
)
parser.add_argument(
"--disable_log_rotation", action="store_true", help="Disable log rotation"
)
parser.add_argument("--disable_log_rotation", action="store_true", help="Disable log rotation")
parser.add_argument(
"--output_path",
help="Set the output directory of final results and all nodes that have output_name but no output_path specified",
@@ -279,9 +257,7 @@ class Driver(object):
def parse_arguments(self, args=None):
if self.user_args is None or self.driver_args is None:
args_parser = argparse.ArgumentParser(
parents=[self.driver_args_parser, self.user_args_parser]
)
args_parser = argparse.ArgumentParser(parents=[self.driver_args_parser, self.user_args_parser])
self.all_args = args_parser.parse_args(args)
self.user_args, other_args = self.user_args_parser.parse_known_args(args)
self.driver_args = self.driver_args_parser.parse_args(other_args)
@@ -349,9 +325,7 @@ class Driver(object):
DataFrame(sp, plan.root_node).compute()
retval = True
elif args.mode == "executor":
assert os.path.isfile(
args.runtime_ctx_path
), f"cannot find runtime context: {args.runtime_ctx_path}"
assert os.path.isfile(args.runtime_ctx_path), f"cannot find runtime context: {args.runtime_ctx_path}"
runtime_ctx: RuntimeContext = load(args.runtime_ctx_path)
if runtime_ctx.bind_numa_node:
@@ -371,9 +345,7 @@ class Driver(object):
retval = run_executor(runtime_ctx, args.exec_id)
elif args.mode == "scheduler":
assert plan is not None
jobmgr = JobManager(
args.data_root, args.python_venv, args.task_image, args.platform
)
jobmgr = JobManager(args.data_root, args.python_venv, args.task_image, args.platform)
exec_plan = jobmgr.run(
plan,
job_id=args.job_id,

View File

@@ -35,9 +35,7 @@ class SimplePoolTask(object):
def join(self, timeout=None):
self.proc.join(timeout)
if not self.ready() and timeout is not None:
logger.warning(
f"worker process {self.proc.name}({self.proc.pid}) does not exit after {timeout} secs, stopping it"
)
logger.warning(f"worker process {self.proc.name}({self.proc.pid}) does not exit after {timeout} secs, stopping it")
self.terminate()
self.proc.join()
@@ -45,17 +43,11 @@ class SimplePoolTask(object):
return self.proc.pid and not self.proc.is_alive()
def exitcode(self):
assert (
self.ready()
), f"worker process {self.proc.name}({self.proc.pid}) has not exited yet"
assert self.ready(), f"worker process {self.proc.name}({self.proc.pid}) has not exited yet"
if self.stopping:
logger.info(
f"worker process stopped: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}"
)
logger.info(f"worker process stopped: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}")
elif self.proc.exitcode != 0:
logger.error(
f"worker process crashed: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}"
)
logger.error(f"worker process crashed: {self.proc.name}({self.proc.pid}), exitcode: {self.proc.exitcode}")
return self.proc.exitcode
@@ -79,9 +71,7 @@ class SimplePool(object):
def update_queue(self):
self.running_tasks = [t for t in self.running_tasks if not t.ready()]
tasks_to_run = self.queued_tasks[: self.pool_size - len(self.running_tasks)]
self.queued_tasks = self.queued_tasks[
self.pool_size - len(self.running_tasks) :
]
self.queued_tasks = self.queued_tasks[self.pool_size - len(self.running_tasks) :]
for task in tasks_to_run:
task.start()
self.running_tasks += tasks_to_run
@@ -97,9 +87,7 @@ class Executor(object):
The task executor.
"""
def __init__(
self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue
) -> None:
def __init__(self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue) -> None:
self.ctx = ctx
self.id = id
self.wq = wq
@@ -135,9 +123,7 @@ class Executor(object):
def process_work(item: WorkItem, cq: WorkQueue):
item.exec(cq)
cq.push(item)
logger.info(
f"finished work: {repr(item)}, status: {item.status}, elapsed time: {item.elapsed_time:.3f} secs"
)
logger.info(f"finished work: {repr(item)}, status: {item.status}, elapsed time: {item.elapsed_time:.3f} secs")
logger.complete()
# for test
@@ -146,16 +132,12 @@ class Executor(object):
# for test
def skip_probes(self, epochs: int):
self.wq.push(
Probe(self.ctx, f".FalseFail-{self.id}", epoch=0, epochs_to_skip=epochs)
)
self.wq.push(Probe(self.ctx, f".FalseFail-{self.id}", epoch=0, epochs_to_skip=epochs))
@logger.catch(reraise=True, message="executor terminated unexpectedly")
def run(self) -> bool:
mp.current_process().name = "ExecutorMainProcess"
logger.info(
f"start to run executor {self.id} on numa node #{self.ctx.numa_node_id} of {socket.gethostname()}"
)
logger.info(f"start to run executor {self.id} on numa node #{self.ctx.numa_node_id} of {socket.gethostname()}")
with SimplePool(self.ctx.usable_cpu_count + 1) as pool:
retval = self.exec_loop(pool)
@@ -173,34 +155,20 @@ class Executor(object):
try:
items = self.wq.pop(count=self.ctx.usable_cpu_count)
except Exception as ex:
logger.opt(exception=ex).critical(
f"failed to pop from work queue: {self.wq}"
)
logger.opt(exception=ex).critical(f"failed to pop from work queue: {self.wq}")
self.running = False
items = []
if not items:
secs_quiet_period = time.time() - latest_probe_time
if (
secs_quiet_period > self.ctx.secs_executor_probe_interval * 2
and os.path.exists(self.ctx.job_status_path)
):
if secs_quiet_period > self.ctx.secs_executor_probe_interval * 2 and os.path.exists(self.ctx.job_status_path):
with open(self.ctx.job_status_path) as status_file:
if (
status := status_file.read().strip()
) and not status.startswith("running"):
logger.critical(
f"job scheduler already stopped: {status}, stopping executor"
)
if (status := status_file.read().strip()) and not status.startswith("running"):
logger.critical(f"job scheduler already stopped: {status}, stopping executor")
self.running = False
break
if (
secs_quiet_period > self.ctx.secs_executor_probe_timeout * 2
and not pytest_running()
):
logger.critical(
f"no probe received for {secs_quiet_period:.1f} secs, stopping executor"
)
if secs_quiet_period > self.ctx.secs_executor_probe_timeout * 2 and not pytest_running():
logger.critical(f"no probe received for {secs_quiet_period:.1f} secs, stopping executor")
self.running = False
break
# no pending works, so wait a few seconds before checking results
@@ -216,9 +184,7 @@ class Executor(object):
if isinstance(item, StopWorkItem):
running_work = self.running_works.get(item.work_to_stop, None)
if running_work is None:
logger.debug(
f"cannot find {item.work_to_stop} in running works of {self.id}"
)
logger.debug(f"cannot find {item.work_to_stop} in running works of {self.id}")
self.cq.push(item)
else:
logger.info(f"stopping work: {item.work_to_stop}")
@@ -250,20 +216,14 @@ class Executor(object):
self.collect_finished_works()
time.sleep(self.ctx.secs_wq_poll_interval)
item._local_gpu = granted_gpus
logger.info(
f"{repr(item)} is assigned to run on GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }"
)
logger.info(f"{repr(item)} is assigned to run on GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }")
# enqueue work item to the pool
self.running_works[item.key] = (
pool.apply_async(
func=Executor.process_work, args=(item, self.cq), name=item.key
),
pool.apply_async(func=Executor.process_work, args=(item, self.cq), name=item.key),
item,
)
logger.info(
f"started work: {repr(item)}, {len(self.running_works)} running works: {list(self.running_works.keys())[:10]}..."
)
logger.info(f"started work: {repr(item)}, {len(self.running_works)} running works: {list(self.running_works.keys())[:10]}...")
# start to run works
pool.update_queue()
@@ -287,15 +247,11 @@ class Executor(object):
work.join()
if (exitcode := work.exitcode()) != 0:
item.status = WorkStatus.CRASHED
item.exception = NonzeroExitCode(
f"worker process {work.proc.name}({work.proc.pid}) exited with non-zero code {exitcode}"
)
item.exception = NonzeroExitCode(f"worker process {work.proc.name}({work.proc.pid}) exited with non-zero code {exitcode}")
try:
self.cq.push(item)
except Exception as ex:
logger.opt(exception=ex).critical(
f"failed to push into completion queue: {self.cq}"
)
logger.opt(exception=ex).critical(f"failed to push into completion queue: {self.cq}")
self.running = False
finished_works.append(item)
@@ -304,9 +260,7 @@ class Executor(object):
self.running_works.pop(item.key)
if item._local_gpu:
self.release_gpu(item._local_gpu)
logger.info(
f"{repr(item)} released GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }"
)
logger.info(f"{repr(item)} released GPU: { {gpu.id: quota for gpu, quota in item._local_gpu.items()} }")
def acquire_gpu(self, quota: float) -> Dict[GPU, float]:
"""
@@ -336,6 +290,4 @@ class Executor(object):
"""
for gpu, quota in gpus.items():
self.local_gpus[gpu] += quota
assert (
self.local_gpus[gpu] <= 1.0
), f"GPU {gpu} quota is greater than 1.0: {self.local_gpus[gpu]}"
assert self.local_gpus[gpu] <= 1.0, f"GPU {gpu} quota is greater than 1.0: {self.local_gpus[gpu]}"

View File

@@ -26,13 +26,9 @@ class SchedStateExporter(Scheduler.StateObserver):
if sched_state.large_runtime_state:
logger.debug(f"pause exporting scheduler state")
elif sched_state.num_local_running_works > 0:
logger.debug(
f"pause exporting scheduler state: {sched_state.num_local_running_works} local running works"
)
logger.debug(f"pause exporting scheduler state: {sched_state.num_local_running_works} local running works")
else:
dump(
sched_state, self.sched_state_path, buffering=32 * MB, atomic_write=True
)
dump(sched_state, self.sched_state_path, buffering=32 * MB, atomic_write=True)
sched_state.log_overall_progress()
logger.debug(f"exported scheduler state to {self.sched_state_path}")
@@ -97,12 +93,8 @@ class JobManager(object):
bind_numa_node=False,
enforce_memory_limit=False,
share_log_analytics: Optional[bool] = None,
console_log_level: Literal[
"CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"
] = "INFO",
file_log_level: Literal[
"CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"
] = "DEBUG",
console_log_level: Literal["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] = "INFO",
file_log_level: Literal["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] = "DEBUG",
disable_log_rotation=False,
sched_state_observers: Optional[List[Scheduler.StateObserver]] = None,
output_path: Optional[str] = None,
@@ -111,11 +103,7 @@ class JobManager(object):
logger.info(f"using platform: {self.platform}")
job_id = JobId(hex=job_id or self.platform.default_job_id())
job_time = (
datetime.fromtimestamp(job_time)
if job_time is not None
else self.platform.default_job_time()
)
job_time = datetime.fromtimestamp(job_time) if job_time is not None else self.platform.default_job_time()
malloc_path = ""
if memory_allocator == "system":
@@ -131,19 +119,10 @@ class JobManager(object):
arrow_default_malloc=memory_allocator,
).splitlines()
env_overrides = env_overrides + (env_variables or [])
env_overrides = dict(
tuple(kv.strip().split("=", maxsplit=1))
for kv in filter(None, env_overrides)
)
env_overrides = dict(tuple(kv.strip().split("=", maxsplit=1)) for kv in filter(None, env_overrides))
share_log_analytics = (
share_log_analytics
if share_log_analytics is not None
else self.platform.default_share_log_analytics()
)
shared_log_root = (
self.platform.shared_log_root() if share_log_analytics else None
)
share_log_analytics = share_log_analytics if share_log_analytics is not None else self.platform.default_share_log_analytics()
shared_log_root = self.platform.shared_log_root() if share_log_analytics else None
runtime_ctx = RuntimeContext(
job_id,
@@ -167,9 +146,7 @@ class JobManager(object):
**kwargs,
)
runtime_ctx.initialize(socket.gethostname(), root_exist_ok=True)
logger.info(
f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}"
)
logger.info(f"command-line arguments: {' '.join([os.path.basename(sys.argv[0]), *sys.argv[1:]])}")
dump(runtime_ctx, runtime_ctx.runtime_ctx_path, atomic_write=True)
logger.info(f"saved runtime context at {runtime_ctx.runtime_ctx_path}")
@@ -178,13 +155,9 @@ class JobManager(object):
logger.info(f"saved logcial plan at {runtime_ctx.logcial_plan_path}")
plan.graph().render(runtime_ctx.logcial_plan_graph_path, format="png")
logger.info(
f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png"
)
logger.info(f"saved logcial plan graph at {runtime_ctx.logcial_plan_graph_path}.png")
exec_plan = Planner(runtime_ctx).create_exec_plan(
plan, manifest_only_final_results
)
exec_plan = Planner(runtime_ctx).create_exec_plan(plan, manifest_only_final_results)
dump(exec_plan, runtime_ctx.exec_plan_path, atomic_write=True)
logger.info(f"saved execution plan at {runtime_ctx.exec_plan_path}")
@@ -229,9 +202,7 @@ class JobManager(object):
sched_state_observers.insert(0, sched_state_exporter)
if os.path.exists(runtime_ctx.sched_state_path):
logger.warning(
f"loading scheduler state from: {runtime_ctx.sched_state_path}"
)
logger.warning(f"loading scheduler state from: {runtime_ctx.sched_state_path}")
scheduler: Scheduler = load(runtime_ctx.sched_state_path)
scheduler.sched_epoch += 1
scheduler.sched_state_observers = sched_state_observers

View File

@@ -54,9 +54,7 @@ class ExecutorState(Enum):
class RemoteExecutor(object):
def __init__(
self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue, init_epoch=0
) -> None:
def __init__(self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue, init_epoch=0) -> None:
self.ctx = ctx
self.id = id
self.wq = wq
@@ -79,9 +77,7 @@ state={self.state}, probe={self.last_acked_probe}"
return f"RemoteExecutor({self.id}):{self.state}"
@staticmethod
def create(
ctx: RuntimeContext, id: str, queue_dir: str, init_epoch=0
) -> "RemoteExecutor":
def create(ctx: RuntimeContext, id: str, queue_dir: str, init_epoch=0) -> "RemoteExecutor":
wq = WorkQueueOnFilesystem(os.path.join(queue_dir, "wq"))
cq = WorkQueueOnFilesystem(os.path.join(queue_dir, "cq"))
return RemoteExecutor(ctx, id, wq, cq, init_epoch)
@@ -173,9 +169,7 @@ state={self.state}, probe={self.last_acked_probe}"
return self.cpu_count - self.cpu_count // 16
def add_running_work(self, item: WorkItem):
assert (
item.key not in self.running_works
), f"duplicate work item assigned to {repr(self)}: {item.key}"
assert item.key not in self.running_works, f"duplicate work item assigned to {repr(self)}: {item.key}"
self.running_works[item.key] = item
self._allocated_cpus += item.cpu_limit
self._allocated_gpus += item.gpu_limit
@@ -219,9 +213,7 @@ state={self.state}, probe={self.last_acked_probe}"
def push(self, item: WorkItem, buffering=False) -> bool:
if item.key in self.running_works:
logger.warning(
f"work item {item.key} already exists in running works of {self}"
)
logger.warning(f"work item {item.key} already exists in running works of {self}")
return False
item.start_time = time.time()
item.exec_id = self.id
@@ -250,9 +242,7 @@ state={self.state}, probe={self.last_acked_probe}"
elif num_missed_probes > self.ctx.max_num_missed_probes:
if self.state != ExecutorState.FAIL:
self.state = ExecutorState.FAIL
logger.error(
f"find failed executor: {self}, missed probes: {num_missed_probes}, current epoch: {current_epoch}"
)
logger.error(f"find failed executor: {self}, missed probes: {num_missed_probes}, current epoch: {current_epoch}")
return True
elif self.state == ExecutorState.STOPPING:
if self.stop_request_acked:
@@ -277,9 +267,7 @@ state={self.state}, probe={self.last_acked_probe}"
class LocalExecutor(RemoteExecutor):
def __init__(
self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue
) -> None:
def __init__(self, ctx: RuntimeContext, id: str, wq: WorkQueue, cq: WorkQueue) -> None:
super().__init__(ctx, id, wq, cq)
self.work = None
self.running = False
@@ -321,9 +309,7 @@ class LocalExecutor(RemoteExecutor):
if item.gpu_limit > 0:
assert len(local_gpus) > 0
item._local_gpu = local_gpus[0]
logger.info(
f"{repr(item)} is assigned to run on GPU #{item.local_rank}: {item.local_gpu}"
)
logger.info(f"{repr(item)} is assigned to run on GPU #{item.local_rank}: {item.local_gpu}")
item = copy.copy(item)
item.exec()
@@ -368,9 +354,7 @@ class Scheduler(object):
self.callback = callback
def __repr__(self) -> str:
return (
repr(self.callback) if self.callback is not None else super().__repr__()
)
return repr(self.callback) if self.callback is not None else super().__repr__()
__str__ = __repr__
@@ -403,9 +387,7 @@ class Scheduler(object):
self.stop_executor_on_failure = stop_executor_on_failure
self.nonzero_exitcode_as_oom = nonzero_exitcode_as_oom
self.remove_output_root = remove_output_root
self.sched_state_observers: List[Scheduler.StateObserver] = (
sched_state_observers or []
)
self.sched_state_observers: List[Scheduler.StateObserver] = sched_state_observers or []
self.secs_state_notify_interval = self.ctx.secs_executor_probe_interval * 2
# task states
self.local_queue: List[Task] = []
@@ -414,11 +396,7 @@ class Scheduler(object):
self.scheduled_tasks: Dict[TaskRuntimeId, Task] = OrderedDict()
self.finished_tasks: Dict[TaskRuntimeId, Task] = OrderedDict()
self.succeeded_tasks: Dict[str, Task] = OrderedDict()
self.nontrivial_tasks = dict(
(key, task)
for (key, task) in self.tasks.items()
if not task.exec_on_scheduler
)
self.nontrivial_tasks = dict((key, task) for (key, task) in self.tasks.items() if not task.exec_on_scheduler)
self.succeeded_nontrivial_tasks: Dict[str, Task] = OrderedDict()
# executor pool
self.local_executor = LocalExecutor.create(self.ctx, "localhost")
@@ -463,18 +441,11 @@ class Scheduler(object):
@property
def running_works(self) -> Iterable[WorkItem]:
return (
work
for executor in (self.alive_executors + self.local_executors)
for work in executor.running_works.values()
)
return (work for executor in (self.alive_executors + self.local_executors) for work in executor.running_works.values())
@property
def num_running_works(self) -> int:
return sum(
len(executor.running_works)
for executor in (self.alive_executors + self.local_executors)
)
return sum(len(executor.running_works) for executor in (self.alive_executors + self.local_executors))
@property
def num_local_running_works(self) -> int:
@@ -489,11 +460,7 @@ class Scheduler(object):
@property
def pending_nontrivial_tasks(self) -> Dict[str, Task]:
return dict(
(key, task)
for key, task in self.nontrivial_tasks.items()
if key not in self.succeeded_nontrivial_tasks
)
return dict((key, task) for key, task in self.nontrivial_tasks.items() if key not in self.succeeded_nontrivial_tasks)
@property
def num_pending_nontrivial_tasks(self) -> int:
@@ -504,33 +471,20 @@ class Scheduler(object):
@property
def succeeded_task_ids(self) -> Set[TaskRuntimeId]:
return set(
TaskRuntimeId(task.id, task.sched_epoch, task.retry_count)
for task in self.succeeded_tasks.values()
)
return set(TaskRuntimeId(task.id, task.sched_epoch, task.retry_count) for task in self.succeeded_tasks.values())
@property
def abandoned_tasks(self) -> List[Task]:
succeeded_task_ids = self.succeeded_task_ids
return [
task
for task in {**self.scheduled_tasks, **self.finished_tasks}.values()
if task.runtime_id not in succeeded_task_ids
]
return [task for task in {**self.scheduled_tasks, **self.finished_tasks}.values() if task.runtime_id not in succeeded_task_ids]
@cached_property
def remote_executors(self) -> List[RemoteExecutor]:
return [
executor
for executor in self.available_executors.values()
if not executor.local
]
return [executor for executor in self.available_executors.values() if not executor.local]
@cached_property
def local_executors(self) -> List[RemoteExecutor]:
return [
executor for executor in self.available_executors.values() if executor.local
]
return [executor for executor in self.available_executors.values() if executor.local]
@cached_property
def working_executors(self) -> List[RemoteExecutor]:
@@ -592,10 +546,7 @@ class Scheduler(object):
def start_speculative_execution(self):
for executor in self.working_executors:
for idx, item in enumerate(executor.running_works.values()):
aggressive_retry = (
self.aggressive_speculative_exec
and len(self.good_executors) >= self.ctx.num_executors
)
aggressive_retry = self.aggressive_speculative_exec and len(self.good_executors) >= self.ctx.num_executors
short_sched_queue = len(self.sched_queue) < len(self.good_executors)
if (
isinstance(item, Task)
@@ -603,8 +554,7 @@ class Scheduler(object):
and item.allow_speculative_exec
and item.retry_count < self.max_retry_count
and item.retry_count == self.tasks[item.key].retry_count
and (logical_node := self.logical_nodes.get(item.node_id, None))
is not None
and (logical_node := self.logical_nodes.get(item.node_id, None)) is not None
):
perf_stats = logical_node.get_perf_stats("elapsed wall time (secs)")
if perf_stats is not None and perf_stats.cnt >= 20:
@@ -639,12 +589,8 @@ class Scheduler(object):
if entry.is_dir():
_, exec_id = os.path.split(entry.path)
if exec_id not in self.available_executors:
self.available_executors[exec_id] = RemoteExecutor.create(
self.ctx, exec_id, entry.path, self.probe_epoch
)
logger.info(
f"find a new executor #{len(self.available_executors)}: {self.available_executors[exec_id]}"
)
self.available_executors[exec_id] = RemoteExecutor.create(self.ctx, exec_id, entry.path, self.probe_epoch)
logger.info(f"find a new executor #{len(self.available_executors)}: {self.available_executors[exec_id]}")
self.clear_cached_executor_lists()
# start a new probe epoch
self.last_executor_probe_time = time.time()
@@ -668,9 +614,7 @@ class Scheduler(object):
item.status = WorkStatus.EXEC_FAILED
item.finish_time = time.time()
if isinstance(item, Task) and item.key not in self.succeeded_tasks:
logger.warning(
f"reschedule {repr(item)} on failed executor: {repr(executor)}"
)
logger.warning(f"reschedule {repr(item)} on failed executor: {repr(executor)}")
self.try_enqueue(self.get_retry_task(item.key))
if any(executor_state_changed):
@@ -690,9 +634,7 @@ class Scheduler(object):
# remove the reference to input deps
task.input_deps = {dep_key: None for dep_key in task.input_deps}
# feed input datasets
task.input_datasets = [
self.succeeded_tasks[dep_key].output for dep_key in task.input_deps
]
task.input_datasets = [self.succeeded_tasks[dep_key].output for dep_key in task.input_deps]
task.sched_epoch = self.sched_epoch
return task
@@ -713,9 +655,7 @@ class Scheduler(object):
task.dataset = finished_task.dataset
def get_runnable_tasks(self, finished_task: Task) -> Iterable[Task]:
assert (
finished_task.status == WorkStatus.SUCCEED
), f"task not succeeded: {finished_task}"
assert finished_task.status == WorkStatus.SUCCEED, f"task not succeeded: {finished_task}"
for output_key in finished_task.output_deps:
output_dep = self.tasks[output_key]
if all(key in self.succeeded_tasks for key in output_dep.input_deps):
@@ -730,14 +670,8 @@ class Scheduler(object):
for executor in self.remote_executors:
running_task = executor.running_works.get(task_key, None)
if running_task is not None:
logger.info(
f"try to stop {repr(running_task)} running on {repr(executor)}"
)
executor.wq.push(
StopWorkItem(
f".StopWorkItem-{repr(running_task)}", running_task.key
)
)
logger.info(f"try to stop {repr(running_task)} running on {repr(executor)}")
executor.wq.push(StopWorkItem(f".StopWorkItem-{repr(running_task)}", running_task.key))
def try_relax_memory_limit(self, task: Task, executor: RemoteExecutor) -> bool:
if task.memory_limit >= executor.memory_size:
@@ -745,9 +679,7 @@ class Scheduler(object):
return False
relaxed_memory_limit = min(executor.memory_size, task.memory_limit * 2)
task._memory_boost = relaxed_memory_limit / task._memory_limit
logger.warning(
f"relax memory limit of {task.key} to {task.memory_limit/GB:.3f}GB and retry ..."
)
logger.warning(f"relax memory limit of {task.key} to {task.memory_limit/GB:.3f}GB and retry ...")
return True
def try_boost_resource(self, item: WorkItem, executor: RemoteExecutor):
@@ -777,9 +709,7 @@ class Scheduler(object):
if item._cpu_limit < boost_cpu or item._memory_limit < boost_mem:
item._cpu_boost = boost_cpu / item._cpu_limit
item._memory_boost = boost_mem / item._memory_limit
logger.info(
f"boost resource usage of {repr(item)}: {item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB"
)
logger.info(f"boost resource usage of {repr(item)}: {item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB")
def get_retry_task(self, key: str) -> Task:
task = self.tasks[key]
@@ -794,9 +724,7 @@ class Scheduler(object):
remove_path(self.ctx.staging_root)
if abandoned_tasks := self.abandoned_tasks:
logger.info(
f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ..."
)
logger.info(f"removing outputs of {len(abandoned_tasks)} abandoned tasks: {abandoned_tasks[:3]} ...")
assert list(pool.map(lambda t: t.clean_output(force=True), abandoned_tasks))
@logger.catch(reraise=pytest_running(), message="failed to export task metrics")
@@ -825,15 +753,9 @@ class Scheduler(object):
buffering=32 * MB,
)
task_props = arrow.array(
pristine_attrs_dict(task) for task in self.finished_tasks.values()
)
partition_infos = arrow.array(
task.partition_infos_as_dict for task in self.finished_tasks.values()
)
perf_metrics = arrow.array(
dict(task.perf_metrics) for task in self.finished_tasks.values()
)
task_props = arrow.array(pristine_attrs_dict(task) for task in self.finished_tasks.values())
partition_infos = arrow.array(task.partition_infos_as_dict for task in self.finished_tasks.values())
perf_metrics = arrow.array(dict(task.perf_metrics) for task in self.finished_tasks.values())
task_metrics = arrow.Table.from_arrays(
[task_props, partition_infos, perf_metrics],
names=["task_props", "partition_infos", "perf_metrics"],
@@ -862,12 +784,7 @@ class Scheduler(object):
[
dict(
task=repr(task),
node=(
repr(node)
if (node := self.logical_nodes.get(task.node_id, None))
is not None
else "StandaloneTasks"
),
node=(repr(node) if (node := self.logical_nodes.get(task.node_id, None)) is not None else "StandaloneTasks"),
status=str(task.status),
executor=task.exec_id,
start_time=datetime.fromtimestamp(task.start_time),
@@ -925,23 +842,16 @@ class Scheduler(object):
fig_filename, _ = fig_title.split(" - ", maxsplit=1)
fig_filename += ".html"
fig_path = os.path.join(self.ctx.log_root, fig_filename)
fig.update_yaxes(
autorange="reversed"
) # otherwise tasks are listed from the bottom up
fig.update_yaxes(autorange="reversed") # otherwise tasks are listed from the bottom up
fig.update_traces(marker_line_color="black", marker_line_width=1, opacity=1)
fig.write_html(
fig_path, include_plotlyjs="cdn" if pytest_running() else True
)
fig.write_html(fig_path, include_plotlyjs="cdn" if pytest_running() else True)
if self.ctx.shared_log_root:
shutil.copy(fig_path, self.ctx.shared_log_root)
logger.debug(f"exported timeline figure to {fig_path}")
def notify_state_observers(self, force_notify=False) -> bool:
secs_since_last_state_notify = time.time() - self.last_state_notify_time
if (
force_notify
or secs_since_last_state_notify >= self.secs_state_notify_interval
):
if force_notify or secs_since_last_state_notify >= self.secs_state_notify_interval:
self.last_state_notify_time = time.time()
for observer in self.sched_state_observers:
if force_notify or observer.enabled:
@@ -949,14 +859,10 @@ class Scheduler(object):
observer.update(self)
elapsed_time = time.time() - start_time
if elapsed_time >= self.ctx.secs_executor_probe_interval / 2:
self.secs_state_notify_interval = (
self.ctx.secs_executor_probe_timeout
)
self.secs_state_notify_interval = self.ctx.secs_executor_probe_timeout
if elapsed_time >= self.ctx.secs_executor_probe_interval:
observer.enabled = False
logger.warning(
f"disabled slow scheduler state observer (elapsed time: {elapsed_time:.1f} secs): {observer}"
)
logger.warning(f"disabled slow scheduler state observer (elapsed time: {elapsed_time:.1f} secs): {observer}")
return True
else:
return False
@@ -984,9 +890,7 @@ class Scheduler(object):
def run(self) -> bool:
mp.current_process().name = f"SchedulerMainProcess#{self.sched_epoch}"
logger.info(
f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}"
)
logger.info(f"start to run scheduler #{self.sched_epoch} on {socket.gethostname()}")
perf_profile = None
if self.ctx.enable_profiling:
@@ -1001,48 +905,30 @@ class Scheduler(object):
self.prioritize_retry |= self.sched_epoch > 0
if self.local_queue or self.sched_queue:
pending_tasks = [
item
for item in self.local_queue + self.sched_queue
if isinstance(item, Task)
]
pending_tasks = [item for item in self.local_queue + self.sched_queue if isinstance(item, Task)]
self.local_queue.clear()
self.sched_queue.clear()
logger.info(
f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ..."
)
logger.info(f"requeue {len(pending_tasks)} pending tasks with latest epoch #{self.sched_epoch}: {pending_tasks[:3]} ...")
self.try_enqueue(pending_tasks)
if self.sched_epoch == 0:
leaf_tasks = self.exec_plan.leaves
logger.info(
f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ..."
)
logger.info(f"enqueue {len(leaf_tasks)} leaf tasks: {leaf_tasks[:3]} ...")
self.try_enqueue(leaf_tasks)
self.log_overall_progress()
while (num_finished_tasks := self.process_finished_tasks(pool)) > 0:
logger.info(
f"processed {num_finished_tasks} finished tasks during startup"
)
logger.info(f"processed {num_finished_tasks} finished tasks during startup")
self.log_overall_progress()
earlier_running_tasks = [
item for item in self.running_works if isinstance(item, Task)
]
earlier_running_tasks = [item for item in self.running_works if isinstance(item, Task)]
if earlier_running_tasks:
logger.info(
f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ..."
)
logger.info(f"enqueue {len(earlier_running_tasks)} earlier running tasks: {earlier_running_tasks[:3]} ...")
self.try_enqueue(earlier_running_tasks)
self.suspend_good_executors()
self.add_state_observer(
Scheduler.StateObserver(Scheduler.log_current_status)
)
self.add_state_observer(
Scheduler.StateObserver(Scheduler.export_timeline_figs)
)
self.add_state_observer(Scheduler.StateObserver(Scheduler.log_current_status))
self.add_state_observer(Scheduler.StateObserver(Scheduler.export_timeline_figs))
self.notify_state_observers(force_notify=True)
try:
@@ -1063,14 +949,10 @@ class Scheduler(object):
if self.success:
self.clean_temp_files(pool)
logger.success(f"final output path: {self.exec_plan.final_output_path}")
logger.info(
f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}"
)
logger.info(f"analyzed plan:{os.linesep}{self.exec_plan.analyzed_logical_plan.explain_str()}")
if perf_profile is not None:
logger.debug(
f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}"
)
logger.debug(f"scheduler perf profile:{os.linesep}{cprofile_to_string(perf_profile)}")
logger.info(f"scheduler of job {self.ctx.job_id} exits")
logger.complete()
@@ -1082,20 +964,14 @@ class Scheduler(object):
task = self.copy_task_for_execution(task)
if task.key in self.succeeded_tasks:
logger.debug(f"task {repr(task)} already succeeded, skipping")
self.try_enqueue(
self.get_runnable_tasks(self.succeeded_tasks[task.key])
)
self.try_enqueue(self.get_runnable_tasks(self.succeeded_tasks[task.key]))
continue
if task.runtime_id in self.scheduled_tasks:
logger.debug(f"task {repr(task)} already scheduled, skipping")
continue
# save enqueued task
self.scheduled_tasks[task.runtime_id] = task
if (
self.standalone_mode
or task.exec_on_scheduler
or task.skip_when_any_input_empty
):
if self.standalone_mode or task.exec_on_scheduler or task.skip_when_any_input_empty:
self.local_queue.append(task)
else:
self.sched_queue.append(task)
@@ -1114,34 +990,20 @@ class Scheduler(object):
if self.local_queue:
assert self.local_executor.alive
logger.info(
f"running {len(self.local_queue)} works on local executor: {self.local_queue[:3]} ..."
)
self.local_queue = [
item
for item in self.local_queue
if not self.local_executor.push(item, buffering=True)
]
logger.info(f"running {len(self.local_queue)} works on local executor: {self.local_queue[:3]} ...")
self.local_queue = [item for item in self.local_queue if not self.local_executor.push(item, buffering=True)]
self.local_executor.flush()
has_progress |= self.dispatch_tasks(pool) > 0
if len(
self.sched_queue
) == 0 and self.num_pending_nontrivial_tasks + 1 < len(self.good_executors):
if len(self.sched_queue) == 0 and self.num_pending_nontrivial_tasks + 1 < len(self.good_executors):
for executor in self.good_executors:
if executor.idle:
logger.info(
f"{len(self.good_executors)} remote executors running, stopping {executor}"
)
logger.info(f"{len(self.good_executors)} remote executors running, stopping {executor}")
executor.stop()
break
if (
len(self.sched_queue) == 0
and len(self.local_queue) == 0
and self.num_running_works == 0
):
if len(self.sched_queue) == 0 and len(self.local_queue) == 0 and self.num_running_works == 0:
self.log_overall_progress()
assert (
self.num_pending_tasks == 0
@@ -1166,29 +1028,13 @@ class Scheduler(object):
def dispatch_tasks(self, pool: ThreadPoolExecutor):
# sort pending tasks
item_sort_key = (
(lambda item: (-item.retry_count, item.id))
if self.prioritize_retry
else (lambda item: (item.retry_count, item.id))
)
item_sort_key = (lambda item: (-item.retry_count, item.id)) if self.prioritize_retry else (lambda item: (item.retry_count, item.id))
items_sorted_by_node_id = sorted(self.sched_queue, key=lambda t: t.node_id)
items_group_by_node_id = itertools.groupby(
items_sorted_by_node_id, key=lambda t: t.node_id
)
sorted_item_groups = [
sorted(items, key=item_sort_key) for _, items in items_group_by_node_id
]
self.sched_queue = [
item
for batch in itertools.zip_longest(*sorted_item_groups, fillvalue=None)
for item in batch
if item is not None
]
items_group_by_node_id = itertools.groupby(items_sorted_by_node_id, key=lambda t: t.node_id)
sorted_item_groups = [sorted(items, key=item_sort_key) for _, items in items_group_by_node_id]
self.sched_queue = [item for batch in itertools.zip_longest(*sorted_item_groups, fillvalue=None) for item in batch if item is not None]
final_phase = (
self.num_pending_nontrivial_tasks - self.num_running_works
<= len(self.good_executors) * 2
)
final_phase = self.num_pending_nontrivial_tasks - self.num_running_works <= len(self.good_executors) * 2
num_dispatched_tasks = 0
unassigned_tasks = []
@@ -1196,42 +1042,31 @@ class Scheduler(object):
first_item = self.sched_queue[0]
# assign tasks to executors in round-robin fashion
usable_executors = [
executor for executor in self.good_executors if not executor.busy
]
for executor in sorted(
usable_executors, key=lambda exec: len(exec.running_works)
):
usable_executors = [executor for executor in self.good_executors if not executor.busy]
for executor in sorted(usable_executors, key=lambda exec: len(exec.running_works)):
if not self.sched_queue:
break
item = self.sched_queue[0]
if item._memory_limit is None:
item._memory_limit = np.int64(
executor.memory_size * item._cpu_limit // executor.cpu_count
)
item._memory_limit = np.int64(executor.memory_size * item._cpu_limit // executor.cpu_count)
if item.key in self.succeeded_tasks:
logger.debug(f"task {repr(item)} already succeeded, skipping")
self.sched_queue.pop(0)
self.try_enqueue(
self.get_runnable_tasks(self.succeeded_tasks[item.key])
)
self.try_enqueue(self.get_runnable_tasks(self.succeeded_tasks[item.key]))
elif (
len(executor.running_works) < executor.max_running_works
and executor.allocated_cpus + item.cpu_limit <= executor.cpu_count
and executor.allocated_gpus + item.gpu_limit <= executor.gpu_count
and executor.allocated_memory + item.memory_limit
<= executor.memory_size
and executor.allocated_memory + item.memory_limit <= executor.memory_size
and item.key not in executor.running_works
):
if final_phase:
self.try_boost_resource(item, executor)
# push to wq of executor but not flushed yet
executor.push(item, buffering=True)
logger.info(
f"appended {repr(item)} ({item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB) to the queue of {executor}"
)
logger.info(f"appended {repr(item)} ({item.cpu_limit} CPUs, {item.memory_limit/GB:.3f}GB) to the queue of {executor}")
self.sched_queue.pop(0)
num_dispatched_tasks += 1
@@ -1242,55 +1077,35 @@ class Scheduler(object):
self.sched_queue.extend(unassigned_tasks)
# flush the buffered work items into wq
assert all(
pool.map(RemoteExecutor.flush, self.good_executors)
), f"failed to flush work queues"
assert all(pool.map(RemoteExecutor.flush, self.good_executors)), f"failed to flush work queues"
return num_dispatched_tasks
def process_finished_tasks(self, pool: ThreadPoolExecutor) -> int:
pop_results = pool.map(RemoteExecutor.pop, self.available_executors.values())
num_finished_tasks = 0
for executor, finished_tasks in zip(
self.available_executors.values(), pop_results
):
for executor, finished_tasks in zip(self.available_executors.values(), pop_results):
for finished_task in finished_tasks:
assert isinstance(finished_task, Task)
scheduled_task = self.scheduled_tasks.get(
finished_task.runtime_id, None
)
scheduled_task = self.scheduled_tasks.get(finished_task.runtime_id, None)
if scheduled_task is None:
logger.info(
f"task not initiated by current scheduler: {finished_task}"
)
logger.info(f"task not initiated by current scheduler: {finished_task}")
if finished_task.status != WorkStatus.SUCCEED and (
missing_inputs := [
key
for key in finished_task.input_deps
if key not in self.succeeded_tasks
]
missing_inputs := [key for key in finished_task.input_deps if key not in self.succeeded_tasks]
):
logger.info(
f"ignore {repr(finished_task)} since some of the input deps are missing: {missing_inputs}"
)
logger.info(f"ignore {repr(finished_task)} since some of the input deps are missing: {missing_inputs}")
continue
if finished_task.status == WorkStatus.INCOMPLETE:
logger.trace(
f"{repr(finished_task)} checkpoint created on {executor.id}: {finished_task.runtime_state}"
)
self.tasks[finished_task.key].runtime_state = (
finished_task.runtime_state
)
logger.trace(f"{repr(finished_task)} checkpoint created on {executor.id}: {finished_task.runtime_state}")
self.tasks[finished_task.key].runtime_state = finished_task.runtime_state
continue
prior_task = self.finished_tasks.get(finished_task.runtime_id, None)
if prior_task is not None:
logger.info(
f"found duplicate tasks, current: {repr(finished_task)}, prior: {repr(prior_task)}"
)
logger.info(f"found duplicate tasks, current: {repr(finished_task)}, prior: {repr(prior_task)}")
continue
else:
self.finished_tasks[finished_task.runtime_id] = finished_task
@@ -1298,30 +1113,22 @@ class Scheduler(object):
succeeded_task = self.succeeded_tasks.get(finished_task.key, None)
if succeeded_task is not None:
logger.info(
f"task already succeeded, current: {repr(finished_task)}, succeeded: {repr(succeeded_task)}"
)
logger.info(f"task already succeeded, current: {repr(finished_task)}, succeeded: {repr(succeeded_task)}")
continue
if finished_task.status in (WorkStatus.FAILED, WorkStatus.CRASHED):
logger.warning(
f"task failed on {executor.id}: {finished_task}, error: {finished_task.exception}"
)
logger.warning(f"task failed on {executor.id}: {finished_task}, error: {finished_task.exception}")
finished_task.dump()
task = self.tasks[finished_task.key]
task.fail_count += 1
if task.fail_count > self.max_fail_count:
logger.critical(
f"task failed too many times: {finished_task}, stopping ..."
)
logger.critical(f"task failed too many times: {finished_task}, stopping ...")
self.stop_executors()
self.sched_running = False
if not executor.local and finished_task.oom(
self.nonzero_exitcode_as_oom
):
if not executor.local and finished_task.oom(self.nonzero_exitcode_as_oom):
if task._memory_limit is None:
task._memory_limit = finished_task._memory_limit
self.try_relax_memory_limit(task, executor)
@@ -1332,9 +1139,7 @@ class Scheduler(object):
self.try_enqueue(self.get_retry_task(finished_task.key))
else:
assert (
finished_task.status == WorkStatus.SUCCEED
), f"unexpected task status: {finished_task}"
assert finished_task.status == WorkStatus.SUCCEED, f"unexpected task status: {finished_task}"
logger.log(
"TRACE" if finished_task.exec_on_scheduler else "INFO",
"task succeeded on {}: {}",
@@ -1344,9 +1149,7 @@ class Scheduler(object):
self.succeeded_tasks[finished_task.key] = finished_task
if not finished_task.exec_on_scheduler:
self.succeeded_nontrivial_tasks[finished_task.key] = (
finished_task
)
self.succeeded_nontrivial_tasks[finished_task.key] = finished_task
# stop the redundant retries of finished task
self.stop_running_tasks(finished_task.key)
@@ -1356,9 +1159,7 @@ class Scheduler(object):
if finished_task.id == self.exec_plan.root_task.id:
self.sched_queue = []
self.stop_executors()
logger.success(
f"all tasks completed, root task: {finished_task}"
)
logger.success(f"all tasks completed, root task: {finished_task}")
logger.success(
f"{len(self.succeeded_tasks)} succeeded tasks, success: {self.success}, elapsed time: {self.elapsed_time:.3f} secs"
)

File diff suppressed because it is too large Load Diff

View File

@@ -55,9 +55,7 @@ class WorkItem(object):
) -> None:
self._cpu_limit = cpu_limit
self._gpu_limit = gpu_limit
self._memory_limit = (
np.int64(memory_limit) if memory_limit is not None else None
)
self._memory_limit = np.int64(memory_limit) if memory_limit is not None else None
self._cpu_boost = 1
self._memory_boost = 1
self._numa_node = None
@@ -88,11 +86,7 @@ class WorkItem(object):
@property
def memory_limit(self) -> np.int64:
return (
np.int64(self._memory_boost * self._memory_limit)
if self._memory_limit
else 0
)
return np.int64(self._memory_boost * self._memory_limit) if self._memory_limit else 0
@property
def elapsed_time(self) -> float:
@@ -142,13 +136,7 @@ class WorkItem(object):
return (
self._memory_limit is not None
and self.status == WorkStatus.CRASHED
and (
isinstance(self.exception, (OutOfMemory, MemoryError))
or (
isinstance(self.exception, NonzeroExitCode)
and nonzero_exitcode_as_oom
)
)
and (isinstance(self.exception, (OutOfMemory, MemoryError)) or (isinstance(self.exception, NonzeroExitCode) and nonzero_exitcode_as_oom))
)
def run(self) -> bool:
@@ -175,9 +163,7 @@ class WorkItem(object):
else:
self.status = WorkStatus.FAILED
except Exception as ex:
logger.opt(exception=ex).error(
f"{repr(self)} crashed with error. node location at {self.location}"
)
logger.opt(exception=ex).error(f"{repr(self)} crashed with error. node location at {self.location}")
self.status = WorkStatus.CRASHED
self.exception = ex
finally:
@@ -204,25 +190,18 @@ class WorkBatch(WorkItem):
cpu_limit = max(w.cpu_limit for w in works)
gpu_limit = max(w.gpu_limit for w in works)
memory_limit = max(w.memory_limit for w in works)
super().__init__(
f"{self.__class__.__name__}-{key}", cpu_limit, gpu_limit, memory_limit
)
super().__init__(f"{self.__class__.__name__}-{key}", cpu_limit, gpu_limit, memory_limit)
self.works = works
def __str__(self) -> str:
return (
super().__str__()
+ f", works[{len(self.works)}]={self.works[:1]}...{self.works[-1:]}"
)
return super().__str__() + f", works[{len(self.works)}]={self.works[:1]}...{self.works[-1:]}"
def run(self) -> bool:
logger.info(f"processing {len(self.works)} works in the batch")
for index, work in enumerate(self.works):
work.exec_id = self.exec_id
if work.exec(self.exec_cq) != WorkStatus.SUCCEED:
logger.error(
f"work item #{index+1}/{len(self.works)} in {self.key} failed: {work}"
)
logger.error(f"work item #{index+1}/{len(self.works)} in {self.key} failed: {work}")
return False
logger.info(f"done {len(self.works)} works in the batch")
return True
@@ -375,9 +354,7 @@ class WorkQueueOnFilesystem(WorkQueue):
os.rename(tempfile_path, enqueued_path)
return True
except OSError as err:
logger.critical(
f"failed to rename {tempfile_path} to {enqueued_path}: {err}"
)
logger.critical(f"failed to rename {tempfile_path} to {enqueued_path}: {err}")
return False
@@ -405,27 +382,17 @@ def count_objects(obj, object_cnt=None, visited_objs=None, depth=0):
object_cnt[class_name] = (cnt + 1, size + sys.getsizeof(obj))
key_attributes = ("__self__", "__dict__", "__slots__")
if not isinstance(obj, (bool, str, int, float, type(None))) and any(
attr_name in key_attributes for attr_name in dir(obj)
):
if not isinstance(obj, (bool, str, int, float, type(None))) and any(attr_name in key_attributes for attr_name in dir(obj)):
logger.debug(f"{' ' * depth}{class_name}@{id(obj):x}")
for attr_name in dir(obj):
try:
if (
not attr_name.startswith("__") or attr_name in key_attributes
) and not isinstance(
if (not attr_name.startswith("__") or attr_name in key_attributes) and not isinstance(
getattr(obj.__class__, attr_name, None), property
):
logger.debug(
f"{' ' * depth}{class_name}.{attr_name}@{id(obj):x}"
)
count_objects(
getattr(obj, attr_name), object_cnt, visited_objs, depth + 1
)
logger.debug(f"{' ' * depth}{class_name}.{attr_name}@{id(obj):x}")
count_objects(getattr(obj, attr_name), object_cnt, visited_objs, depth + 1)
except Exception as ex:
logger.warning(
f"failed to get '{attr_name}' from {repr(obj)}: {ex}"
)
logger.warning(f"failed to get '{attr_name}' from {repr(obj)}: {ex}")
def main():
@@ -433,23 +400,13 @@ def main():
from smallpond.execution.task import Probe
parser = argparse.ArgumentParser(
prog="workqueue.py", description="Work Queue Reader"
)
parser = argparse.ArgumentParser(prog="workqueue.py", description="Work Queue Reader")
parser.add_argument("wq_root", help="Work queue root path")
parser.add_argument("-f", "--work_filter", default="", help="Work item filter")
parser.add_argument(
"-x", "--expand_batch", action="store_true", help="Expand batched works"
)
parser.add_argument(
"-c", "--count_object", action="store_true", help="Count number of objects"
)
parser.add_argument(
"-n", "--top_n_class", default=20, type=int, help="Show the top n classes"
)
parser.add_argument(
"-l", "--log_level", default="INFO", help="Logging message level"
)
parser.add_argument("-x", "--expand_batch", action="store_true", help="Expand batched works")
parser.add_argument("-c", "--count_object", action="store_true", help="Count number of objects")
parser.add_argument("-n", "--top_n_class", default=20, type=int, help="Show the top n classes")
parser.add_argument("-l", "--log_level", default="INFO", help="Logging message level")
args = parser.parse_args()
logger.remove()
@@ -468,9 +425,7 @@ def main():
if args.count_object:
object_cnt = {}
count_objects(work, object_cnt)
sorted_counts = sorted(
[(v, k) for k, v in object_cnt.items()], reverse=True
)
sorted_counts = sorted([(v, k) for k, v in object_cnt.items()], reverse=True)
for count, class_name in sorted_counts[: args.top_n_class]:
logger.info(f" {class_name}: {count}")

View File

@@ -44,11 +44,7 @@ class RowRange:
@property
def estimated_data_size(self) -> int:
"""The estimated uncompressed data size in bytes."""
return (
self.data_size * self.num_rows // self.file_num_rows
if self.file_num_rows > 0
else 0
)
return self.data_size * self.num_rows // self.file_num_rows if self.file_num_rows > 0 else 0
def take(self, num_rows: int) -> "RowRange":
"""
@@ -62,9 +58,7 @@ class RowRange:
return head
@staticmethod
def partition_by_rows(
row_ranges: List["RowRange"], npartition: int
) -> List[List["RowRange"]]:
def partition_by_rows(row_ranges: List["RowRange"], npartition: int) -> List[List["RowRange"]]:
"""Evenly split a list of row ranges into `npartition` partitions."""
# NOTE: `row_ranges` should not be modified by this function
row_ranges = copy.deepcopy(row_ranges)
@@ -128,9 +122,7 @@ def convert_types_to_large_string(schema: arrow.Schema) -> arrow.Schema:
new_fields = []
for field in schema:
new_type = convert_type_to_large(field.type)
new_field = arrow.field(
field.name, new_type, nullable=field.nullable, metadata=field.metadata
)
new_field = arrow.field(field.name, new_type, nullable=field.nullable, metadata=field.metadata)
new_fields.append(new_field)
return arrow.schema(new_fields, metadata=schema.metadata)
@@ -151,11 +143,7 @@ def filter_schema(
if included_cols is not None:
fields = [schema.field(col_name) for col_name in included_cols]
if excluded_cols is not None:
fields = [
schema.field(col_name)
for col_name in schema.names
if col_name not in excluded_cols
]
fields = [schema.field(col_name) for col_name in schema.names if col_name not in excluded_cols]
return arrow.schema(fields, metadata=schema.metadata)
@@ -172,9 +160,7 @@ def _iter_record_batches(
current_offset = 0
required_l, required_r = offset, offset + length
for batch in file.iter_batches(
batch_size=batch_size, columns=columns, use_threads=False
):
for batch in file.iter_batches(batch_size=batch_size, columns=columns, use_threads=False):
current_l, current_r = current_offset, current_offset + batch.num_rows
# check if intersection is null
if current_r <= required_l:
@@ -184,9 +170,7 @@ def _iter_record_batches(
else:
intersection_l = max(required_l, current_l)
intersection_r = min(required_r, current_r)
trimmed = batch.slice(
intersection_l - current_offset, intersection_r - intersection_l
)
trimmed = batch.slice(intersection_l - current_offset, intersection_r - intersection_l)
assert (
trimmed.num_rows == intersection_r - intersection_l
), f"trimmed.num_rows {trimmed.num_rows} != batch_length {intersection_r - intersection_l}"
@@ -204,9 +188,7 @@ def build_batch_reader_from_files(
) -> arrow.RecordBatchReader:
assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list"
schema = _read_schema_from_file(paths_or_ranges[0], columns, filesystem)
iterator = _iter_record_batches_from_files(
paths_or_ranges, columns, batch_size, max_batch_byte_size, filesystem
)
iterator = _iter_record_batches_from_files(paths_or_ranges, columns, batch_size, max_batch_byte_size, filesystem)
return arrow.RecordBatchReader.from_batches(schema, iterator)
@@ -216,9 +198,7 @@ def _read_schema_from_file(
filesystem: fsspec.AbstractFileSystem = None,
) -> arrow.Schema:
path = path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
schema = parquet.read_schema(
filesystem.unstrip_protocol(path) if filesystem else path, filesystem=filesystem
)
schema = parquet.read_schema(filesystem.unstrip_protocol(path) if filesystem else path, filesystem=filesystem)
if columns is not None:
assert all(
c in schema.names for c in columns
@@ -253,9 +233,7 @@ def _iter_record_batches_from_files(
yield from table.combine_chunks().to_batches(batch_size)
for path_or_range in paths_or_ranges:
path = (
path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
)
path = path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
with parquet.ParquetFile(
filesystem.unstrip_protocol(path) if filesystem else path,
buffer_size=16 * MB,
@@ -265,23 +243,16 @@ def _iter_record_batches_from_files(
offset, length = path_or_range.begin, path_or_range.num_rows
else:
offset, length = 0, file.metadata.num_rows
for batch in _iter_record_batches(
file, columns, offset, length, batch_size
):
for batch in _iter_record_batches(file, columns, offset, length, batch_size):
batch_size_exceeded = batch.num_rows + buffered_rows >= batch_size
batch_byte_size_exceeded = (
max_batch_byte_size is not None
and batch.nbytes + buffered_bytes >= max_batch_byte_size
)
batch_byte_size_exceeded = max_batch_byte_size is not None and batch.nbytes + buffered_bytes >= max_batch_byte_size
if not batch_size_exceeded and not batch_byte_size_exceeded:
buffered_batches.append(batch)
buffered_rows += batch.num_rows
buffered_bytes += batch.nbytes
else:
if batch_size_exceeded:
buffered_batches.append(
batch.slice(0, batch_size - buffered_rows)
)
buffered_batches.append(batch.slice(0, batch_size - buffered_rows))
batch = batch.slice(batch_size - buffered_rows)
if buffered_batches:
yield from combine_buffered_batches(buffered_batches)
@@ -298,9 +269,7 @@ def read_parquet_files_into_table(
columns: List[str] = None,
filesystem: fsspec.AbstractFileSystem = None,
) -> arrow.Table:
batch_reader = build_batch_reader_from_files(
paths_or_ranges, columns=columns, filesystem=filesystem
)
batch_reader = build_batch_reader_from_files(paths_or_ranges, columns=columns, filesystem=filesystem)
return batch_reader.read_all()
@@ -312,37 +281,22 @@ def load_from_parquet_files(
) -> arrow.Table:
start_time = time.time()
assert len(paths_or_ranges) > 0, "paths_or_ranges must be a non-empty list"
paths = [
path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range
for path_or_range in paths_or_ranges
]
paths = [path_or_range.path if isinstance(path_or_range, RowRange) else path_or_range for path_or_range in paths_or_ranges]
total_compressed_size = sum(
(
path_or_range.data_size
if isinstance(path_or_range, RowRange)
else os.path.getsize(path_or_range)
)
for path_or_range in paths_or_ranges
)
logger.debug(
f"loading {len(paths)} parquet files (compressed size: {total_compressed_size/MB:.3f}MB): {paths[:3]}..."
(path_or_range.data_size if isinstance(path_or_range, RowRange) else os.path.getsize(path_or_range)) for path_or_range in paths_or_ranges
)
logger.debug(f"loading {len(paths)} parquet files (compressed size: {total_compressed_size/MB:.3f}MB): {paths[:3]}...")
num_workers = min(len(paths), max_workers)
with ThreadPoolExecutor(num_workers) as pool:
running_works = [
pool.submit(read_parquet_files_into_table, batch, columns, filesystem)
for batch in split_into_rows(paths_or_ranges, num_workers)
pool.submit(read_parquet_files_into_table, batch, columns, filesystem) for batch in split_into_rows(paths_or_ranges, num_workers)
]
tables = [work.result() for work in running_works]
logger.debug(
f"collected {len(tables)} tables from: {paths[:3]}... (elapsed: {time.time() - start_time:.3f} secs)"
)
logger.debug(f"collected {len(tables)} tables from: {paths[:3]}... (elapsed: {time.time() - start_time:.3f} secs)")
return arrow.concat_tables(tables)
def parquet_write_table(
table, where, filesystem: fsspec.AbstractFileSystem = None, **write_table_args
) -> int:
def parquet_write_table(table, where, filesystem: fsspec.AbstractFileSystem = None, **write_table_args) -> int:
if filesystem is not None:
return parquet.write_table(
table,
@@ -388,10 +342,7 @@ def dump_to_parquet_files(
num_workers = min(len(batches), max_workers)
num_tables = max(math.ceil(table.nbytes / MAX_PARQUET_FILE_BYTES), num_workers)
logger.debug(f"evenly distributed {len(batches)} batches into {num_tables} files")
tables = [
arrow.Table.from_batches(batch, table.schema)
for batch in split_into_rows(batches, num_tables)
]
tables = [arrow.Table.from_batches(batch, table.schema) for batch in split_into_rows(batches, num_tables)]
assert sum(t.num_rows for t in tables) == table.num_rows
logger.debug(f"writing {len(tables)} files to {output_dir}")
@@ -413,7 +364,5 @@ def dump_to_parquet_files(
]
assert all(work.result() or True for work in running_works)
logger.debug(
f"finished writing {len(tables)} files to {output_dir} (elapsed: {time.time() - start_time:.3f} secs)"
)
logger.debug(f"finished writing {len(tables)} files to {output_dir} (elapsed: {time.time() - start_time:.3f} secs)")
return True

View File

@@ -44,9 +44,7 @@ def remove_path(path: str):
os.symlink(realpath, link)
return
except Exception as ex:
logger.opt(exception=ex).debug(
f"fast recursive remove failed, fall back to shutil.rmtree('{realpath}')"
)
logger.opt(exception=ex).debug(f"fast recursive remove failed, fall back to shutil.rmtree('{realpath}')")
shutil.rmtree(realpath, ignore_errors=True)
@@ -94,9 +92,7 @@ def dump(obj: Any, path: str, buffering=2 * MB, atomic_write=False) -> int:
raise
except Exception as ex:
trace_str, trace_err = get_pickle_trace(obj)
logger.opt(exception=ex).error(
f"pickle trace of {repr(obj)}:{os.linesep}{trace_str}"
)
logger.opt(exception=ex).error(f"pickle trace of {repr(obj)}:{os.linesep}{trace_str}")
if trace_err is None:
raise
else:
@@ -107,9 +103,7 @@ def dump(obj: Any, path: str, buffering=2 * MB, atomic_write=False) -> int:
if atomic_write:
directory, filename = os.path.split(path)
with tempfile.NamedTemporaryFile(
"wb", buffering=buffering, dir=directory, prefix=filename, delete=False
) as fout:
with tempfile.NamedTemporaryFile("wb", buffering=buffering, dir=directory, prefix=filename, delete=False) as fout:
write_to_file(fout)
fout.seek(0, os.SEEK_END)
size = fout.tell()

View File

@@ -206,9 +206,7 @@ class DataSet(object):
resolved_paths.append(path)
if wildcard_paths:
if len(wildcard_paths) == 1:
expanded_paths = glob.glob(
wildcard_paths[0], recursive=self.recursive
)
expanded_paths = glob.glob(wildcard_paths[0], recursive=self.recursive)
else:
logger.debug(
"resolving {} paths with wildcards in {}",
@@ -247,9 +245,7 @@ class DataSet(object):
if self.root_dir is None:
self._absolute_paths = sorted(self.paths)
else:
self._absolute_paths = [
os.path.join(self.root_dir, p) for p in sorted(self.paths)
]
self._absolute_paths = [os.path.join(self.root_dir, p) for p in sorted(self.paths)]
return self._absolute_paths
def sql_query_fragment(
@@ -340,23 +336,15 @@ class DataSet(object):
return file_partitions
@functools.lru_cache
def partition_by_files(
self, npartition: int, random_shuffle: bool = False
) -> "List[DataSet]":
def partition_by_files(self, npartition: int, random_shuffle: bool = False) -> "List[DataSet]":
"""
Evenly split into `npartition` datasets by files.
"""
assert npartition > 0, f"npartition has negative value: {npartition}"
if npartition > self.num_files:
logger.debug(
f"number of partitions {npartition} is greater than the number of files {self.num_files}"
)
logger.debug(f"number of partitions {npartition} is greater than the number of files {self.num_files}")
resolved_paths = (
random.sample(self.resolved_paths, len(self.resolved_paths))
if random_shuffle
else self.resolved_paths
)
resolved_paths = random.sample(self.resolved_paths, len(self.resolved_paths)) if random_shuffle else self.resolved_paths
evenly_split_groups = split_into_rows(resolved_paths, npartition)
num_paths_in_groups = list(map(len, evenly_split_groups))
@@ -367,11 +355,7 @@ class DataSet(object):
logger.debug(
f"created {npartition} file partitions (min #files: {min(num_paths_in_groups)}, max #files: {max(num_paths_in_groups)}, avg #files: {sum(num_paths_in_groups)/len(num_paths_in_groups):.3f}) from {self}"
)
return (
random.sample(file_partitions, len(file_partitions))
if random_shuffle
else file_partitions
)
return random.sample(file_partitions, len(file_partitions)) if random_shuffle else file_partitions
class PartitionedDataSet(DataSet):
@@ -463,9 +447,7 @@ class CsvDataSet(DataSet):
union_by_name=False,
) -> None:
super().__init__(paths, root_dir, recursive, columns, union_by_name)
assert isinstance(
schema, OrderedDict
), f"type of csv schema is not OrderedDict: {type(schema)}"
assert isinstance(schema, OrderedDict), f"type of csv schema is not OrderedDict: {type(schema)}"
self.schema = schema
self.delim = delim
self.max_line_size = max_line_size
@@ -492,14 +474,8 @@ class CsvDataSet(DataSet):
filesystem: fsspec.AbstractFileSystem = None,
conn: duckdb.DuckDBPyConnection = None,
) -> str:
schema_str = ", ".join(
map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items())
)
max_line_size_str = (
f"max_line_size={self.max_line_size}, "
if self.max_line_size is not None
else ""
)
schema_str = ", ".join(map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items()))
max_line_size_str = f"max_line_size={self.max_line_size}, " if self.max_line_size is not None else ""
return (
f"( select {self._column_str} from read_csv([ {self._resolved_path_str} ], delim='{self.delim}', columns={{ {schema_str} }}, header={self.header}, "
f"{max_line_size_str} parallel={self.parallel}, union_by_name={self.union_by_name}) )"
@@ -552,9 +528,7 @@ class JsonDataSet(DataSet):
filesystem: fsspec.AbstractFileSystem = None,
conn: duckdb.DuckDBPyConnection = None,
) -> str:
schema_str = ", ".join(
map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items())
)
schema_str = ", ".join(map(lambda x: f"'{x[0]}': '{x[1]}'", self.schema.items()))
return (
f"( select {self._column_str} from read_json([ {self._resolved_path_str} ], format='{self.format}', columns={{ {schema_str} }}, "
f"maximum_object_size={self.max_object_size}, union_by_name={self.union_by_name}) )"
@@ -614,11 +588,7 @@ class ParquetDataSet(DataSet):
)
# merge row ranges if any dataset has resolved row ranges
if any(dataset._resolved_row_ranges is not None for dataset in datasets):
dataset._resolved_row_ranges = [
row_range
for dataset in datasets
for row_range in dataset.resolved_row_ranges
]
dataset._resolved_row_ranges = [row_range for dataset in datasets for row_range in dataset.resolved_row_ranges]
return dataset
@staticmethod
@@ -655,10 +625,7 @@ class ParquetDataSet(DataSet):
# read parquet metadata to get number of rows
metadata = parquet.read_metadata(path)
num_rows = metadata.num_rows
uncompressed_data_size = sum(
metadata.row_group(i).total_byte_size
for i in range(metadata.num_row_groups)
)
uncompressed_data_size = sum(metadata.row_group(i).total_byte_size for i in range(metadata.num_row_groups))
return RowRange(
path,
data_size=uncompressed_data_size,
@@ -667,20 +634,14 @@ class ParquetDataSet(DataSet):
end=num_rows,
)
with ThreadPoolExecutor(
max_workers=min(32, len(self.resolved_paths))
) as pool:
self._resolved_row_ranges = list(
pool.map(resolve_row_range, self.resolved_paths)
)
with ThreadPoolExecutor(max_workers=min(32, len(self.resolved_paths))) as pool:
self._resolved_row_ranges = list(pool.map(resolve_row_range, self.resolved_paths))
return self._resolved_row_ranges
@property
def num_rows(self) -> int:
if self._resolved_num_rows is None:
self._resolved_num_rows = sum(
row_range.num_rows for row_range in self.resolved_row_ranges
)
self._resolved_num_rows = sum(row_range.num_rows for row_range in self.resolved_row_ranges)
return self._resolved_num_rows
@property
@@ -695,29 +656,19 @@ class ParquetDataSet(DataSet):
"""
Return the estimated data size in bytes.
"""
return sum(
row_range.estimated_data_size for row_range in self.resolved_row_ranges
)
return sum(row_range.estimated_data_size for row_range in self.resolved_row_ranges)
def sql_query_fragment(
self,
filesystem: fsspec.AbstractFileSystem = None,
conn: duckdb.DuckDBPyConnection = None,
) -> str:
extra_parameters = (
"".join(f", {col}=true" for col in self.generated_columns)
if self.generated_columns
else ""
)
extra_parameters = "".join(f", {col}=true" for col in self.generated_columns) if self.generated_columns else ""
parquet_file_queries = []
full_row_ranges = []
for row_range in self.resolved_row_ranges:
path = (
filesystem.unstrip_protocol(row_range.path)
if filesystem
else row_range.path
)
path = filesystem.unstrip_protocol(row_range.path) if filesystem else row_range.path
if row_range.num_rows == row_range.file_num_rows:
full_row_ranges.append(row_range)
else:
@@ -744,9 +695,7 @@ class ParquetDataSet(DataSet):
full_row_ranges[largest_index],
full_row_ranges[0],
)
parquet_file_str = ",\n ".join(
map(lambda x: f"'{x.path}'", full_row_ranges)
)
parquet_file_str = ",\n ".join(map(lambda x: f"'{x.path}'", full_row_ranges))
parquet_file_queries.insert(
0,
f"""
@@ -771,11 +720,7 @@ class ParquetDataSet(DataSet):
tables = []
if self.resolved_row_ranges:
tables.append(
load_from_parquet_files(
self.resolved_row_ranges, self.columns, max_workers, filesystem
)
)
tables.append(load_from_parquet_files(self.resolved_row_ranges, self.columns, max_workers, filesystem))
return arrow.concat_tables(tables)
def to_batch_reader(
@@ -795,18 +740,14 @@ class ParquetDataSet(DataSet):
)
@functools.lru_cache
def partition_by_files(
self, npartition: int, random_shuffle: bool = False
) -> "List[ParquetDataSet]":
def partition_by_files(self, npartition: int, random_shuffle: bool = False) -> "List[ParquetDataSet]":
if self._resolved_row_ranges is not None:
return self.partition_by_rows(npartition, random_shuffle)
else:
return super().partition_by_files(npartition, random_shuffle)
@functools.lru_cache
def partition_by_rows(
self, npartition: int, random_shuffle: bool = False
) -> "List[ParquetDataSet]":
def partition_by_rows(self, npartition: int, random_shuffle: bool = False) -> "List[ParquetDataSet]":
"""
Evenly split the dataset into `npartition` partitions by rows.
If `random_shuffle` is True, shuffle the files before partitioning.
@@ -814,11 +755,7 @@ class ParquetDataSet(DataSet):
assert npartition > 0, f"npartition has negative value: {npartition}"
resolved_row_ranges = self.resolved_row_ranges
resolved_row_ranges = (
random.sample(resolved_row_ranges, len(resolved_row_ranges))
if random_shuffle
else resolved_row_ranges
)
resolved_row_ranges = random.sample(resolved_row_ranges, len(resolved_row_ranges)) if random_shuffle else resolved_row_ranges
def create_dataset(row_ranges: List[RowRange]) -> ParquetDataSet:
row_ranges = sorted(row_ranges, key=lambda x: x.path)
@@ -833,12 +770,7 @@ class ParquetDataSet(DataSet):
dataset._resolved_row_ranges = row_ranges
return dataset
return [
create_dataset(row_ranges)
for row_ranges in RowRange.partition_by_rows(
resolved_row_ranges, npartition
)
]
return [create_dataset(row_ranges) for row_ranges in RowRange.partition_by_rows(resolved_row_ranges, npartition)]
@functools.lru_cache
def partition_by_size(self, max_partition_size: int) -> "List[ParquetDataSet]":
@@ -847,16 +779,12 @@ class ParquetDataSet(DataSet):
"""
if self.empty:
return []
estimated_data_size = sum(
row_range.estimated_data_size for row_range in self.resolved_row_ranges
)
estimated_data_size = sum(row_range.estimated_data_size for row_range in self.resolved_row_ranges)
npartition = estimated_data_size // max_partition_size + 1
return self.partition_by_rows(npartition)
@staticmethod
def _read_partition_key(
path: str, data_partition_column: str, hive_partitioning: bool
) -> int:
def _read_partition_key(path: str, data_partition_column: str, hive_partitioning: bool) -> int:
"""
Get the partition key of the parquet file.
@@ -874,9 +802,7 @@ class ParquetDataSet(DataSet):
try:
return int(key)
except ValueError:
logger.error(
f"cannot parse partition key '{data_partition_column}' of {path} from: {key}"
)
logger.error(f"cannot parse partition key '{data_partition_column}' of {path} from: {key}")
raise
if hive_partitioning:
@@ -884,9 +810,7 @@ class ParquetDataSet(DataSet):
for part in path.split(os.path.sep):
if part.startswith(path_part_prefix):
return parse_partition_key(part[len(path_part_prefix) :])
raise RuntimeError(
f"cannot extract hive partition key '{data_partition_column}' from path: {path}"
)
raise RuntimeError(f"cannot extract hive partition key '{data_partition_column}' from path: {path}")
with parquet.ParquetFile(path) as file:
kv_metadata = file.schema_arrow.metadata or file.metadata.metadata
@@ -896,36 +820,22 @@ class ParquetDataSet(DataSet):
if key == PARQUET_METADATA_KEY_PREFIX + data_partition_column:
return parse_partition_key(val)
if file.metadata.num_rows == 0:
logger.warning(
f"cannot read partition keys from empty parquet file: {path}"
)
logger.warning(f"cannot read partition keys from empty parquet file: {path}")
return None
for batch in file.iter_batches(
batch_size=128, columns=[data_partition_column], use_threads=False
):
assert (
data_partition_column in batch.column_names
), f"cannot find column '{data_partition_column}' in {batch.column_names}"
assert (
batch.num_columns == 1
), f"unexpected num of columns: {batch.column_names}"
for batch in file.iter_batches(batch_size=128, columns=[data_partition_column], use_threads=False):
assert data_partition_column in batch.column_names, f"cannot find column '{data_partition_column}' in {batch.column_names}"
assert batch.num_columns == 1, f"unexpected num of columns: {batch.column_names}"
uniq_partition_keys = set(batch.columns[0].to_pylist())
assert (
uniq_partition_keys and len(uniq_partition_keys) == 1
), f"partition keys found in {path} not unique: {uniq_partition_keys}"
assert uniq_partition_keys and len(uniq_partition_keys) == 1, f"partition keys found in {path} not unique: {uniq_partition_keys}"
return uniq_partition_keys.pop()
def load_partitioned_datasets(
self, npartition: int, data_partition_column: str, hive_partitioning=False
) -> "List[ParquetDataSet]":
def load_partitioned_datasets(self, npartition: int, data_partition_column: str, hive_partitioning=False) -> "List[ParquetDataSet]":
"""
Split the dataset into a list of partitioned datasets.
"""
assert npartition > 0, f"npartition has negative value: {npartition}"
if npartition > self.num_files:
logger.debug(
f"number of partitions {npartition} is greater than the number of files {self.num_files}"
)
logger.debug(f"number of partitions {npartition} is greater than the number of files {self.num_files}")
file_partitions: List[ParquetDataSet] = self._init_file_partitions(npartition)
for dataset in file_partitions:
@@ -940,17 +850,13 @@ class ParquetDataSet(DataSet):
with ThreadPoolExecutor(min(32, len(self.resolved_paths))) as pool:
partition_keys = pool.map(
lambda path: ParquetDataSet._read_partition_key(
path, data_partition_column, hive_partitioning
),
lambda path: ParquetDataSet._read_partition_key(path, data_partition_column, hive_partitioning),
self.resolved_paths,
)
for row_range, partition_key in zip(self.resolved_row_ranges, partition_keys):
if partition_key is not None:
assert (
0 <= partition_key <= npartition
), f"invalid partition key {partition_key} found in {row_range.path}"
assert 0 <= partition_key <= npartition, f"invalid partition key {partition_key} found in {row_range.path}"
dataset = file_partitions[partition_key]
dataset.paths.append(row_range.path)
dataset._absolute_paths.append(row_range.path)
@@ -964,20 +870,14 @@ class ParquetDataSet(DataSet):
"""
Remove empty parquet files from the dataset.
"""
new_row_ranges = [
row_range
for row_range in self.resolved_row_ranges
if row_range.num_rows > 0
]
new_row_ranges = [row_range for row_range in self.resolved_row_ranges if row_range.num_rows > 0]
if len(new_row_ranges) == 0:
# keep at least one file to avoid empty dataset
new_row_ranges = self.resolved_row_ranges[:1]
if len(new_row_ranges) == len(self.resolved_row_ranges):
# no empty files found
return
logger.info(
f"removed {len(self.resolved_row_ranges) - len(new_row_ranges)}/{len(self.resolved_row_ranges)} empty parquet files from {self}"
)
logger.info(f"removed {len(self.resolved_row_ranges) - len(new_row_ranges)}/{len(self.resolved_row_ranges)} empty parquet files from {self}")
self._resolved_row_ranges = new_row_ranges
self._resolved_paths = [row_range.path for row_range in new_row_ranges]
self._absolute_paths = self._resolved_paths
@@ -997,9 +897,7 @@ class SqlQueryDataSet(DataSet):
def __init__(
self,
sql_query: str,
query_builder: Callable[
[duckdb.DuckDBPyConnection, fsspec.AbstractFileSystem], str
] = None,
query_builder: Callable[[duckdb.DuckDBPyConnection, fsspec.AbstractFileSystem], str] = None,
) -> None:
super().__init__([])
self.sql_query = sql_query
@@ -1007,9 +905,7 @@ class SqlQueryDataSet(DataSet):
@property
def num_rows(self) -> int:
num_rows = duckdb.sql(
f"select count(*) as num_rows from {self.sql_query_fragment()}"
).fetchall()
num_rows = duckdb.sql(f"select count(*) as num_rows from {self.sql_query_fragment()}").fetchall()
return num_rows[0][0]
def sql_query_fragment(
@@ -1017,11 +913,7 @@ class SqlQueryDataSet(DataSet):
filesystem: fsspec.AbstractFileSystem = None,
conn: duckdb.DuckDBPyConnection = None,
) -> str:
sql_query = (
self.sql_query
if self.query_builder is None
else self.query_builder(conn, filesystem)
)
sql_query = self.sql_query if self.query_builder is None else self.query_builder(conn, filesystem)
return f"( {sql_query} )"

View File

@@ -132,9 +132,7 @@ class Context(object):
-------
The unique function name.
"""
self.udfs[name] = PythonUDFContext(
name, func, params, return_type, use_arrow_type
)
self.udfs[name] = PythonUDFContext(name, func, params, return_type, use_arrow_type)
return name
def create_external_module(self, module_path: str, name: str = None) -> str:
@@ -213,18 +211,10 @@ class Node(object):
This is a resource requirement specified by the user and used to guide
task scheduling. smallpond does NOT enforce this limit.
"""
assert isinstance(
input_deps, Iterable
), f"input_deps is not iterable: {input_deps}"
assert all(
isinstance(node, Node) for node in input_deps
), f"some of input_deps are not instances of Node: {input_deps}"
assert output_name is None or re.match(
"[a-zA-Z0-9_]+", output_name
), f"output_name has invalid format: {output_name}"
assert output_path is None or os.path.isabs(
output_path
), f"output_path is not an absolute path: {output_path}"
assert isinstance(input_deps, Iterable), f"input_deps is not iterable: {input_deps}"
assert all(isinstance(node, Node) for node in input_deps), f"some of input_deps are not instances of Node: {input_deps}"
assert output_name is None or re.match("[a-zA-Z0-9_]+", output_name), f"output_name has invalid format: {output_name}"
assert output_path is None or os.path.isabs(output_path), f"output_path is not an absolute path: {output_path}"
self.ctx = ctx
self.id = self.ctx._new_node_id()
self.input_deps = input_deps
@@ -238,10 +228,7 @@ class Node(object):
self.perf_metrics: Dict[str, List[float]] = defaultdict(list)
# record the location where the node is constructed in user code
frame = next(
frame
for frame in reversed(traceback.extract_stack())
if frame.filename != __file__
and not frame.filename.endswith("/dataframe.py")
frame for frame in reversed(traceback.extract_stack()) if frame.filename != __file__ and not frame.filename.endswith("/dataframe.py")
)
self.location = f"{frame.filename}:{frame.lineno}"
@@ -290,9 +277,7 @@ class Node(object):
values = self.perf_metrics[name]
min, max, avg = np.min(values), np.max(values), np.average(values)
p50, p75, p95, p99 = np.percentile(values, (50, 75, 95, 99))
self.perf_stats[name] = PerfStats(
len(values), sum(values), min, max, avg, p50, p75, p95, p99
)
self.perf_stats[name] = PerfStats(len(values), sum(values), min, max, avg, p50, p75, p95, p99)
return self.perf_stats[name]
@property
@@ -378,9 +363,7 @@ class DataSinkNode(Node):
"link_or_copy",
"manifest",
), f"invalid sink type: {type}"
super().__init__(
ctx, input_deps, None, output_path, cpu_limit=1, gpu_limit=0, memory_limit=0
)
super().__init__(ctx, input_deps, None, output_path, cpu_limit=1, gpu_limit=0, memory_limit=0)
self.type: DataSinkType = "manifest" if manifest_only else type
self.is_final_node = is_final_node
@@ -402,12 +385,7 @@ class DataSinkNode(Node):
if self.type == "copy" or self.type == "link_or_copy":
# so we create two phase tasks:
# phase1: copy data to a temp directory, for each input partition in parallel
input_deps = [
self._create_phase1_task(
runtime_ctx, task, [PartitionInfo(i, len(input_deps))]
)
for i, task in enumerate(input_deps)
]
input_deps = [self._create_phase1_task(runtime_ctx, task, [PartitionInfo(i, len(input_deps))]) for i, task in enumerate(input_deps)]
# phase2: resolve file name conflicts, hard link files, create manifest file, and clean up temp directory
return DataSinkTask(
runtime_ctx,
@@ -445,9 +423,7 @@ class DataSinkNode(Node):
input_dep: Task,
partition_infos: List[PartitionInfo],
) -> DataSinkTask:
return DataSinkTask(
runtime_ctx, [input_dep], partition_infos, self.output_path, type=self.type
)
return DataSinkTask(runtime_ctx, [input_dep], partition_infos, self.output_path, type=self.type)
class PythonScriptNode(Node):
@@ -467,9 +443,7 @@ class PythonScriptNode(Node):
ctx: Context,
input_deps: Tuple[Node, ...],
*,
process_func: Optional[
Callable[[RuntimeContext, List[DataSet], str], bool]
] = None,
process_func: Optional[Callable[[RuntimeContext, List[DataSet], str], bool]] = None,
output_name: Optional[str] = None,
output_path: Optional[str] = None,
cpu_limit: int = 1,
@@ -646,9 +620,7 @@ class ArrowComputeNode(Node):
gpu_limit,
memory_limit,
)
self.parquet_row_group_size = (
parquet_row_group_size or self.default_row_group_size
)
self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size
self.parquet_dictionary_encoding = parquet_dictionary_encoding
self.parquet_compression = parquet_compression
self.parquet_compression_level = parquet_compression_level
@@ -708,9 +680,7 @@ class ArrowComputeNode(Node):
"""
return ArrowComputeTask(*args, **kwargs)
def process(
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
"""
Put user-defined code here.
@@ -749,9 +719,7 @@ class ArrowStreamNode(Node):
ctx: Context,
input_deps: Tuple[Node, ...],
*,
process_func: Callable[
[RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table]
] = None,
process_func: Callable[[RuntimeContext, List[arrow.RecordBatchReader]], Iterable[arrow.Table]] = None,
background_io_thread=True,
streaming_batch_size: int = None,
secs_checkpoint_interval: int = None,
@@ -816,12 +784,9 @@ class ArrowStreamNode(Node):
self.background_io_thread = background_io_thread and self.cpu_limit > 1
self.streaming_batch_size = streaming_batch_size or self.default_batch_size
self.secs_checkpoint_interval = secs_checkpoint_interval or math.ceil(
self.default_secs_checkpoint_interval
/ min(6, self.gpu_limit + 2, self.cpu_limit)
)
self.parquet_row_group_size = (
parquet_row_group_size or self.default_row_group_size
self.default_secs_checkpoint_interval / min(6, self.gpu_limit + 2, self.cpu_limit)
)
self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size
self.parquet_dictionary_encoding = parquet_dictionary_encoding
self.parquet_compression = parquet_compression
self.parquet_compression_level = parquet_compression_level
@@ -890,9 +855,7 @@ class ArrowStreamNode(Node):
"""
return ArrowStreamTask(*args, **kwargs)
def process(
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
) -> Iterable[arrow.Table]:
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
"""
Put user-defined code here.
@@ -918,9 +881,7 @@ class ArrowBatchNode(ArrowStreamNode):
def spawn(self, *args, **kwargs) -> ArrowBatchTask:
return ArrowBatchTask(*args, **kwargs)
def process(
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
raise NotImplementedError
@@ -932,9 +893,7 @@ class PandasComputeNode(ArrowComputeNode):
def spawn(self, *args, **kwargs) -> PandasComputeTask:
return PandasComputeTask(*args, **kwargs)
def process(
self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]
) -> pd.DataFrame:
def process(self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]) -> pd.DataFrame:
raise NotImplementedError
@@ -946,9 +905,7 @@ class PandasBatchNode(ArrowStreamNode):
def spawn(self, *args, **kwargs) -> PandasBatchTask:
return PandasBatchTask(*args, **kwargs)
def process(
self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]
) -> pd.DataFrame:
def process(self, runtime_ctx: RuntimeContext, input_dfs: List[pd.DataFrame]) -> pd.DataFrame:
raise NotImplementedError
@@ -1064,13 +1021,8 @@ class SqlEngineNode(Node):
cpu_limit = cpu_limit or self.default_cpu_limit
memory_limit = memory_limit or self.default_memory_limit
if udfs is not None:
if (
self.max_udf_cpu_limit is not None
and cpu_limit > self.max_udf_cpu_limit
):
warnings.warn(
f"UDF execution is not highly paralleled, downgrade cpu_limit from {cpu_limit} to {self.max_udf_cpu_limit}"
)
if self.max_udf_cpu_limit is not None and cpu_limit > self.max_udf_cpu_limit:
warnings.warn(f"UDF execution is not highly paralleled, downgrade cpu_limit from {cpu_limit} to {self.max_udf_cpu_limit}")
cpu_limit = self.max_udf_cpu_limit
memory_limit = None
if relax_memory_if_oom is not None:
@@ -1080,10 +1032,7 @@ class SqlEngineNode(Node):
stacklevel=3,
)
assert isinstance(sql_query, str) or (
isinstance(sql_query, Iterable)
and all(isinstance(q, str) for q in sql_query)
)
assert isinstance(sql_query, str) or (isinstance(sql_query, Iterable) and all(isinstance(q, str) for q in sql_query))
super().__init__(
ctx,
input_deps,
@@ -1092,17 +1041,13 @@ class SqlEngineNode(Node):
cpu_limit=cpu_limit,
memory_limit=memory_limit,
)
self.sql_queries = (
[sql_query] if isinstance(sql_query, str) else list(sql_query)
)
self.udfs = [
ctx.create_duckdb_extension(path) for path in extension_paths or []
] + [ctx.create_external_module(path) for path in udf_module_paths or []]
self.sql_queries = [sql_query] if isinstance(sql_query, str) else list(sql_query)
self.udfs = [ctx.create_duckdb_extension(path) for path in extension_paths or []] + [
ctx.create_external_module(path) for path in udf_module_paths or []
]
for udf in udfs or []:
if isinstance(udf, UserDefinedFunction):
name = ctx.create_function(
udf.name, udf.func, udf.params, udf.return_type, udf.use_arrow_type
)
name = ctx.create_function(udf.name, udf.func, udf.params, udf.return_type, udf.use_arrow_type)
else:
assert isinstance(udf, str), f"udf must be a string: {udf}"
if udf in ctx.udfs:
@@ -1120,9 +1065,7 @@ class SqlEngineNode(Node):
self.materialize_in_memory = materialize_in_memory
self.batched_processing = batched_processing and len(input_deps) == 1
self.enable_temp_directory = enable_temp_directory
self.parquet_row_group_size = (
parquet_row_group_size or self.default_row_group_size
)
self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size
self.parquet_dictionary_encoding = parquet_dictionary_encoding
self.parquet_compression = parquet_compression
self.parquet_compression_level = parquet_compression_level
@@ -1130,17 +1073,11 @@ class SqlEngineNode(Node):
self.memory_overcommit_ratio = memory_overcommit_ratio
def __str__(self) -> str:
return (
super().__str__()
+ f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}"
)
return super().__str__() + f", sql_query=<{self.oneline_query[:100]}...>, udfs={self.udfs}, batched_processing={self.batched_processing}"
@property
def oneline_query(self) -> str:
return "; ".join(
" ".join(filter(None, map(str.strip, query.splitlines())))
for query in self.sql_queries
)
return "; ".join(" ".join(filter(None, map(str.strip, query.splitlines()))) for query in self.sql_queries)
@Node.task_factory
def create_task(
@@ -1224,12 +1161,8 @@ class ConsolidateNode(Node):
dimensions
Partitions would be grouped by these `dimensions` and consolidated into larger partitions.
"""
assert isinstance(
dimensions, Iterable
), f"dimensions is not iterable: {dimensions}"
assert all(
isinstance(dim, str) for dim in dimensions
), f"some dimensions are not strings: {dimensions}"
assert isinstance(dimensions, Iterable), f"dimensions is not iterable: {dimensions}"
assert all(isinstance(dim, str) for dim in dimensions), f"some dimensions are not strings: {dimensions}"
super().__init__(ctx, [input_dep])
self.dimensions = set(list(dimensions) + [PartitionInfo.toplevel_dimension])
@@ -1283,29 +1216,16 @@ class PartitionNode(Node):
See unit tests in `test/test_partition.py`. For nested partition see `test_nested_partition`.
Why nested partition? See **5.1 Partial Partitioning** of [Advanced partitioning techniques for massively distributed computation](https://dl.acm.org/doi/10.1145/2213836.2213839).
"""
assert isinstance(
npartitions, int
), f"npartitions is not an integer: {npartitions}"
assert dimension is None or re.match(
"[a-zA-Z0-9_]+", dimension
), f"dimension has invalid format: {dimension}"
assert not (
nested and dimension is None
), f"nested partition should have dimension"
super().__init__(
ctx, input_deps, output_name, output_path, cpu_limit, 0, memory_limit
)
assert isinstance(npartitions, int), f"npartitions is not an integer: {npartitions}"
assert dimension is None or re.match("[a-zA-Z0-9_]+", dimension), f"dimension has invalid format: {dimension}"
assert not (nested and dimension is None), f"nested partition should have dimension"
super().__init__(ctx, input_deps, output_name, output_path, cpu_limit, 0, memory_limit)
self.npartitions = npartitions
self.dimension = (
dimension if dimension is not None else PartitionInfo.default_dimension
)
self.dimension = dimension if dimension is not None else PartitionInfo.default_dimension
self.nested = nested
def __str__(self) -> str:
return (
super().__str__()
+ f", npartitions={self.npartitions}, dimension={self.dimension}, nested={self.nested}"
)
return super().__str__() + f", npartitions={self.npartitions}, dimension={self.dimension}, nested={self.nested}"
@Node.task_factory
def create_producer_task(
@@ -1441,12 +1361,8 @@ class UserDefinedPartitionNode(PartitionNode):
class UserPartitionedDataSourceNode(UserDefinedPartitionNode):
max_num_producer_tasks = 1
def __init__(
self, ctx: Context, partitioned_datasets: List[DataSet], dimension: str = None
) -> None:
assert isinstance(partitioned_datasets, Iterable) and all(
isinstance(dataset, DataSet) for dataset in partitioned_datasets
)
def __init__(self, ctx: Context, partitioned_datasets: List[DataSet], dimension: str = None) -> None:
assert isinstance(partitioned_datasets, Iterable) and all(isinstance(dataset, DataSet) for dataset in partitioned_datasets)
super().__init__(
ctx,
[DataSourceNode(ctx, dataset=None)],
@@ -1507,10 +1423,7 @@ class EvenlyDistributedPartitionNode(PartitionNode):
self.random_shuffle = random_shuffle
def __str__(self) -> str:
return (
super().__str__()
+ f", partition_by_rows={self.partition_by_rows}, random_shuffle={self.random_shuffle}"
)
return super().__str__() + f", partition_by_rows={self.partition_by_rows}, random_shuffle={self.random_shuffle}"
@Node.task_factory
def create_producer_task(
@@ -1551,9 +1464,7 @@ class LoadPartitionedDataSetNode(PartitionNode):
cpu_limit: int = 1,
memory_limit: Optional[int] = None,
) -> None:
assert (
dimension or data_partition_column
), f"Both 'dimension' and 'data_partition_column' are none or empty"
assert dimension or data_partition_column, f"Both 'dimension' and 'data_partition_column' are none or empty"
super().__init__(
ctx,
input_deps,
@@ -1567,10 +1478,7 @@ class LoadPartitionedDataSetNode(PartitionNode):
self.hive_partitioning = hive_partitioning
def __str__(self) -> str:
return (
super().__str__()
+ f", data_partition_column={self.data_partition_column}, hive_partitioning={self.hive_partitioning}"
)
return super().__str__() + f", data_partition_column={self.data_partition_column}, hive_partitioning={self.hive_partitioning}"
@Node.task_factory
def create_producer_task(
@@ -1620,9 +1528,7 @@ def DataSetPartitionNode(
--------
See unit test `test_load_partitioned_datasets` in `test/test_partition.py`.
"""
assert not (
partition_by_rows and data_partition_column
), "partition_by_rows and data_partition_column cannot be set at the same time"
assert not (partition_by_rows and data_partition_column), "partition_by_rows and data_partition_column cannot be set at the same time"
if data_partition_column is None:
partition_node = EvenlyDistributedPartitionNode(
ctx,
@@ -1720,12 +1626,8 @@ class HashPartitionNode(PartitionNode):
Specify if we should use dictionary encoding in general or only for some columns.
See `use_dictionary` in https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html.
"""
assert (
not random_shuffle or not shuffle_only
), f"random_shuffle and shuffle_only cannot be enabled at the same time"
assert (
not shuffle_only or data_partition_column is not None
), f"data_partition_column not specified for shuffle-only partitioning"
assert not random_shuffle or not shuffle_only, f"random_shuffle and shuffle_only cannot be enabled at the same time"
assert not shuffle_only or data_partition_column is not None, f"data_partition_column not specified for shuffle-only partitioning"
assert data_partition_column is None or re.match(
"[a-zA-Z0-9_]+", data_partition_column
), f"data_partition_column has invalid format: {data_partition_column}"
@@ -1734,9 +1636,7 @@ class HashPartitionNode(PartitionNode):
"duckdb",
"arrow",
), f"unknown query engine type: {engine_type}"
data_partition_column = (
data_partition_column or self.default_data_partition_column
)
data_partition_column = data_partition_column or self.default_data_partition_column
super().__init__(
ctx,
input_deps,
@@ -1756,9 +1656,7 @@ class HashPartitionNode(PartitionNode):
self.drop_partition_column = drop_partition_column
self.use_parquet_writer = use_parquet_writer
self.hive_partitioning = hive_partitioning and self.engine_type == "duckdb"
self.parquet_row_group_size = (
parquet_row_group_size or self.default_row_group_size
)
self.parquet_row_group_size = parquet_row_group_size or self.default_row_group_size
self.parquet_dictionary_encoding = parquet_dictionary_encoding
self.parquet_compression = parquet_compression
self.parquet_compression_level = parquet_compression_level
@@ -1929,22 +1827,15 @@ class ProjectionNode(Node):
"""
columns = columns or ["*"]
generated_columns = generated_columns or []
assert all(
col in GENERATED_COLUMNS for col in generated_columns
), f"invalid values found in generated columns: {generated_columns}"
assert not (
set(columns) & set(generated_columns)
), f"columns {columns} and generated columns {generated_columns} share common columns"
assert all(col in GENERATED_COLUMNS for col in generated_columns), f"invalid values found in generated columns: {generated_columns}"
assert not (set(columns) & set(generated_columns)), f"columns {columns} and generated columns {generated_columns} share common columns"
super().__init__(ctx, [input_dep])
self.columns = columns
self.generated_columns = generated_columns
self.union_by_name = union_by_name
def __str__(self) -> str:
return (
super().__str__()
+ f", columns={self.columns}, generated_columns={self.generated_columns}, union_by_name={self.union_by_name}"
)
return super().__str__() + f", columns={self.columns}, generated_columns={self.generated_columns}, union_by_name={self.union_by_name}"
@Node.task_factory
def create_task(
@@ -2100,10 +1991,7 @@ class LogicalPlan(object):
if node.id in visited:
return lines + [" " * depth + " (omitted ...)"]
visited.add(node.id)
lines += [
" " * depth + f" | {name}: {stats}"
for name, stats in node.perf_stats.items()
]
lines += [" " * depth + f" | {name}: {stats}" for name, stats in node.perf_stats.items()]
for dep in node.input_deps:
lines.extend(to_str(dep, depth + 1))
return lines

View File

@@ -32,9 +32,7 @@ class Optimizer(LogicalPlanVisitor[Node]):
def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> Node:
# fuse consecutive SqlEngineNodes
if len(node.input_deps) == 1 and isinstance(
child := self.visit(node.input_deps[0], depth + 1), SqlEngineNode
):
if len(node.input_deps) == 1 and isinstance(child := self.visit(node.input_deps[0], depth + 1), SqlEngineNode):
fused = copy.copy(node)
fused.input_deps = child.input_deps
fused.udfs = node.udfs + child.udfs
@@ -52,8 +50,6 @@ class Optimizer(LogicalPlanVisitor[Node]):
# node.sql_queries = ["select a, b from {0}"]
# fused.sql_queries = ["select a, b from (select * from {0})"]
# ```
fused.sql_queries = child.sql_queries[:-1] + [
query.format(f"({child.sql_queries[-1]})") for query in node.sql_queries
]
fused.sql_queries = child.sql_queries[:-1] + [query.format(f"({child.sql_queries[-1]})") for query in node.sql_queries]
return fused
return self.generic_visit(node, depth)

View File

@@ -14,30 +14,20 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
self.node_to_tasks: Dict[Node, TaskGroup] = {}
@logger.catch(reraise=True, message="failed to build computation graph")
def create_exec_plan(
self, logical_plan: LogicalPlan, manifest_only_final_results=True
) -> ExecutionPlan:
def create_exec_plan(self, logical_plan: LogicalPlan, manifest_only_final_results=True) -> ExecutionPlan:
logical_plan = copy.deepcopy(logical_plan)
# if --output_path is specified, copy files to the output path
# otherwise, create manifest files only
sink_type = (
"copy" if self.runtime_ctx.final_output_path is not None else "manifest"
)
final_sink_type = (
"copy"
if self.runtime_ctx.final_output_path is not None
else "manifest" if manifest_only_final_results else "link"
)
sink_type = "copy" if self.runtime_ctx.final_output_path is not None else "manifest"
final_sink_type = "copy" if self.runtime_ctx.final_output_path is not None else "manifest" if manifest_only_final_results else "link"
# create DataSinkNode for each named output node (same name share the same sink node)
nodes_groupby_output_name: Dict[str, List[Node]] = defaultdict(list)
for node in logical_plan.nodes.values():
if node.output_name is not None:
if node.output_name in nodes_groupby_output_name:
warnings.warn(
f"{node} has duplicate output name: {node.output_name}"
)
warnings.warn(f"{node} has duplicate output name: {node.output_name}")
nodes_groupby_output_name[node.output_name].append(node)
sink_nodes = {} # { output_name: DataSinkNode }
for output_name, nodes in nodes_groupby_output_name.items():
@@ -45,9 +35,7 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
self.runtime_ctx.final_output_path or self.runtime_ctx.output_root,
output_name,
)
sink_nodes[output_name] = DataSinkNode(
logical_plan.ctx, tuple(nodes), output_path, type=sink_type
)
sink_nodes[output_name] = DataSinkNode(logical_plan.ctx, tuple(nodes), output_path, type=sink_type)
# create DataSinkNode for root node
# XXX: special case optimization to avoid copying files twice
@@ -63,9 +51,7 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
):
sink_nodes["FinalResults"] = DataSinkNode(
logical_plan.ctx,
tuple(
sink_nodes[node.output_name] for node in partition_node.input_deps
),
tuple(sink_nodes[node.output_name] for node in partition_node.input_deps),
output_path=os.path.join(
self.runtime_ctx.final_output_path or self.runtime_ctx.output_root,
"FinalResults",
@@ -124,38 +110,24 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
return [node.create_task(self.runtime_ctx, [], [PartitionInfo()])]
def visit_data_sink_node(self, node: DataSinkNode, depth: int) -> TaskGroup:
all_input_deps = [
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
]
all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
return [node.create_task(self.runtime_ctx, all_input_deps, [PartitionInfo()])]
def visit_root_node(self, node: RootNode, depth: int) -> TaskGroup:
all_input_deps = [
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
]
all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
return [RootTask(self.runtime_ctx, all_input_deps, [PartitionInfo()])]
def visit_union_node(self, node: UnionNode, depth: int) -> TaskGroup:
all_input_deps = [
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
]
all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
unique_partition_dims = set(task.partition_dims for task in all_input_deps)
assert (
len(unique_partition_dims) == 1
), f"cannot union partitions with different dimensions: {unique_partition_dims}"
assert len(unique_partition_dims) == 1, f"cannot union partitions with different dimensions: {unique_partition_dims}"
return all_input_deps
def visit_consolidate_node(self, node: ConsolidateNode, depth: int) -> TaskGroup:
input_deps_taskgroups = [self.visit(dep, depth + 1) for dep in node.input_deps]
assert (
len(input_deps_taskgroups) == 1
), f"consolidate node only accepts one input node, but found: {input_deps_taskgroups}"
unique_partition_dims = set(
task.partition_dims for task in input_deps_taskgroups[0]
)
assert (
len(unique_partition_dims) == 1
), f"cannot consolidate partitions with different dimensions: {unique_partition_dims}"
assert len(input_deps_taskgroups) == 1, f"consolidate node only accepts one input node, but found: {input_deps_taskgroups}"
unique_partition_dims = set(task.partition_dims for task in input_deps_taskgroups[0])
assert len(unique_partition_dims) == 1, f"cannot consolidate partitions with different dimensions: {unique_partition_dims}"
existing_dimensions = set(unique_partition_dims.pop())
assert (
node.dimensions.intersection(existing_dimensions) == node.dimensions
@@ -163,46 +135,30 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
# group tasks by partitions
input_deps_groupby_partitions: Dict[Tuple, List[Task]] = defaultdict(list)
for task in input_deps_taskgroups[0]:
partition_infos = tuple(
info
for info in task.partition_infos
if info.dimension in node.dimensions
)
partition_infos = tuple(info for info in task.partition_infos if info.dimension in node.dimensions)
input_deps_groupby_partitions[partition_infos].append(task)
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for partition_infos, input_deps in input_deps_groupby_partitions.items()
node.create_task(self.runtime_ctx, input_deps, partition_infos) for partition_infos, input_deps in input_deps_groupby_partitions.items()
]
def visit_partition_node(self, node: PartitionNode, depth: int) -> TaskGroup:
all_input_deps = [
task for dep in node.input_deps for task in self.visit(dep, depth + 1)
]
all_input_deps = [task for dep in node.input_deps for task in self.visit(dep, depth + 1)]
unique_partition_dims = set(task.partition_dims for task in all_input_deps)
assert (
len(unique_partition_dims) == 1
), f"cannot partition input_deps with different dimensions: {unique_partition_dims}"
assert len(unique_partition_dims) == 1, f"cannot partition input_deps with different dimensions: {unique_partition_dims}"
if node.nested:
assert (
node.dimension not in unique_partition_dims
), f"found duplicate partition dimension '{node.dimension}', existing dimensions: {unique_partition_dims}"
assert (
len(all_input_deps) * node.npartitions
<= node.max_card_of_producers_x_consumers
len(all_input_deps) * node.npartitions <= node.max_card_of_producers_x_consumers
), f"{len(all_input_deps)=} * {node.npartitions=} > {node.max_card_of_producers_x_consumers=}"
producer_tasks = [
node.create_producer_task(
self.runtime_ctx, [task], task.partition_infos
)
for task in all_input_deps
]
producer_tasks = [node.create_producer_task(self.runtime_ctx, [task], task.partition_infos) for task in all_input_deps]
return [
node.create_consumer_task(
self.runtime_ctx,
[producer],
list(producer.partition_infos)
+ [PartitionInfo(partition_idx, node.npartitions, node.dimension)],
list(producer.partition_infos) + [PartitionInfo(partition_idx, node.npartitions, node.dimension)],
)
for producer in producer_tasks
for partition_idx in range(node.npartitions)
@@ -212,16 +168,10 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
node.max_num_producer_tasks,
math.ceil(node.max_card_of_producers_x_consumers / node.npartitions),
)
num_parallel_tasks = (
2
* self.runtime_ctx.num_executors
* math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit)
)
num_parallel_tasks = 2 * self.runtime_ctx.num_executors * math.ceil(self.runtime_ctx.usable_cpu_count / node.cpu_limit)
num_producer_tasks = max(1, min(max_num_producer_tasks, num_parallel_tasks))
if len(all_input_deps) < num_producer_tasks:
merge_datasets_task = node.create_merge_task(
self.runtime_ctx, all_input_deps, [PartitionInfo()]
)
merge_datasets_task = node.create_merge_task(self.runtime_ctx, all_input_deps, [PartitionInfo()])
split_dataset_tasks = [
node.create_split_task(
self.runtime_ctx,
@@ -237,15 +187,10 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
tasks,
[PartitionInfo(partition_idx, num_producer_tasks)],
)
for partition_idx, tasks in enumerate(
split_into_rows(all_input_deps, num_producer_tasks)
)
for partition_idx, tasks in enumerate(split_into_rows(all_input_deps, num_producer_tasks))
]
producer_tasks = [
node.create_producer_task(
self.runtime_ctx, [split_dataset], split_dataset.partition_infos
)
for split_dataset in split_dataset_tasks
node.create_producer_task(self.runtime_ctx, [split_dataset], split_dataset.partition_infos) for split_dataset in split_dataset_tasks
]
return [
node.create_consumer_task(
@@ -284,11 +229,7 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
for main_input in input_deps_most_ndims:
input_deps = []
for input_deps_dims, input_deps_map in input_deps_maps:
partition_infos = tuple(
info
for info in main_input.partition_infos
if info.dimension in input_deps_dims
)
partition_infos = tuple(info for info in main_input.partition_infos if info.dimension in input_deps_dims)
input_dep = input_deps_map.get(partition_infos, None)
assert (
input_dep is not None
@@ -299,50 +240,32 @@ class Planner(LogicalPlanVisitor[TaskGroup]):
def visit_python_script_node(self, node: PythonScriptNode, depth: int) -> TaskGroup:
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
]
def visit_arrow_compute_node(self, node: ArrowComputeNode, depth: int) -> TaskGroup:
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
]
def visit_arrow_stream_node(self, node: ArrowStreamNode, depth: int) -> TaskGroup:
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
]
def visit_query_engine_node(self, node: SqlEngineNode, depth: int) -> TaskGroup:
return [
node.create_task(self.runtime_ctx, input_deps, partition_infos)
for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
node.create_task(self.runtime_ctx, input_deps, partition_infos) for input_deps, partition_infos in self.broadcast_input_deps(node, depth)
]
def visit_projection_node(self, node: ProjectionNode, depth: int) -> TaskGroup:
assert (
len(node.input_deps) == 1
), f"projection node only accepts one input node, but found: {node.input_deps}"
return [
node.create_task(self.runtime_ctx, [task], task.partition_infos)
for task in self.visit(node.input_deps[0], depth + 1)
]
assert len(node.input_deps) == 1, f"projection node only accepts one input node, but found: {node.input_deps}"
return [node.create_task(self.runtime_ctx, [task], task.partition_infos) for task in self.visit(node.input_deps[0], depth + 1)]
def visit_limit_node(self, node: LimitNode, depth: int) -> TaskGroup:
assert (
len(node.input_deps) == 1
), f"limit node only accepts one input node, but found: {node.input_deps}"
assert len(node.input_deps) == 1, f"limit node only accepts one input node, but found: {node.input_deps}"
all_input_deps = self.visit(node.input_deps[0], depth + 1)
partial_limit_tasks = [
node.create_task(self.runtime_ctx, [task], task.partition_infos)
for task in all_input_deps
]
merge_task = node.create_merge_task(
self.runtime_ctx, partial_limit_tasks, [PartitionInfo()]
)
global_limit_task = node.create_task(
self.runtime_ctx, [merge_task], merge_task.partition_infos
)
partial_limit_tasks = [node.create_task(self.runtime_ctx, [task], task.partition_infos) for task in all_input_deps]
merge_task = node.create_merge_task(self.runtime_ctx, partial_limit_tasks, [PartitionInfo()])
global_limit_task = node.create_task(self.runtime_ctx, [merge_task], merge_task.partition_infos)
return [global_limit_task]

View File

@@ -264,6 +264,4 @@ def udf(
See `Context.create_function` for more details.
"""
return lambda func: UserDefinedFunction(
name or func.__name__, func, params, return_type, use_arrow_type
)
return lambda func: UserDefinedFunction(name or func.__name__, func, params, return_type, use_arrow_type)

View File

@@ -57,9 +57,7 @@ class SessionBase:
logger.info(f"session config: {self.config}")
def setup_worker():
runtime_ctx._init_logs(
exec_id=socket.gethostname(), capture_stdout_stderr=True
)
runtime_ctx._init_logs(exec_id=socket.gethostname(), capture_stdout_stderr=True)
if self.config.ray_address is None:
# find the memory allocator
@@ -72,9 +70,7 @@ class SessionBase:
malloc_path = shutil.which("libmimalloc.so.2.1")
assert malloc_path is not None, "mimalloc is not installed"
else:
raise ValueError(
f"unsupported memory allocator: {self.config.memory_allocator}"
)
raise ValueError(f"unsupported memory allocator: {self.config.memory_allocator}")
memory_purge_delay = 10000
# start ray head node
@@ -84,11 +80,7 @@ class SessionBase:
# start a new local cluster
address="local",
# disable local CPU resource if not running on localhost
num_cpus=(
0
if self.config.num_executors > 0
else self._runtime_ctx.usable_cpu_count
),
num_cpus=(0 if self.config.num_executors > 0 else self._runtime_ctx.usable_cpu_count),
# set the memory limit to the available memory size
_memory=self._runtime_ctx.usable_memory_size,
# setup logging for workers
@@ -142,9 +134,7 @@ class SessionBase:
# spawn a thread to periodically dump metrics
self._stop_event = threading.Event()
self._dump_thread = threading.Thread(
name="dump_thread", target=self._dump_periodically, daemon=True
)
self._dump_thread = threading.Thread(name="dump_thread", target=self._dump_periodically, daemon=True)
self._dump_thread.start()
def shutdown(self):
@@ -184,11 +174,7 @@ class SessionBase:
extra_opts=dict(
tags=["smallpond", "scheduler", smallpond.__version__],
),
envs={
k: v
for k, v in os.environ.items()
if k.startswith("SP_") and k != "SP_SPAWN"
},
envs={k: v for k, v in os.environ.items() if k.startswith("SP_") and k != "SP_SPAWN"},
)
def _start_prometheus(self) -> Optional[subprocess.Popen]:
@@ -233,8 +219,7 @@ class SessionBase:
stdout=open(f"{self._runtime_ctx.log_root}/grafana/grafana.log", "w"),
env={
"GF_SERVER_HTTP_PORT": "8122", # redirect to an available port
"GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST")
or "http://localhost:8122",
"GF_SERVER_ROOT_URL": os.environ.get("RAY_GRAFANA_IFRAME_HOST") or "http://localhost:8122",
"GF_PATHS_DATA": f"{self._runtime_ctx.log_root}/grafana/data",
},
)
@@ -309,12 +294,8 @@ class SessionBase:
self.dump_graph()
self.dump_timeline()
num_total_tasks, num_finished_tasks = self._summarize_task()
percent = (
num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0
)
logger.info(
f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)"
)
percent = num_finished_tasks / num_total_tasks * 100 if num_total_tasks > 0 else 0
logger.info(f"progress: {num_finished_tasks}/{num_total_tasks} tasks ({percent:.1f}%)")
@dataclass
@@ -360,20 +341,12 @@ class Config:
platform = get_platform(get_env("PLATFORM") or platform)
job_id = get_env("JOBID") or job_id or platform.default_job_id()
job_time = (
get_env("JOB_TIME", datetime.fromisoformat)
or job_time
or platform.default_job_time()
)
job_time = get_env("JOB_TIME", datetime.fromisoformat) or job_time or platform.default_job_time()
data_root = get_env("DATA_ROOT") or data_root or platform.default_data_root()
num_executors = get_env("NUM_EXECUTORS", int) or num_executors or 0
ray_address = get_env("RAY_ADDRESS") or ray_address
bind_numa_node = get_env("BIND_NUMA_NODE") == "1" or bind_numa_node
memory_allocator = (
get_env("MEMORY_ALLOCATOR")
or memory_allocator
or platform.default_memory_allocator()
)
memory_allocator = get_env("MEMORY_ALLOCATOR") or memory_allocator or platform.default_memory_allocator()
config = Config(
job_id=job_id,

View File

@@ -24,9 +24,7 @@ def overall_stats(
):
from smallpond.logical.node import DataSetPartitionNode, DataSinkNode, SqlEngineNode
n = SqlEngineNode(
ctx, inp, sql_per_part, cpu_limit=cpu_limit, memory_limit=memory_limit
)
n = SqlEngineNode(ctx, inp, sql_per_part, cpu_limit=cpu_limit, memory_limit=memory_limit)
p = DataSetPartitionNode(ctx, (n,), npartitions=1)
n2 = SqlEngineNode(
ctx,
@@ -59,9 +57,7 @@ def execute_command(cmd: str, env: Dict[str, str] = None, shell=False):
raise subprocess.CalledProcessError(return_code, cmd)
def cprofile_to_string(
perf_profile: cProfile.Profile, order_by=pstats.SortKey.TIME, top_k=20
):
def cprofile_to_string(perf_profile: cProfile.Profile, order_by=pstats.SortKey.TIME, top_k=20):
perf_profile.disable()
pstats_output = io.StringIO()
profile_stats = pstats.Stats(perf_profile, stream=pstats_output)
@@ -111,9 +107,7 @@ class ConcurrentIter(object):
"""
def __init__(self, iterable: Iterable, max_buffer_size=1) -> None:
assert isinstance(
iterable, Iterable
), f"expect an iterable but found: {repr(iterable)}"
assert isinstance(iterable, Iterable), f"expect an iterable but found: {repr(iterable)}"
self.__iterable = iterable
self.__queue = queue.Queue(max_buffer_size)
self.__last = object()
@@ -194,6 +188,4 @@ class InterceptHandler(logging.Handler):
frame = frame.f_back
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())

View File

@@ -14,9 +14,7 @@ if __name__ == "__main__":
required=True,
help="The address of the Ray cluster to connect to",
)
parser.add_argument(
"--log_dir", required=True, help="The directory where logs will be stored"
)
parser.add_argument("--log_dir", required=True, help="The directory where logs will be stored")
parser.add_argument(
"--bind_numa_node",
action="store_true",

View File

@@ -14,19 +14,13 @@ from filelock import FileLock
def generate_url_and_domain() -> Tuple[str, str]:
domain_part = "".join(
random.choices(string.ascii_lowercase, k=random.randint(5, 15))
)
domain_part = "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 15)))
tld = random.choice(["com", "net", "org", "cn", "edu", "gov", "co", "io"])
domain = f"www.{domain_part}.{tld}"
path_segments = []
for _ in range(random.randint(1, 3)):
segment = "".join(
random.choices(
string.ascii_lowercase + string.digits, k=random.randint(3, 10)
)
)
segment = "".join(random.choices(string.ascii_lowercase + string.digits, k=random.randint(3, 10)))
path_segments.append(segment)
path = "/" + "/".join(path_segments)
@@ -42,26 +36,18 @@ def generate_random_date() -> str:
start = datetime(2023, 1, 1, tzinfo=timezone.utc)
end = datetime(2023, 12, 31, tzinfo=timezone.utc)
delta = end - start
random_date = start + timedelta(
seconds=random.randint(0, int(delta.total_seconds()))
)
random_date = start + timedelta(seconds=random.randint(0, int(delta.total_seconds())))
return random_date.strftime("%Y-%m-%dT%H:%M:%SZ")
def generate_content() -> bytes:
target_length = (
random.randint(1000, 100000)
if random.random() < 0.8
else random.randint(100000, 1000000)
)
target_length = random.randint(1000, 100000) if random.random() < 0.8 else random.randint(100000, 1000000)
before = b"<!DOCTYPE html><html><head><title>Random Page</title></head><body>"
after = b"</body></html>"
total_before_after = len(before) + len(after)
fill_length = max(target_length - total_before_after, 0)
filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[
:fill_length
]
filler = "".join(random.choices(string.printable, k=fill_length)).encode("ascii")[:fill_length]
return before + filler + after
@@ -103,9 +89,7 @@ def generate_random_string(length: int) -> str:
def generate_random_url() -> str:
"""Generate a random URL"""
path = generate_random_string(random.randint(10, 20))
return (
f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html"
)
return f"com.{random.randint(10000, 999999)}.{random.randint(100, 9999)}/{path}.html"
def generate_random_data() -> str:
@@ -138,9 +122,7 @@ def generate_url_parquet_files(output_dir: str, num_files: int = 10):
)
def generate_url_tsv_files(
output_dir: str, num_files: int = 10, lines_per_file: int = 100
):
def generate_url_tsv_files(output_dir: str, num_files: int = 10, lines_per_file: int = 100):
"""Generate multiple files, each containing a specified number of random data lines"""
os.makedirs(output_dir, exist_ok=True)
for i in range(num_files):
@@ -166,16 +148,12 @@ def generate_data(path: str = "tests/data"):
with FileLock(path + "/data.lock"):
print("Generating data...")
if not os.path.exists(path + "/mock_urls"):
generate_url_tsv_files(
output_dir=path + "/mock_urls", num_files=10, lines_per_file=100
)
generate_url_tsv_files(output_dir=path + "/mock_urls", num_files=10, lines_per_file=100)
generate_url_parquet_files(output_dir=path + "/mock_urls", num_files=10)
if not os.path.exists(path + "/arrow"):
generate_arrow_files(output_dir=path + "/arrow", num_files=10)
if not os.path.exists(path + "/large_array"):
concat_arrow_files(
input_dir=path + "/arrow", output_dir=path + "/large_array"
)
concat_arrow_files(input_dir=path + "/arrow", output_dir=path + "/large_array")
if not os.path.exists(path + "/long_path_list.txt"):
generate_long_path_list(path=path + "/long_path_list.txt")
except Exception as e:

View File

@@ -37,10 +37,7 @@ class TestArrow(TestFabric, unittest.TestCase):
with self.subTest(dataset_path=dataset_path):
metadata = parquet.read_metadata(dataset_path)
file_num_rows = metadata.num_rows
data_size = sum(
metadata.row_group(i).total_byte_size
for i in range(metadata.num_row_groups)
)
data_size = sum(metadata.row_group(i).total_byte_size for i in range(metadata.num_row_groups))
row_range = RowRange(
path=dataset_path,
begin=100,
@@ -48,9 +45,7 @@ class TestArrow(TestFabric, unittest.TestCase):
data_size=data_size,
file_num_rows=file_num_rows,
)
expected = self._load_parquet_files([dataset_path]).slice(
offset=100, length=100
)
expected = self._load_parquet_files([dataset_path]).slice(offset=100, length=100)
actual = load_from_parquet_files([row_range])
self._compare_arrow_tables(expected, actual)
@@ -62,21 +57,15 @@ class TestArrow(TestFabric, unittest.TestCase):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected = self._load_parquet_files(parquet_files)
with tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ok = dump_to_parquet_files(expected, output_dir)
self.assertTrue(ok)
actual = self._load_parquet_files(
glob.glob(f"{output_dir}/*.parquet")
)
actual = self._load_parquet_files(glob.glob(f"{output_dir}/*.parquet"))
self._compare_arrow_tables(expected, actual)
def test_dump_load_empty_table(self):
# create empty table
empty_table = self._load_parquet_files(
["tests/data/arrow/data0.parquet"]
).slice(length=0)
empty_table = self._load_parquet_files(["tests/data/arrow/data0.parquet"]).slice(length=0)
self.assertEqual(empty_table.num_rows, 0)
# dump empty table
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
@@ -94,9 +83,7 @@ class TestArrow(TestFabric, unittest.TestCase):
):
with self.subTest(dataset_path=dataset_path):
parquet_files = glob.glob(dataset_path)
expected_num_rows = sum(
parquet.read_metadata(file).num_rows for file in parquet_files
)
expected_num_rows = sum(parquet.read_metadata(file).num_rows for file in parquet_files)
with build_batch_reader_from_files(
parquet_files,
batch_size=expected_num_rows,
@@ -104,9 +91,7 @@ class TestArrow(TestFabric, unittest.TestCase):
) as batch_reader, ConcurrentIter(batch_reader) as concurrent_iter:
total_num_rows = 0
for batch in concurrent_iter:
print(
f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}"
)
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {expected_num_rows}")
self.assertLessEqual(batch.num_rows, expected_num_rows)
total_num_rows += batch.num_rows
self.assertEqual(total_num_rows, expected_num_rows)
@@ -121,9 +106,7 @@ class TestArrow(TestFabric, unittest.TestCase):
table = self._load_parquet_files(parquet_files)
total_num_rows = 0
for batch in table.to_batches(max_chunksize=table.num_rows):
print(
f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}"
)
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {table.num_rows}")
self.assertLessEqual(batch.num_rows, table.num_rows)
total_num_rows += batch.num_rows
self.assertEqual(total_num_rows, table.num_rows)
@@ -135,26 +118,14 @@ class TestArrow(TestFabric, unittest.TestCase):
print(f"table_with_meta.schema.metadata {table_with_meta.schema.metadata}")
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
self.assertTrue(
dump_to_parquet_files(
table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2
)
)
parquet_files = glob.glob(
os.path.join(output_dir, "arrow_schema_metadata*.parquet")
)
loaded_table = load_from_parquet_files(
parquet_files, table.column_names[:1]
)
self.assertTrue(dump_to_parquet_files(table_with_meta, output_dir, "arrow_schema_metadata", max_workers=2))
parquet_files = glob.glob(os.path.join(output_dir, "arrow_schema_metadata*.parquet"))
loaded_table = load_from_parquet_files(parquet_files, table.column_names[:1])
print(f"loaded_table.schema.metadata {loaded_table.schema.metadata}")
self.assertEqual(
table_with_meta.schema.metadata, loaded_table.schema.metadata
)
self.assertEqual(table_with_meta.schema.metadata, loaded_table.schema.metadata)
with parquet.ParquetFile(parquet_files[0]) as file:
print(f"file.schema_arrow.metadata {file.schema_arrow.metadata}")
self.assertEqual(
table_with_meta.schema.metadata, file.schema_arrow.metadata
)
self.assertEqual(table_with_meta.schema.metadata, file.schema_arrow.metadata)
def test_load_mixed_string_types(self):
parquet_paths = glob.glob("tests/data/arrow/*.parquet")
@@ -166,9 +137,7 @@ class TestArrow(TestFabric, unittest.TestCase):
loaded_table = load_from_parquet_files(parquet_paths)
self.assertEqual(table.num_rows * 2, loaded_table.num_rows)
batch_reader = build_batch_reader_from_files(parquet_paths)
self.assertEqual(
table.num_rows * 2, sum(batch.num_rows for batch in batch_reader)
)
self.assertEqual(table.num_rows * 2, sum(batch.num_rows for batch in batch_reader))
@logger.catch(reraise=True, message="failed to load parquet files")
def _load_from_parquet_files_with_log(self, paths, columns):

View File

@@ -45,9 +45,7 @@ class TestBench(TestFabric, unittest.TestCase):
num_sort_partitions = 1 << 3
for shuffle_engine in ("duckdb", "arrow"):
for sort_engine in ("duckdb", "arrow", "polars"):
with self.subTest(
shuffle_engine=shuffle_engine, sort_engine=sort_engine
):
with self.subTest(shuffle_engine=shuffle_engine, sort_engine=sort_engine):
ctx = Context()
random_records = generate_random_records(
ctx,

View File

@@ -34,9 +34,7 @@ class TestCommon(TestFabric, unittest.TestCase):
npartitions = data.draw(st.integers(1, 2 * nelements))
items = list(range(nelements))
computed = split_into_rows(items, npartitions)
expected = [
get_nth_partition(items, n, npartitions) for n in range(npartitions)
]
expected = [get_nth_partition(items, n, npartitions) for n in range(npartitions)]
self.assertEqual(expected, computed)
@given(st.data())

View File

@@ -70,9 +70,7 @@ def test_flat_map(sp: Session):
# user need to specify the schema if can not be inferred from the mapping values
df3 = df.flat_map(lambda r: [{"c": None}], schema=pa.schema([("c", pa.int64())]))
assert df3.to_arrow() == pa.table(
{"c": pa.array([None, None, None], type=pa.int64())}
)
assert df3.to_arrow() == pa.table({"c": pa.array([None, None, None], type=pa.int64())})
def test_map_batches(sp: Session):
@@ -99,10 +97,7 @@ def test_random_shuffle(sp: Session):
assert sorted(shuffled) == list(range(1000))
def count_inversions(arr: List[int]) -> int:
return sum(
sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j])
for i in range(len(arr))
)
return sum(sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j]) for i in range(len(arr)))
# check the shuffle is random enough
# the expected number of inversions is n*(n-1)/4 = 249750
@@ -158,9 +153,7 @@ def test_partial_sql(sp: Session):
# join
df1 = sp.from_arrow(pa.table({"id1": [1, 2, 3], "val1": ["a", "b", "c"]}))
df2 = sp.from_arrow(pa.table({"id2": [1, 2, 3], "val2": ["d", "e", "f"]}))
joined = sp.partial_sql(
"select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2
)
joined = sp.partial_sql("select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2)
assert joined.to_arrow() == pa.table(
{"id1": [1, 2, 3], "val1": ["a", "b", "c"], "val2": ["d", "e", "f"]},
schema=pa.schema(
@@ -193,10 +186,7 @@ def test_unpicklable_task_exception(sp: Session):
df.map(lambda x: logger.info("use outside logger")).to_arrow()
except Exception as ex:
assert "Can't pickle task" in str(ex)
assert (
"HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task."
in str(ex)
)
assert "HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." in str(ex)
else:
assert False, "expected exception"

View File

@@ -30,9 +30,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
dataset = ParquetDataSet([os.path.join(self.output_root_abspath, "*.parquet")])
self.assertEqual(num_urls, dataset.num_rows)
def _generate_parquet_dataset(
self, output_path, npartitions, num_rows, row_group_size
):
def _generate_parquet_dataset(self, output_path, npartitions, num_rows, row_group_size):
duckdb.sql(
f"""copy (
select range as i, range % {npartitions} as partition from range(0, {num_rows}) )
@@ -41,9 +39,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
)
return ParquetDataSet([f"{output_path}/**/*.parquet"])
def _check_partition_datasets(
self, orig_dataset: ParquetDataSet, partition_func, npartition
):
def _check_partition_datasets(self, orig_dataset: ParquetDataSet, partition_func, npartition):
# build partitioned datasets
partitioned_datasets = partition_func(npartition)
self.assertEqual(npartition, len(partitioned_datasets))
@@ -52,9 +48,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
sum(dataset.num_rows for dataset in partitioned_datasets),
)
# load as arrow table
loaded_table = arrow.concat_tables(
[dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets]
)
loaded_table = arrow.concat_tables([dataset.to_arrow_table(max_workers=1) for dataset in partitioned_datasets])
self.assertEqual(orig_dataset.num_rows, loaded_table.num_rows)
# compare arrow tables
orig_table = orig_dataset.to_arrow_table(max_workers=1)
@@ -74,9 +68,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
def test_partition_by_files(self):
output_path = os.path.join(self.output_root_abspath, "test_partition_by_files")
orig_dataset = self._generate_parquet_dataset(
output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000
)
orig_dataset = self._generate_parquet_dataset(output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000)
num_files = len(orig_dataset.resolved_paths)
for npartition in range(1, num_files + 1):
for random_shuffle in (False, True):
@@ -84,17 +76,13 @@ class TestDataSet(TestFabric, unittest.TestCase):
orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir)
self._check_partition_datasets(
orig_dataset,
lambda n: orig_dataset.partition_by_files(
n, random_shuffle=random_shuffle
),
lambda n: orig_dataset.partition_by_files(n, random_shuffle=random_shuffle),
npartition,
)
def test_partition_by_rows(self):
output_path = os.path.join(self.output_root_abspath, "test_partition_by_rows")
orig_dataset = self._generate_parquet_dataset(
output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000
)
orig_dataset = self._generate_parquet_dataset(output_path, npartitions=11, num_rows=170 * 1000, row_group_size=10 * 1000)
num_files = len(orig_dataset.resolved_paths)
for npartition in range(1, 2 * num_files + 1):
for random_shuffle in (False, True):
@@ -102,9 +90,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
orig_dataset.reset(orig_dataset.paths, orig_dataset.root_dir)
self._check_partition_datasets(
orig_dataset,
lambda n: orig_dataset.partition_by_rows(
n, random_shuffle=random_shuffle
),
lambda n: orig_dataset.partition_by_rows(n, random_shuffle=random_shuffle),
npartition,
)
@@ -116,9 +102,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
self.assertEqual(len(dataset.resolved_paths), len(filenames))
def test_paths_with_char_ranges(self):
dataset_with_char_ranges = ParquetDataSet(
["tests/data/arrow/data[0-9].parquet"]
)
dataset_with_char_ranges = ParquetDataSet(["tests/data/arrow/data[0-9].parquet"])
dataset_with_wildcards = ParquetDataSet(["tests/data/arrow/*.parquet"])
self.assertEqual(
len(dataset_with_char_ranges.resolved_paths),
@@ -126,9 +110,7 @@ class TestDataSet(TestFabric, unittest.TestCase):
)
def test_to_arrow_table_batch_reader(self):
memdb = duckdb.connect(
database=":memory:", config={"arrow_large_buffer_size": "true"}
)
memdb = duckdb.connect(database=":memory:", config={"arrow_large_buffer_size": "true"})
for dataset_path in (
"tests/data/arrow/*.parquet",
"tests/data/large_array/*.parquet",
@@ -137,24 +119,14 @@ class TestDataSet(TestFabric, unittest.TestCase):
print(f"dataset_path: {dataset_path}, conn: {conn}")
with self.subTest(dataset_path=dataset_path, conn=conn):
dataset = ParquetDataSet([dataset_path])
to_batches = dataset.to_arrow_table(
max_workers=1, conn=conn
).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2)
batch_reader = dataset.to_batch_reader(
batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn
)
with ConcurrentIter(
batch_reader, max_buffer_size=2
) as batch_reader:
to_batches = dataset.to_arrow_table(max_workers=1, conn=conn).to_batches(max_chunksize=DEFAULT_ROW_GROUP_SIZE * 2)
batch_reader = dataset.to_batch_reader(batch_size=DEFAULT_ROW_GROUP_SIZE * 2, conn=conn)
with ConcurrentIter(batch_reader, max_buffer_size=2) as batch_reader:
for batch_iter in (to_batches, batch_reader):
total_num_rows = 0
for batch in batch_iter:
print(
f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}"
)
self.assertLessEqual(
batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2
)
print(f"batch.num_rows {batch.num_rows}, max_batch_row_size {DEFAULT_ROW_GROUP_SIZE*2}")
self.assertLessEqual(batch.num_rows, DEFAULT_ROW_GROUP_SIZE * 2)
total_num_rows += batch.num_rows
print(f"{dataset_path}: total_num_rows {total_num_rows}")
self.assertEqual(total_num_rows, dataset.num_rows)
@@ -167,8 +139,6 @@ def test_arrow_reader(benchmark, reader: str, dataset_path: str):
dataset = ParquetDataSet([dataset_path])
conn = None
if reader == "duckdb":
conn = duckdb.connect(
database=":memory:", config={"arrow_large_buffer_size": "true"}
)
conn = duckdb.connect(database=":memory:", config={"arrow_large_buffer_size": "true"})
benchmark(dataset.to_arrow_table, conn=conn)
# result: arrow reader is 4x faster than duckdb reader in small dataset, 1.4x faster in large dataset

View File

@@ -7,9 +7,7 @@ from smallpond.io.arrow import cast_columns_to_large_string
from tests.test_fabric import TestFabric
@unittest.skipUnless(
importlib.util.find_spec("deltalake") is not None, "cannot find deltalake"
)
@unittest.skipUnless(importlib.util.find_spec("deltalake") is not None, "cannot find deltalake")
class TestDeltaLake(TestFabric, unittest.TestCase):
def test_read_write_deltalake(self):
from deltalake import DeltaTable, write_deltalake
@@ -20,9 +18,7 @@ class TestDeltaLake(TestFabric, unittest.TestCase):
):
parquet_files = glob.glob(dataset_path)
expected = self._load_parquet_files(parquet_files)
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
write_deltalake(output_dir, expected, large_dtypes=True)
dt = DeltaTable(output_dir)
self._compare_arrow_tables(expected, dt.to_pyarrow_table())
@@ -35,12 +31,8 @@ class TestDeltaLake(TestFabric, unittest.TestCase):
"tests/data/large_array/*.parquet",
):
parquet_files = glob.glob(dataset_path)
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
table = cast_columns_to_large_string(
self._load_parquet_files(parquet_files)
)
with self.subTest(dataset_path=dataset_path), tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
table = cast_columns_to_large_string(self._load_parquet_files(parquet_files))
write_deltalake(output_dir, table, large_dtypes=True, mode="overwrite")
write_deltalake(output_dir, table, large_dtypes=False, mode="append")
loaded_table = DeltaTable(output_dir).to_pyarrow_table()

View File

@@ -69,9 +69,7 @@ class OutputMsgPythonTask(PythonScriptTask):
input_datasets: List[DataSet],
output_path: str,
) -> bool:
logger.info(
f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}"
)
logger.info(f"msg: {self.msg}, num files: {input_datasets[0].num_files}, local gpu ranks: {self.local_gpu_ranks}")
self.inject_fault()
return True
@@ -105,9 +103,7 @@ class CopyInputArrowNode(ArrowComputeNode):
super().__init__(ctx, input_deps)
self.msg = msg
def process(
self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
def process(self, runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
return copy_input_arrow(runtime_ctx, input_tables, self.msg)
@@ -117,24 +113,18 @@ class CopyInputStreamNode(ArrowStreamNode):
super().__init__(ctx, input_deps)
self.msg = msg
def process(
self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]
) -> Iterable[arrow.Table]:
def process(self, runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader]) -> Iterable[arrow.Table]:
return copy_input_stream(runtime_ctx, input_readers, self.msg)
def copy_input_arrow(
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str
) -> arrow.Table:
def copy_input_arrow(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], msg: str) -> arrow.Table:
logger.info(f"msg: {msg}, num rows: {input_tables[0].num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
runtime_ctx.task.inject_fault()
return input_tables[0]
def copy_input_stream(
runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str
) -> Iterable[arrow.Table]:
def copy_input_stream(runtime_ctx: RuntimeContext, input_readers: List[arrow.RecordBatchReader], msg: str) -> Iterable[arrow.Table]:
for index, batch in enumerate(input_readers[0]):
logger.info(f"msg: {msg}, batch index: {index}, num rows: {batch.num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
@@ -146,62 +136,44 @@ def copy_input_stream(
runtime_ctx.task.inject_fault()
def copy_input_batch(
runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str
) -> arrow.Table:
def copy_input_batch(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table], msg: str) -> arrow.Table:
logger.info(f"msg: {msg}, num rows: {input_batches[0].num_rows}")
time.sleep(runtime_ctx.secs_executor_probe_interval)
runtime_ctx.task.inject_fault()
return input_batches[0]
def copy_input_data_frame(
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
) -> DataFrame:
def copy_input_data_frame(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
runtime_ctx.task.inject_fault()
return input_dfs[0]
def copy_input_data_frame_batch(
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
) -> DataFrame:
def copy_input_data_frame_batch(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
runtime_ctx.task.inject_fault()
return input_dfs[0]
def merge_input_tables(
runtime_ctx: RuntimeContext, input_batches: List[arrow.Table]
) -> arrow.Table:
def merge_input_tables(runtime_ctx: RuntimeContext, input_batches: List[arrow.Table]) -> arrow.Table:
runtime_ctx.task.inject_fault()
output = arrow.concat_tables(input_batches)
logger.info(
f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}"
)
logger.info(f"input rows: {[len(batch) for batch in input_batches]}, output rows: {len(output)}")
return output
def merge_input_data_frames(
runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]
) -> DataFrame:
def merge_input_data_frames(runtime_ctx: RuntimeContext, input_dfs: List[DataFrame]) -> DataFrame:
runtime_ctx.task.inject_fault()
output = pandas.concat(input_dfs)
logger.info(
f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}"
)
logger.info(f"input rows: {[len(df) for df in input_dfs]}, output rows: {len(output)}")
return output
def parse_url(
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]
) -> arrow.Table:
def parse_url(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table]) -> arrow.Table:
urls = input_tables[0].columns[0]
hosts = [url.as_py().split("/", maxsplit=2)[0] for url in urls]
return input_tables[0].append_column("host", arrow.array(hosts))
def nonzero_exit_code(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
def nonzero_exit_code(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
import sys
if runtime_ctx.task._memory_boost == 1:
@@ -210,9 +182,7 @@ def nonzero_exit_code(
# create an empty file with a fixed name
def empty_file(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
def empty_file(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
import os
with open(os.path.join(output_path, "file"), "w") as fout:
@@ -231,9 +201,7 @@ def split_url(urls: arrow.array) -> arrow.array:
return arrow.array(url_parts, type=arrow.list_(arrow.string()))
def choose_random_urls(
runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5
) -> arrow.Table:
def choose_random_urls(runtime_ctx: RuntimeContext, input_tables: List[arrow.Table], k: int = 5) -> arrow.Table:
# get the current running task
runtime_task = runtime_ctx.task
# access task-specific attributes
@@ -255,16 +223,12 @@ class TestExecution(TestFabric, unittest.TestCase):
def test_arrow_task(self):
for use_duckdb_reader in (False, True):
with self.subTest(use_duckdb_reader=use_duckdb_reader):
with tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_table = dataset.to_arrow_table()
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=7
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=7)
if use_duckdb_reader:
data_partitions = ProjectionNode(
ctx,
@@ -274,9 +238,7 @@ class TestExecution(TestFabric, unittest.TestCase):
arrow_compute = ArrowComputeNode(
ctx,
(data_partitions,),
process_func=functools.partial(
copy_input_arrow, msg="arrow compute"
),
process_func=functools.partial(copy_input_arrow, msg="arrow compute"),
use_duckdb_reader=use_duckdb_reader,
output_name="arrow_compute",
output_path=output_dir,
@@ -285,9 +247,7 @@ class TestExecution(TestFabric, unittest.TestCase):
arrow_stream = ArrowStreamNode(
ctx,
(data_partitions,),
process_func=functools.partial(
copy_input_stream, msg="arrow stream"
),
process_func=functools.partial(copy_input_stream, msg="arrow stream"),
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
use_duckdb_reader=use_duckdb_reader,
@@ -298,9 +258,7 @@ class TestExecution(TestFabric, unittest.TestCase):
arrow_batch = ArrowBatchNode(
ctx,
(data_partitions,),
process_func=functools.partial(
copy_input_batch, msg="arrow batch"
),
process_func=functools.partial(copy_input_batch, msg="arrow batch"),
streaming_batch_size=10,
secs_checkpoint_interval=0.5,
use_duckdb_reader=use_duckdb_reader,
@@ -314,12 +272,8 @@ class TestExecution(TestFabric, unittest.TestCase):
output_path=output_dir,
)
plan = LogicalPlan(ctx, data_sink)
exec_plan = self.execute_plan(
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
)
self.assertTrue(
all(map(os.path.exists, exec_plan.final_output.resolved_paths))
)
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
arrow_compute_output = ParquetDataSet(
[os.path.join(output_dir, "arrow_compute", "**/*.parquet")],
recursive=True,
@@ -334,21 +288,15 @@ class TestExecution(TestFabric, unittest.TestCase):
)
self._compare_arrow_tables(
data_table,
arrow_compute_output.to_arrow_table().select(
data_table.column_names
),
arrow_compute_output.to_arrow_table().select(data_table.column_names),
)
self._compare_arrow_tables(
data_table,
arrow_stream_output.to_arrow_table().select(
data_table.column_names
),
arrow_stream_output.to_arrow_table().select(data_table.column_names),
)
self._compare_arrow_tables(
data_table,
arrow_batch_output.to_arrow_table().select(
data_table.column_names
),
arrow_batch_output.to_arrow_table().select(data_table.column_names),
)
def test_pandas_task(self):
@@ -376,16 +324,10 @@ class TestExecution(TestFabric, unittest.TestCase):
output_path=output_dir,
cpu_limit=2,
)
data_sink = DataSinkNode(
ctx, (pandas_compute, pandas_batch), output_path=output_dir
)
data_sink = DataSinkNode(ctx, (pandas_compute, pandas_batch), output_path=output_dir)
plan = LogicalPlan(ctx, data_sink)
exec_plan = self.execute_plan(
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
)
self.assertTrue(
all(map(os.path.exists, exec_plan.final_output.resolved_paths))
)
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
pandas_compute_output = ParquetDataSet(
[os.path.join(output_dir, "pandas_compute", "**/*.parquet")],
recursive=True,
@@ -394,21 +336,15 @@ class TestExecution(TestFabric, unittest.TestCase):
[os.path.join(output_dir, "pandas_batch", "**/*.parquet")],
recursive=True,
)
self._compare_arrow_tables(
data_table, pandas_compute_output.to_arrow_table()
)
self._compare_arrow_tables(data_table, pandas_compute_output.to_arrow_table())
self._compare_arrow_tables(data_table, pandas_batch_output.to_arrow_table())
def test_variable_length_input_datasets(self):
ctx = Context()
small_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
large_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10)
small_partitions = DataSetPartitionNode(
ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7
)
large_partitions = DataSetPartitionNode(
ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7
)
small_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, small_dataset),), npartitions=7)
large_partitions = DataSetPartitionNode(ctx, (DataSourceNode(ctx, large_dataset),), npartitions=7)
arrow_batch = ArrowBatchNode(
ctx,
(small_partitions, large_partitions),
@@ -428,9 +364,7 @@ class TestExecution(TestFabric, unittest.TestCase):
cpu_limit=2,
)
plan = LogicalPlan(ctx, RootNode(ctx, (arrow_batch, pandas_batch)))
exec_plan = self.execute_plan(
plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5
)
exec_plan = self.execute_plan(plan, fault_inject_prob=0.1, secs_executor_probe_interval=0.5)
self.assertTrue(all(map(os.path.exists, exec_plan.final_output.resolved_paths)))
arrow_batch_output = ParquetDataSet(
[os.path.join(exec_plan.ctx.output_root, "arrow_batch", "**/*.parquet")],
@@ -440,9 +374,7 @@ class TestExecution(TestFabric, unittest.TestCase):
[os.path.join(exec_plan.ctx.output_root, "pandas_batch", "**/*.parquet")],
recursive=True,
)
self.assertEqual(
small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows
)
self.assertEqual(small_dataset.num_rows + large_dataset.num_rows, arrow_batch_output.num_rows)
self.assertEqual(
small_dataset.num_rows + large_dataset.num_rows,
pandas_batch_output.num_rows,
@@ -453,9 +385,7 @@ class TestExecution(TestFabric, unittest.TestCase):
# select columns when defining dataset
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"], columns=["url"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=3, partition_by_rows=True
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=3, partition_by_rows=True)
# projection as input of arrow node
generated_columns = ["filename", "file_row_number"]
urls_with_host = ArrowComputeNode(
@@ -480,9 +410,7 @@ class TestExecution(TestFabric, unittest.TestCase):
# unify different schemas
merged_diff_schemas = ProjectionNode(
ctx,
DataSetPartitionNode(
ctx, (distinct_urls_with_host, urls_with_host), npartitions=1
),
DataSetPartitionNode(ctx, (distinct_urls_with_host, urls_with_host), npartitions=1),
union_by_name=True,
)
host_partitions = HashPartitionNode(
@@ -513,9 +441,7 @@ class TestExecution(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
ctx.create_function(
"split_url",
split_url,
@@ -537,9 +463,7 @@ class TestExecution(TestFabric, unittest.TestCase):
npartitions = 1000
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * npartitions)
data_files = DataSourceNode(ctx, dataset)
data_partitions = EvenlyDistributedPartitionNode(
ctx, (data_files,), npartitions=npartitions
)
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions)
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(
@@ -552,13 +476,9 @@ class TestExecution(TestFabric, unittest.TestCase):
def test_many_producers_and_partitions(self):
ctx = Context()
npartitions = 10000
dataset = ParquetDataSet(
["tests/data/mock_urls/*.parquet"] * (npartitions * 10)
)
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * (npartitions * 10))
data_files = DataSourceNode(ctx, dataset)
data_partitions = EvenlyDistributedPartitionNode(
ctx, (data_files,), npartitions=npartitions, cpu_limit=1
)
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=npartitions, cpu_limit=1)
data_partitions.max_num_producer_tasks = 20
output_msg = OutputMsgPythonNode(ctx, (data_partitions,))
plan = LogicalPlan(ctx, output_msg)
@@ -573,12 +493,8 @@ class TestExecution(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files
)
output_msg = OutputMsgPythonNode(
ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
output_msg = OutputMsgPythonNode(ctx, (data_partitions,), cpu_limit=1, gpu_limit=0.5)
plan = LogicalPlan(ctx, output_msg)
runtime_ctx = RuntimeContext(
JobId.new(),
@@ -596,9 +512,7 @@ class TestExecution(TestFabric, unittest.TestCase):
data_files = DataSourceNode(ctx, dataset)
copy_input_arrow_node = CopyInputArrowNode(ctx, (data_files,), "hello")
copy_input_stream_node = CopyInputStreamNode(ctx, (data_files,), "hello")
output_msg = OutputMsgPythonNode2(
ctx, (copy_input_arrow_node, copy_input_stream_node), "hello"
)
output_msg = OutputMsgPythonNode2(ctx, (copy_input_arrow_node, copy_input_stream_node), "hello")
plan = LogicalPlan(ctx, output_msg)
self.execute_plan(plan)
@@ -606,9 +520,7 @@ class TestExecution(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
uniq_urls = SqlEngineNode(
ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB
)
uniq_urls = SqlEngineNode(ctx, (data_files,), r"select distinct * from {0}", memory_limit=2 * MB)
uniq_url_partitions = DataSetPartitionNode(ctx, (uniq_urls,), 2)
uniq_url_count = SqlEngineNode(
ctx,
@@ -637,9 +549,7 @@ class TestExecution(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
data_sink = DataSinkNode(
ctx, (arrow_compute, arrow_stream), output_path=output_dir
)
data_sink = DataSinkNode(ctx, (arrow_compute, arrow_stream), output_path=output_dir)
plan = LogicalPlan(ctx, data_sink)
self.execute_plan(
plan,
@@ -652,17 +562,11 @@ class TestExecution(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
nonzero_exitcode = PythonScriptNode(
ctx, (data_files,), process_func=nonzero_exit_code
)
nonzero_exitcode = PythonScriptNode(ctx, (data_files,), process_func=nonzero_exit_code)
plan = LogicalPlan(ctx, nonzero_exitcode)
exec_plan = self.execute_plan(
plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False
)
exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=False)
self.assertFalse(exec_plan.successful)
exec_plan = self.execute_plan(
plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True
)
exec_plan = self.execute_plan(plan, num_executors=1, check_result=False, nonzero_exitcode_as_oom=True)
self.assertTrue(exec_plan.successful)
def test_manifest_only_data_sink(self):
@@ -675,9 +579,7 @@ class TestExecution(TestFabric, unittest.TestCase):
dataset = ParquetDataSet(filenames)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=512)
data_sink = DataSinkNode(
ctx, (data_partitions,), output_path=output_dir, manifest_only=True
)
data_sink = DataSinkNode(ctx, (data_partitions,), output_path=output_dir, manifest_only=True)
plan = LogicalPlan(ctx, data_sink)
self.execute_plan(plan)
@@ -734,12 +636,8 @@ class TestExecution(TestFabric, unittest.TestCase):
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
url_counts = SqlEngineNode(
ctx, (data_partitions,), r"select count(url) as cnt from {0}"
)
distinct_url_counts = SqlEngineNode(
ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}"
)
url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(url) as cnt from {0}")
distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}")
merged_counts = DataSetPartitionNode(
ctx,
(
@@ -764,9 +662,7 @@ class TestExecution(TestFabric, unittest.TestCase):
r"select count(url) as cnt from {0}",
output_name="url_counts",
)
distinct_url_counts = SqlEngineNode(
ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}"
)
distinct_url_counts = SqlEngineNode(ctx, (data_partitions,), r"select count(distinct url) as cnt from {0}")
merged_counts = DataSetPartitionNode(
ctx,
(
@@ -787,24 +683,14 @@ class TestExecution(TestFabric, unittest.TestCase):
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=10)
empty_files1 = PythonScriptNode(
ctx, (data_partitions,), process_func=empty_file
)
empty_files2 = PythonScriptNode(
ctx, (data_partitions,), process_func=empty_file
)
empty_files1 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file)
empty_files2 = PythonScriptNode(ctx, (data_partitions,), process_func=empty_file)
link_path = os.path.join(self.runtime_ctx.output_root, "link")
copy_path = os.path.join(self.runtime_ctx.output_root, "copy")
copy_input_path = os.path.join(self.runtime_ctx.output_root, "copy_input")
data_link = DataSinkNode(
ctx, (empty_files1, empty_files2), type="link", output_path=link_path
)
data_copy = DataSinkNode(
ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path
)
data_copy_input = DataSinkNode(
ctx, (data_partitions,), type="copy", output_path=copy_input_path
)
data_link = DataSinkNode(ctx, (empty_files1, empty_files2), type="link", output_path=link_path)
data_copy = DataSinkNode(ctx, (empty_files1, empty_files2), type="copy", output_path=copy_path)
data_copy_input = DataSinkNode(ctx, (data_partitions,), type="copy", output_path=copy_input_path)
plan = LogicalPlan(ctx, RootNode(ctx, (data_link, data_copy, data_copy_input)))
self.execute_plan(plan)
@@ -813,33 +699,19 @@ class TestExecution(TestFabric, unittest.TestCase):
self.assertEqual(21, len(os.listdir(copy_path)))
# file name should not be modified if no conflict
self.assertEqual(
set(
filename
for filename in os.listdir("tests/data/mock_urls")
if filename.endswith(".parquet")
),
set(
filename
for filename in os.listdir(copy_input_path)
if filename.endswith(".parquet")
),
set(filename for filename in os.listdir("tests/data/mock_urls") if filename.endswith(".parquet")),
set(filename for filename in os.listdir(copy_input_path) if filename.endswith(".parquet")),
)
def test_literal_datasets_as_data_sources(self):
ctx = Context()
num_rows = 10
query_dataset = SqlQueryDataSet(f"select i from range({num_rows}) as x(i)")
table_dataset = ArrowTableDataSet(
arrow.Table.from_arrays([list(range(num_rows))], names=["i"])
)
table_dataset = ArrowTableDataSet(arrow.Table.from_arrays([list(range(num_rows))], names=["i"]))
query_source = DataSourceNode(ctx, query_dataset)
table_source = DataSourceNode(ctx, table_dataset)
query_partitions = DataSetPartitionNode(
ctx, (query_source,), npartitions=num_rows, partition_by_rows=True
)
table_partitions = DataSetPartitionNode(
ctx, (table_source,), npartitions=num_rows, partition_by_rows=True
)
query_partitions = DataSetPartitionNode(ctx, (query_source,), npartitions=num_rows, partition_by_rows=True)
table_partitions = DataSetPartitionNode(ctx, (table_source,), npartitions=num_rows, partition_by_rows=True)
joined_rows = SqlEngineNode(
ctx,
(query_partitions, table_partitions),

View File

@@ -27,9 +27,7 @@ from tests.datagen import generate_data
generate_data()
def run_scheduler(
runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue
):
def run_scheduler(runtime_ctx: RuntimeContext, scheduler: Scheduler, queue: queue.Queue):
runtime_ctx.initialize("scheduler")
scheduler.add_state_observer(Scheduler.StateObserver(SaveSchedState(queue)))
retval = scheduler.run()
@@ -130,9 +128,7 @@ class TestFabric(unittest.TestCase):
process.kill()
process.join()
logger.info(
f"#{i} process {process.name} exited with code {process.exitcode}"
)
logger.info(f"#{i} process {process.name} exited with code {process.exitcode}")
def start_execution(
self,
@@ -189,11 +185,7 @@ class TestFabric(unittest.TestCase):
secs_wq_poll_interval=secs_wq_poll_interval,
secs_executor_probe_interval=secs_executor_probe_interval,
max_num_missed_probes=max_num_missed_probes,
fault_inject_prob=(
fault_inject_prob
if fault_inject_prob is not None
else self.fault_inject_prob
),
fault_inject_prob=(fault_inject_prob if fault_inject_prob is not None else self.fault_inject_prob),
enable_profiling=enable_profiling,
enable_diagnostic_metrics=enable_diagnostic_metrics,
remove_empty_parquet=remove_empty_parquet,
@@ -217,9 +209,7 @@ class TestFabric(unittest.TestCase):
nonzero_exitcode_as_oom=nonzero_exitcode_as_oom,
)
self.latest_state = scheduler
self.executors = [
Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)
]
self.executors = [Executor.create(runtime_ctx, f"executor-{i}") for i in range(num_executors)]
self.processes = [
Process(
target=run_scheduler,
@@ -229,10 +219,7 @@ class TestFabric(unittest.TestCase):
name="scheduler",
)
]
self.processes += [
Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id)
for executor in self.executors
]
self.processes += [Process(target=run_executor, args=(runtime_ctx, executor), name=executor.id) for executor in self.executors]
for process in reversed(self.processes):
process.start()
@@ -264,15 +251,9 @@ class TestFabric(unittest.TestCase):
self.assertTrue(latest_state.success)
return latest_state.exec_plan
def _load_parquet_files(
self, paths, filesystem: fsspec.AbstractFileSystem = None
) -> arrow.Table:
def _load_parquet_files(self, paths, filesystem: fsspec.AbstractFileSystem = None) -> arrow.Table:
def read_parquet_file(path):
return arrow.Table.from_batches(
parquet.ParquetFile(
path, buffer_size=16 * MB, filesystem=filesystem
).iter_batches()
)
return arrow.Table.from_batches(parquet.ParquetFile(path, buffer_size=16 * MB, filesystem=filesystem).iter_batches())
with ThreadPoolExecutor(16) as pool:
return arrow.concat_tables(pool.map(read_parquet_file, paths))

View File

@@ -17,9 +17,7 @@ class TestFilesystem(TestFabric, unittest.TestCase):
def test_pickle_trace(self):
with self.assertRaises(TypeError) as context:
with tempfile.TemporaryDirectory(
dir=self.output_root_abspath
) as output_dir:
with tempfile.TemporaryDirectory(dir=self.output_root_abspath) as output_dir:
thread = threading.Thread()
pickle_path = os.path.join(output_dir, "thread.pickle")
dump(thread, pickle_path)

View File

@@ -21,19 +21,11 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
def test_join_chunkmeta_inodes(self):
ctx = Context()
chunkmeta_dump = DataSourceNode(
ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"])
)
chunkmeta_partitions = HashPartitionNode(
ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"]
)
chunkmeta_dump = DataSourceNode(ctx, dataset=ParquetDataSet(["tests/data/chunkmeta*.parquet"]))
chunkmeta_partitions = HashPartitionNode(ctx, (chunkmeta_dump,), npartitions=2, hash_columns=["inodeId"])
inodes_dump = DataSourceNode(
ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"])
)
inodes_partitions = HashPartitionNode(
ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"]
)
inodes_dump = DataSourceNode(ctx, dataset=ParquetDataSet(["tests/data/inodes*.parquet"]))
inodes_partitions = HashPartitionNode(ctx, (inodes_dump,), npartitions=2, hash_columns=["inode_id"])
num_gc_chunks = SqlEngineNode(
ctx,
@@ -53,12 +45,8 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
partition_dim_a = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A"
)
partition_dim_b = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B"
)
partition_dim_a = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A")
partition_dim_b = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="B")
join_two_inputs = SqlEngineNode(
ctx,
(partition_dim_a, partition_dim_b),
@@ -73,9 +61,7 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
partition_dim_a = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A"
)
partition_dim_a = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files, dimension="A")
partition_dim_a2 = EvenlyDistributedPartitionNode(
ctx,
(data_source,),
@@ -94,9 +80,7 @@ class TestLogicalPlan(TestFabric, unittest.TestCase):
)
plan = LogicalPlan(
ctx,
DataSetPartitionNode(
ctx, (join_two_inputs1, join_two_inputs2), npartitions=1
),
DataSetPartitionNode(ctx, (join_two_inputs1, join_two_inputs2), npartitions=1),
)
logger.info(str(plan))
with self.assertRaises(AssertionError) as context:

View File

@@ -31,9 +31,7 @@ from tests.test_fabric import TestFabric
class CalculatePartitionFromFilename(UserDefinedPartitionNode):
def partition(self, runtime_ctx: RuntimeContext, dataset: DataSet) -> List[DataSet]:
partitioned_datasets: List[ParquetDataSet] = [
ParquetDataSet([]) for _ in range(self.npartitions)
]
partitioned_datasets: List[ParquetDataSet] = [ParquetDataSet([]) for _ in range(self.npartitions)]
for path in dataset.resolved_paths:
partition_idx = hash(path) % self.npartitions
partitioned_datasets[partition_idx].paths.append(path)
@@ -45,9 +43,7 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"] * 10)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_files)
count_rows = SqlEngineNode(
ctx,
(data_partitions,),
@@ -62,9 +58,7 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=dataset.num_rows, partition_by_rows=True)
count_rows = SqlEngineNode(
ctx,
(data_partitions,),
@@ -74,18 +68,14 @@ class TestPartition(TestFabric, unittest.TestCase):
)
plan = LogicalPlan(ctx, count_rows)
exec_plan = self.execute_plan(plan, num_executors=5)
self.assertEqual(
exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows
)
self.assertEqual(exec_plan.final_output.to_arrow_table().num_rows, dataset.num_rows)
def test_empty_dataset_partition(self):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
# create more partitions than files
data_partitions = EvenlyDistributedPartitionNode(
ctx, (data_files,), npartitions=dataset.num_files * 2
)
data_partitions = EvenlyDistributedPartitionNode(ctx, (data_files,), npartitions=dataset.num_files * 2)
data_partitions.max_num_producer_tasks = 3
unique_urls = SqlEngineNode(
ctx,
@@ -95,9 +85,7 @@ class TestPartition(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
# nested partition
nested_partitioned_urls = EvenlyDistributedPartitionNode(
ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True
)
nested_partitioned_urls = EvenlyDistributedPartitionNode(ctx, (unique_urls,), npartitions=3, dimension="nested", nested=True)
parsed_urls = ArrowComputeNode(
ctx,
(nested_partitioned_urls,),
@@ -106,18 +94,14 @@ class TestPartition(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, parsed_urls)
final_output = self.execute_plan(
plan, remove_empty_parquet=True, skip_task_with_empty_input=True
).final_output
final_output = self.execute_plan(plan, remove_empty_parquet=True, skip_task_with_empty_input=True).final_output
self.assertTrue(isinstance(final_output, ParquetDataSet))
self.assertEqual(dataset.num_rows, final_output.num_rows)
def test_hash_partition(self):
for engine_type in ("duckdb", "arrow"):
for partition_by_rows in (False, True):
for hive_partitioning in (
(False, True) if engine_type == "duckdb" else (False,)
):
for hive_partitioning in (False, True) if engine_type == "duckdb" else (False,):
with self.subTest(
engine_type=engine_type,
partition_by_rows=partition_by_rows,
@@ -155,26 +139,16 @@ class TestPartition(TestFabric, unittest.TestCase):
exec_plan = self.execute_plan(plan)
self.assertEqual(
dataset.num_rows,
pc.sum(
exec_plan.final_output.to_arrow_table().column(
"row_count"
)
).as_py(),
pc.sum(exec_plan.final_output.to_arrow_table().column("row_count")).as_py(),
)
self.assertEqual(
npartitions,
len(exec_plan.final_output.load_partitioned_datasets(npartitions, DATA_PARTITION_COLUMN_NAME)),
)
self.assertEqual(
npartitions,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions, DATA_PARTITION_COLUMN_NAME
)
),
)
self.assertEqual(
npartitions,
len(
exec_plan.get_output(
"hash_partitions"
).load_partitioned_datasets(
exec_plan.get_output("hash_partitions").load_partitioned_datasets(
npartitions,
DATA_PARTITION_COLUMN_NAME,
hive_partitioning,
@@ -185,9 +159,7 @@ class TestPartition(TestFabric, unittest.TestCase):
def test_empty_hash_partition(self):
for engine_type in ("duckdb", "arrow"):
for partition_by_rows in (False, True):
for hive_partitioning in (
(False, True) if engine_type == "duckdb" else (False,)
):
for hive_partitioning in (False, True) if engine_type == "duckdb" else (False,):
with self.subTest(
engine_type=engine_type,
partition_by_rows=partition_by_rows,
@@ -199,9 +171,7 @@ class TestPartition(TestFabric, unittest.TestCase):
npartitions = 3
npartitions_nested = 4
num_rows = 1
head_rows = SqlEngineNode(
ctx, (data_files,), f"select * from {{0}} limit {num_rows}"
)
head_rows = SqlEngineNode(ctx, (data_files,), f"select * from {{0}} limit {num_rows}")
data_partitions = DataSetPartitionNode(
ctx,
(head_rows,),
@@ -241,53 +211,31 @@ class TestPartition(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
plan = LogicalPlan(ctx, select_every_row)
exec_plan = self.execute_plan(
plan, skip_task_with_empty_input=True
)
exec_plan = self.execute_plan(plan, skip_task_with_empty_input=True)
self.assertEqual(num_rows, exec_plan.final_output.num_rows)
self.assertEqual(
npartitions,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions, "hash_partitions"
)
),
len(exec_plan.final_output.load_partitioned_datasets(npartitions, "hash_partitions")),
)
self.assertEqual(
npartitions_nested,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions_nested, "nested_hash_partitions"
)
),
len(exec_plan.final_output.load_partitioned_datasets(npartitions_nested, "nested_hash_partitions")),
)
self.assertEqual(
npartitions,
len(
exec_plan.get_output(
"hash_partitions"
).load_partitioned_datasets(
npartitions, "hash_partitions"
)
),
len(exec_plan.get_output("hash_partitions").load_partitioned_datasets(npartitions, "hash_partitions")),
)
self.assertEqual(
npartitions_nested,
len(
exec_plan.get_output(
"nested_hash_partitions"
).load_partitioned_datasets(
npartitions_nested, "nested_hash_partitions"
)
exec_plan.get_output("nested_hash_partitions").load_partitioned_datasets(npartitions_nested, "nested_hash_partitions")
),
)
if hive_partitioning:
self.assertEqual(
npartitions,
len(
exec_plan.get_output(
"hash_partitions"
).load_partitioned_datasets(
exec_plan.get_output("hash_partitions").load_partitioned_datasets(
npartitions,
"hash_partitions",
hive_partitioning=True,
@@ -297,9 +245,7 @@ class TestPartition(TestFabric, unittest.TestCase):
self.assertEqual(
npartitions_nested,
len(
exec_plan.get_output(
"nested_hash_partitions"
).load_partitioned_datasets(
exec_plan.get_output("nested_hash_partitions").load_partitioned_datasets(
npartitions_nested,
"nested_hash_partitions",
hive_partitioning=True,
@@ -341,19 +287,11 @@ class TestPartition(TestFabric, unittest.TestCase):
exec_plan = self.execute_plan(plan)
self.assertEqual(
npartitions,
len(
exec_plan.final_output.load_partitioned_datasets(
npartitions, data_partition_column
)
),
len(exec_plan.final_output.load_partitioned_datasets(npartitions, data_partition_column)),
)
self.assertEqual(
npartitions,
len(
exec_plan.get_output("input_partitions").load_partitioned_datasets(
npartitions, data_partition_column, hive_partitioning
)
),
len(exec_plan.get_output("input_partitions").load_partitioned_datasets(npartitions, data_partition_column, hive_partitioning)),
)
return exec_plan
@@ -376,12 +314,8 @@ class TestPartition(TestFabric, unittest.TestCase):
)
ctx = Context()
output1 = DataSourceNode(
ctx, dataset=exec_plan1.get_output("input_partitions")
)
output2 = DataSourceNode(
ctx, dataset=exec_plan2.get_output("input_partitions")
)
output1 = DataSourceNode(ctx, dataset=exec_plan1.get_output("input_partitions"))
output2 = DataSourceNode(ctx, dataset=exec_plan2.get_output("input_partitions"))
split_urls1 = LoadPartitionedDataSetNode(
ctx,
(output1,),
@@ -411,16 +345,8 @@ class TestPartition(TestFabric, unittest.TestCase):
plan = LogicalPlan(ctx, split_urls3)
exec_plan3 = self.execute_plan(plan)
# load each partition as arrow table and compare
final_output_partitions1 = (
exec_plan1.final_output.load_partitioned_datasets(
npartitions, data_partition_column
)
)
final_output_partitions3 = (
exec_plan3.final_output.load_partitioned_datasets(
npartitions, data_partition_column
)
)
final_output_partitions1 = exec_plan1.final_output.load_partitioned_datasets(npartitions, data_partition_column)
final_output_partitions3 = exec_plan3.final_output.load_partitioned_datasets(npartitions, data_partition_column)
self.assertEqual(npartitions, len(final_output_partitions3))
for x, y in zip(final_output_partitions1, final_output_partitions3):
self._compare_arrow_tables(x.to_arrow_table(), y.to_arrow_table())
@@ -433,9 +359,7 @@ class TestPartition(TestFabric, unittest.TestCase):
SqlEngineNode.default_cpu_limit = 1
SqlEngineNode.default_memory_limit = 1 * GB
initial_reduce = r"select host, count(*) as cnt from {0} group by host"
combine_reduce_results = (
r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host"
)
combine_reduce_results = r"select host, cast(sum(cnt) as bigint) as cnt from {0} group by host"
join_query = r"select host, cnt from {0} where (exists (select * from {1} where {1}.host = {0}.host)) and (exists (select * from {2} where {2}.host = {0}.host))"
partition_by_hosts = HashPartitionNode(
@@ -496,11 +420,7 @@ class TestPartition(TestFabric, unittest.TestCase):
url_count_by_3dims = SqlEngineNode(ctx, (partitioned_3dims,), initial_reduce)
url_count_by_hosts_x_urls2 = SqlEngineNode(
ctx,
(
ConsolidateNode(
ctx, url_count_by_3dims, ["host_partition", "url_partition"]
),
),
(ConsolidateNode(ctx, url_count_by_3dims, ["host_partition", "url_partition"]),),
combine_reduce_results,
output_name="url_count_by_hosts_x_urls2",
)
@@ -524,9 +444,7 @@ class TestPartition(TestFabric, unittest.TestCase):
output_name="join_count_by_hosts_x_urls2",
)
union_url_count_by_hosts = UnionNode(
ctx, (url_count_by_hosts1, url_count_by_hosts2)
)
union_url_count_by_hosts = UnionNode(ctx, (url_count_by_hosts1, url_count_by_hosts2))
union_url_count_by_hosts_x_urls = UnionNode(
ctx,
(
@@ -576,18 +494,14 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
parquet_files = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_files)
file_partitions1 = CalculatePartitionFromFilename(
ctx, (data_source,), npartitions=3, dimension="by_filename_hash1"
)
file_partitions1 = CalculatePartitionFromFilename(ctx, (data_source,), npartitions=3, dimension="by_filename_hash1")
url_count1 = SqlEngineNode(
ctx,
(file_partitions1,),
r"select host, count(*) as cnt from {0} group by host",
output_name="url_count1",
)
file_partitions2 = CalculatePartitionFromFilename(
ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2"
)
file_partitions2 = CalculatePartitionFromFilename(ctx, (url_count1,), npartitions=3, dimension="by_filename_hash2")
url_count2 = SqlEngineNode(
ctx,
(file_partitions2,),
@@ -606,9 +520,7 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
evenly_dist_data_source = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files
)
evenly_dist_data_source = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files)
parquet_datasets = [ParquetDataSet([p]) for p in parquet_dataset.resolved_paths]
partitioned_data_source = UserPartitionedDataSourceNode(ctx, parquet_datasets)
@@ -631,9 +543,7 @@ class TestPartition(TestFabric, unittest.TestCase):
memory_limit=1 * GB,
)
plan = LogicalPlan(
ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2])
)
plan = LogicalPlan(ctx, UnionNode(ctx, [url_count_by_host1, url_count_by_host2]))
exec_plan = self.execute_plan(plan, enable_diagnostic_metrics=True)
self._compare_arrow_tables(
exec_plan.get_output("url_count_by_host1").to_arrow_table(),
@@ -647,9 +557,7 @@ class TestPartition(TestFabric, unittest.TestCase):
ctx = Context()
parquet_dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_source = DataSourceNode(ctx, parquet_dataset)
evenly_dist_data_source = EvenlyDistributedPartitionNode(
ctx, (data_source,), npartitions=parquet_dataset.num_files
)
evenly_dist_data_source = EvenlyDistributedPartitionNode(ctx, (data_source,), npartitions=parquet_dataset.num_files)
sql_query = SqlEngineNode(
ctx,
(evenly_dist_data_source,),

View File

@@ -65,6 +65,4 @@ def test_fstest(sp: Session):
def test_sort_mock_urls_v2(sp: Session):
sort_mock_urls_v2(
sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3
)
sort_mock_urls_v2(sp, ["tests/data/mock_urls/*.tsv"], sp._runtime_ctx.output_root, npartitions=3)

View File

@@ -21,9 +21,7 @@ from tests.test_fabric import TestFabric
class RandomSleepTask(PythonScriptTask):
def __init__(
self, *args, sleep_secs: float, fail_first_try: bool, **kwargs
) -> None:
def __init__(self, *args, sleep_secs: float, fail_first_try: bool, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.sleep_secs = sleep_secs
self.fail_first_try = fail_first_try
@@ -58,24 +56,16 @@ class RandomSleepNode(PythonScriptNode):
self.fail_first_try = fail_first_try
def spawn(self, *args, **kwargs) -> RandomSleepTask:
sleep_secs = (
random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs
)
return RandomSleepTask(
*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try
)
sleep_secs = random.random() if len(self.generated_tasks) % 20 else self.max_sleep_secs
return RandomSleepTask(*args, **kwargs, sleep_secs=sleep_secs, fail_first_try=self.fail_first_try)
class TestScheduler(TestFabric, unittest.TestCase):
def create_random_sleep_plan(
self, npartitions, max_sleep_secs, fail_first_try=False
):
def create_random_sleep_plan(self, npartitions, max_sleep_secs, fail_first_try=False):
ctx = Context()
dataset = ParquetDataSet(["tests/data/mock_urls/*.parquet"])
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=npartitions, partition_by_rows=True
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=npartitions, partition_by_rows=True)
random_sleep = RandomSleepNode(
ctx,
(data_partitions,),
@@ -87,13 +77,8 @@ class TestScheduler(TestFabric, unittest.TestCase):
def check_executor_state(self, target_state: ExecutorState, nloops=200):
for _ in range(nloops):
latest_sched_state = self.get_latest_sched_state()
if any(
executor.state == target_state
for executor in latest_sched_state.remote_executors
):
logger.info(
f"found {target_state} executor in: {latest_sched_state.remote_executors}"
)
if any(executor.state == target_state for executor in latest_sched_state.remote_executors):
logger.info(f"found {target_state} executor in: {latest_sched_state.remote_executors}")
break
time.sleep(0.1)
else:
@@ -121,9 +106,7 @@ class TestScheduler(TestFabric, unittest.TestCase):
latest_sched_state = self.get_latest_sched_state()
self.check_executor_state(ExecutorState.GOOD)
for i, (executor, process) in enumerate(
random.sample(list(zip(executors, processes[1:])), k=num_fail)
):
for i, (executor, process) in enumerate(random.sample(list(zip(executors, processes[1:])), k=num_fail)):
if i % 2 == 0:
logger.warning(f"kill executor: {executor}")
process.kill()
@@ -165,9 +148,7 @@ class TestScheduler(TestFabric, unittest.TestCase):
self.assertGreater(len(latest_sched_state.abandoned_tasks), 0)
def test_stop_executor_on_failure(self):
plan = self.create_random_sleep_plan(
npartitions=3, max_sleep_secs=5, fail_first_try=True
)
plan = self.create_random_sleep_plan(npartitions=3, max_sleep_secs=5, fail_first_try=True)
exec_plan = self.execute_plan(
plan,
num_executors=5,

View File

@@ -5,9 +5,7 @@ from smallpond.dataframe import Session
def test_shutdown_cleanup(sp: Session):
assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should exist"
assert os.path.exists(
sp._runtime_ctx.staging_root
), "staging directory should exist"
assert os.path.exists(sp._runtime_ctx.staging_root), "staging directory should exist"
assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should exist"
# create some tasks and complete them
@@ -16,15 +14,9 @@ def test_shutdown_cleanup(sp: Session):
sp.shutdown()
# shutdown should clean up directories
assert not os.path.exists(
sp._runtime_ctx.queue_root
), "queue directory should be cleared"
assert not os.path.exists(
sp._runtime_ctx.staging_root
), "staging directory should be cleared"
assert not os.path.exists(
sp._runtime_ctx.temp_root
), "temp directory should be cleared"
assert not os.path.exists(sp._runtime_ctx.queue_root), "queue directory should be cleared"
assert not os.path.exists(sp._runtime_ctx.staging_root), "staging directory should be cleared"
assert not os.path.exists(sp._runtime_ctx.temp_root), "temp directory should be cleared"
with open(sp._runtime_ctx.job_status_path) as fin:
assert "success" in fin.read(), "job status should be success"
@@ -41,14 +33,8 @@ def test_shutdown_no_cleanup_on_failure(sp: Session):
sp.shutdown()
# shutdown should not clean up directories
assert os.path.exists(
sp._runtime_ctx.queue_root
), "queue directory should not be cleared"
assert os.path.exists(
sp._runtime_ctx.staging_root
), "staging directory should not be cleared"
assert os.path.exists(
sp._runtime_ctx.temp_root
), "temp directory should not be cleared"
assert os.path.exists(sp._runtime_ctx.queue_root), "queue directory should not be cleared"
assert os.path.exists(sp._runtime_ctx.staging_root), "staging directory should not be cleared"
assert os.path.exists(sp._runtime_ctx.temp_root), "temp directory should not be cleared"
with open(sp._runtime_ctx.job_status_path) as fin:
assert "failure" in fin.read(), "job status should be failure"

View File

@@ -85,9 +85,7 @@ class WorkQueueTestBase(object):
def test_multi_consumers(self):
numConsumers = 10
numItems = 200
result = self.pool.starmap_async(
consumer, [(self.wq, id) for id in range(numConsumers)]
)
result = self.pool.starmap_async(consumer, [(self.wq, id) for id in range(numConsumers)])
producer(self.wq, 0, numItems, numConsumers)
logger.info("waiting for result")