reformat code with --line-length=150 (#18)

This commit is contained in:
Runji Wang
2025-03-05 22:46:23 +08:00
committed by GitHub
parent ed112db42a
commit 52ecc5e455
48 changed files with 794 additions and 2604 deletions

View File

@@ -64,9 +64,7 @@ def main():
driver = Driver()
driver.add_argument("-i", "--input_paths", nargs="+")
driver.add_argument("-n", "--npartitions", type=int, default=None)
driver.add_argument(
"-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream")
)
driver.add_argument("-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream"))
driver.add_argument("-b", "--batch_size", type=int, default=1024 * 1024)
driver.add_argument("-s", "--row_group_size", type=int, default=1024 * 1024)
driver.add_argument("-o", "--output_name", default="data")

View File

@@ -73,26 +73,18 @@ def generate_records(
subprocess.run(gensort_cmd.split()).check_returncode()
runtime_task.add_elapsed_time("generate records (secs)")
shm_file.seek(0)
buffer = arrow.py_buffer(
shm_file.read(record_count * record_nbytes)
)
buffer = arrow.py_buffer(shm_file.read(record_count * record_nbytes))
runtime_task.add_elapsed_time("read records (secs)")
# https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout
records = arrow.Array.from_buffers(
arrow.binary(record_nbytes), record_count, [None, buffer]
)
records = arrow.Array.from_buffers(arrow.binary(record_nbytes), record_count, [None, buffer])
keys = pc.binary_slice(records, 0, key_nbytes)
# get first 2 bytes and convert to big-endian uint16
binary_prefix = pc.binary_slice(records, 0, 2).cast(arrow.binary())
reversed_prefix = pc.binary_reverse(binary_prefix).cast(
arrow.binary(2)
)
reversed_prefix = pc.binary_reverse(binary_prefix).cast(arrow.binary(2))
uint16_prefix = reversed_prefix.view(arrow.uint16())
buckets = pc.shift_right(uint16_prefix, 16 - bucket_nbits)
runtime_task.add_elapsed_time("build arrow table (secs)")
yield arrow.Table.from_arrays(
[buckets, keys, records], schema=schema
)
yield arrow.Table.from_arrays([buckets, keys, records], schema=schema)
yield StreamOutput(
schema.empty_table(),
batch_indices=[batch_idx],
@@ -108,9 +100,7 @@ def sort_records(
write_io_nbytes=500 * MB,
) -> bool:
runtime_task: PythonScriptTask = runtime_ctx.task
data_file_path = os.path.join(
runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat"
)
data_file_path = os.path.join(runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat")
if sort_engine == "polars":
input_data = polars.read_parquet(
@@ -134,9 +124,7 @@ def sort_records(
record_arrays = sorted_table.column("records").chunks
runtime_task.add_elapsed_time("convert to chunks (secs)")
elif sort_engine == "duckdb":
with duckdb.connect(
database=":memory:", config={"allow_unsigned_extensions": "true"}
) as conn:
with duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"}) as conn:
runtime_task.prepare_connection(conn)
input_views = runtime_task.create_input_views(conn, input_datasets)
sql_query = "select records from {0} order by keys".format(*input_views)
@@ -154,8 +142,7 @@ def sort_records(
buffer_mem = memoryview(values)
total_write_nbytes = sum(
fout.write(buffer_mem[offset : offset + write_io_nbytes])
for offset in range(0, len(buffer_mem), write_io_nbytes)
fout.write(buffer_mem[offset : offset + write_io_nbytes]) for offset in range(0, len(buffer_mem), write_io_nbytes)
)
assert total_write_nbytes == len(buffer_mem)
@@ -164,16 +151,10 @@ def sort_records(
return True
def validate_records(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
def validate_records(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
for data_path in input_datasets[0].resolved_paths:
summary_path = os.path.join(
output_path, PurePath(data_path).with_suffix(".sum").name
)
cmdstr = (
f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m"
)
summary_path = os.path.join(output_path, PurePath(data_path).with_suffix(".sum").name)
cmdstr = f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m"
logging.debug(f"running command: {cmdstr}")
result = subprocess.run(cmdstr.split(), capture_output=True, encoding="utf8")
if result.stderr:
@@ -185,9 +166,7 @@ def validate_records(
return True
def validate_summary(
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
) -> bool:
def validate_summary(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
concated_summary_path = os.path.join(output_path, "merged.sum")
with open(concated_summary_path, "wb") as fout:
for path in input_datasets[0].resolved_paths:
@@ -224,22 +203,13 @@ def generate_random_records(
)
range_begin_at = [pos for pos in range(0, total_num_records, record_range_size)]
range_num_records = [
min(total_num_records, record_range_size * (range_idx + 1)) - begin_at
for range_idx, begin_at in enumerate(range_begin_at)
]
range_num_records = [min(total_num_records, record_range_size * (range_idx + 1)) - begin_at for range_idx, begin_at in enumerate(range_begin_at)]
assert sum(range_num_records) == total_num_records
record_range = DataSourceNode(
ctx,
ArrowTableDataSet(
arrow.Table.from_arrays(
[range_begin_at, range_num_records], names=["begin_at", "num_records"]
)
),
)
record_range_partitions = DataSetPartitionNode(
ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True
ArrowTableDataSet(arrow.Table.from_arrays([range_begin_at, range_num_records], names=["begin_at", "num_records"])),
)
record_range_partitions = DataSetPartitionNode(ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True)
random_records = ArrowStreamNode(
ctx,
@@ -288,9 +258,7 @@ def gray_sort_benchmark(
if input_paths:
input_dataset = ParquetDataSet(input_paths)
input_nbytes = sum(os.path.getsize(p) for p in input_dataset.resolved_paths)
logging.warning(
f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files"
)
logging.warning(f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files")
random_records = DataSourceNode(ctx, input_dataset)
else:
random_records = generate_random_records(
@@ -335,12 +303,8 @@ def gray_sort_benchmark(
process_func=validate_records,
output_name="partitioned_summaries",
)
merged_summaries = DataSetPartitionNode(
ctx, (partitioned_summaries,), npartitions=1
)
final_check = PythonScriptNode(
ctx, (merged_summaries,), process_func=validate_summary
)
merged_summaries = DataSetPartitionNode(ctx, (partitioned_summaries,), npartitions=1)
final_check = PythonScriptNode(ctx, (merged_summaries,), process_func=validate_summary)
root = final_check
else:
root = sorted_records
@@ -359,17 +323,11 @@ def main():
driver.add_argument("-n", "--num_data_partitions", type=int, default=None)
driver.add_argument("-t", "--num_sort_partitions", type=int, default=None)
driver.add_argument("-i", "--input_paths", nargs="+", default=[])
driver.add_argument(
"-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow")
)
driver.add_argument(
"-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars")
)
driver.add_argument("-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow"))
driver.add_argument("-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars"))
driver.add_argument("-H", "--hive_partitioning", action="store_true")
driver.add_argument("-V", "--validate_results", action="store_true")
driver.add_argument(
"-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit
)
driver.add_argument("-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit)
driver.add_argument(
"-M",
"--shuffle_memory_limit",
@@ -378,12 +336,8 @@ def main():
)
driver.add_argument("-TC", "--sort_cpu_limit", type=int, default=8)
driver.add_argument("-TM", "--sort_memory_limit", type=int, default=None)
driver.add_argument(
"-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False)
)
driver.add_argument(
"-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total
)
driver.add_argument("-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False))
driver.add_argument("-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total)
driver.add_argument("-CP", "--parquet_compression", default=None)
driver.add_argument("-LV", "--parquet_compression_level", type=int, default=None)
@@ -393,16 +347,9 @@ def main():
total_num_cpus = max(1, driver_args.num_executors) * user_args.cpus_per_node
memory_per_cpu = user_args.memory_per_node // user_args.cpus_per_node
user_args.sort_cpu_limit = (
1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit
)
sort_memory_limit = (
user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu
)
user_args.total_data_nbytes = (
user_args.total_data_nbytes
or max(1, driver_args.num_executors) * user_args.memory_per_node
)
user_args.sort_cpu_limit = 1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit
sort_memory_limit = user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu
user_args.total_data_nbytes = user_args.total_data_nbytes or max(1, driver_args.num_executors) * user_args.memory_per_node
user_args.num_data_partitions = user_args.num_data_partitions or total_num_cpus // 2
user_args.num_sort_partitions = user_args.num_sort_partitions or max(
total_num_cpus // user_args.sort_cpu_limit,

View File

@@ -70,18 +70,12 @@ def main():
driver.add_argument("-i", "--input_paths", nargs="+", required=True)
driver.add_argument("-n", "--npartitions", type=int, default=None)
driver.add_argument("-c", "--hash_columns", nargs="+", required=True)
driver.add_argument(
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
)
driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow"))
driver.add_argument("-S", "--partition_stats", action="store_true")
driver.add_argument("-W", "--use_parquet_writer", action="store_true")
driver.add_argument("-H", "--hive_partitioning", action="store_true")
driver.add_argument(
"-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit
)
driver.add_argument(
"-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit
)
driver.add_argument("-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit)
driver.add_argument("-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit)
driver.add_argument("-NC", "--cpus_per_node", type=int, default=192)
driver.add_argument("-NM", "--memory_per_node", type=int, default=2000 * GB)

View File

@@ -29,9 +29,7 @@ def urls_sort_benchmark(
delim=r"\t",
)
data_files = DataSourceNode(ctx, dataset)
data_partitions = DataSetPartitionNode(
ctx, (data_files,), npartitions=num_data_partitions
)
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=num_data_partitions)
imported_urls = SqlEngineNode(
ctx,
@@ -80,16 +78,10 @@ def urls_sort_benchmark_v2(
sort_cpu_limit=8,
sort_memory_limit=None,
):
dataset = sp.read_csv(
input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t"
)
dataset = sp.read_csv(input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t")
data_partitions = dataset.repartition(num_data_partitions)
urls_partitions = data_partitions.repartition(
num_hash_partitions, hash_by="urlstr", engine_type=engine_type
)
sorted_urls = urls_partitions.partial_sort(
by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit
)
urls_partitions = data_partitions.repartition(num_hash_partitions, hash_by="urlstr", engine_type=engine_type)
sorted_urls = urls_partitions.partial_sort(by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit)
sorted_urls.write_parquet(output_path)
@@ -106,12 +98,8 @@ def main():
num_nodes = driver_args.num_executors
cpus_per_node = 120
partition_rounds = 2
user_args.num_data_partitions = (
user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds
)
user_args.num_hash_partitions = (
user_args.num_hash_partitions or num_nodes * cpus_per_node
)
user_args.num_data_partitions = user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds
user_args.num_hash_partitions = user_args.num_hash_partitions or num_nodes * cpus_per_node
# v1
plan = urls_sort_benchmark(**vars(user_args))