mirror of
https://github.com/deepseek-ai/smallpond
synced 2025-06-26 18:27:45 +00:00
reformat code with --line-length=150 (#18)
This commit is contained in:
@@ -64,9 +64,7 @@ def main():
|
||||
driver = Driver()
|
||||
driver.add_argument("-i", "--input_paths", nargs="+")
|
||||
driver.add_argument("-n", "--npartitions", type=int, default=None)
|
||||
driver.add_argument(
|
||||
"-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream")
|
||||
)
|
||||
driver.add_argument("-e", "--io_engine", default="duckdb", choices=("duckdb", "arrow", "stream"))
|
||||
driver.add_argument("-b", "--batch_size", type=int, default=1024 * 1024)
|
||||
driver.add_argument("-s", "--row_group_size", type=int, default=1024 * 1024)
|
||||
driver.add_argument("-o", "--output_name", default="data")
|
||||
|
||||
@@ -73,26 +73,18 @@ def generate_records(
|
||||
subprocess.run(gensort_cmd.split()).check_returncode()
|
||||
runtime_task.add_elapsed_time("generate records (secs)")
|
||||
shm_file.seek(0)
|
||||
buffer = arrow.py_buffer(
|
||||
shm_file.read(record_count * record_nbytes)
|
||||
)
|
||||
buffer = arrow.py_buffer(shm_file.read(record_count * record_nbytes))
|
||||
runtime_task.add_elapsed_time("read records (secs)")
|
||||
# https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout
|
||||
records = arrow.Array.from_buffers(
|
||||
arrow.binary(record_nbytes), record_count, [None, buffer]
|
||||
)
|
||||
records = arrow.Array.from_buffers(arrow.binary(record_nbytes), record_count, [None, buffer])
|
||||
keys = pc.binary_slice(records, 0, key_nbytes)
|
||||
# get first 2 bytes and convert to big-endian uint16
|
||||
binary_prefix = pc.binary_slice(records, 0, 2).cast(arrow.binary())
|
||||
reversed_prefix = pc.binary_reverse(binary_prefix).cast(
|
||||
arrow.binary(2)
|
||||
)
|
||||
reversed_prefix = pc.binary_reverse(binary_prefix).cast(arrow.binary(2))
|
||||
uint16_prefix = reversed_prefix.view(arrow.uint16())
|
||||
buckets = pc.shift_right(uint16_prefix, 16 - bucket_nbits)
|
||||
runtime_task.add_elapsed_time("build arrow table (secs)")
|
||||
yield arrow.Table.from_arrays(
|
||||
[buckets, keys, records], schema=schema
|
||||
)
|
||||
yield arrow.Table.from_arrays([buckets, keys, records], schema=schema)
|
||||
yield StreamOutput(
|
||||
schema.empty_table(),
|
||||
batch_indices=[batch_idx],
|
||||
@@ -108,9 +100,7 @@ def sort_records(
|
||||
write_io_nbytes=500 * MB,
|
||||
) -> bool:
|
||||
runtime_task: PythonScriptTask = runtime_ctx.task
|
||||
data_file_path = os.path.join(
|
||||
runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat"
|
||||
)
|
||||
data_file_path = os.path.join(runtime_task.runtime_output_abspath, f"{runtime_task.output_filename}.dat")
|
||||
|
||||
if sort_engine == "polars":
|
||||
input_data = polars.read_parquet(
|
||||
@@ -134,9 +124,7 @@ def sort_records(
|
||||
record_arrays = sorted_table.column("records").chunks
|
||||
runtime_task.add_elapsed_time("convert to chunks (secs)")
|
||||
elif sort_engine == "duckdb":
|
||||
with duckdb.connect(
|
||||
database=":memory:", config={"allow_unsigned_extensions": "true"}
|
||||
) as conn:
|
||||
with duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"}) as conn:
|
||||
runtime_task.prepare_connection(conn)
|
||||
input_views = runtime_task.create_input_views(conn, input_datasets)
|
||||
sql_query = "select records from {0} order by keys".format(*input_views)
|
||||
@@ -154,8 +142,7 @@ def sort_records(
|
||||
buffer_mem = memoryview(values)
|
||||
|
||||
total_write_nbytes = sum(
|
||||
fout.write(buffer_mem[offset : offset + write_io_nbytes])
|
||||
for offset in range(0, len(buffer_mem), write_io_nbytes)
|
||||
fout.write(buffer_mem[offset : offset + write_io_nbytes]) for offset in range(0, len(buffer_mem), write_io_nbytes)
|
||||
)
|
||||
assert total_write_nbytes == len(buffer_mem)
|
||||
|
||||
@@ -164,16 +151,10 @@ def sort_records(
|
||||
return True
|
||||
|
||||
|
||||
def validate_records(
|
||||
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
|
||||
) -> bool:
|
||||
def validate_records(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
|
||||
for data_path in input_datasets[0].resolved_paths:
|
||||
summary_path = os.path.join(
|
||||
output_path, PurePath(data_path).with_suffix(".sum").name
|
||||
)
|
||||
cmdstr = (
|
||||
f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m"
|
||||
)
|
||||
summary_path = os.path.join(output_path, PurePath(data_path).with_suffix(".sum").name)
|
||||
cmdstr = f"{SortBenchTool.valsort_path} -o {summary_path} {data_path},buf,trans=10m"
|
||||
logging.debug(f"running command: {cmdstr}")
|
||||
result = subprocess.run(cmdstr.split(), capture_output=True, encoding="utf8")
|
||||
if result.stderr:
|
||||
@@ -185,9 +166,7 @@ def validate_records(
|
||||
return True
|
||||
|
||||
|
||||
def validate_summary(
|
||||
runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str
|
||||
) -> bool:
|
||||
def validate_summary(runtime_ctx: RuntimeContext, input_datasets: List[DataSet], output_path: str) -> bool:
|
||||
concated_summary_path = os.path.join(output_path, "merged.sum")
|
||||
with open(concated_summary_path, "wb") as fout:
|
||||
for path in input_datasets[0].resolved_paths:
|
||||
@@ -224,22 +203,13 @@ def generate_random_records(
|
||||
)
|
||||
|
||||
range_begin_at = [pos for pos in range(0, total_num_records, record_range_size)]
|
||||
range_num_records = [
|
||||
min(total_num_records, record_range_size * (range_idx + 1)) - begin_at
|
||||
for range_idx, begin_at in enumerate(range_begin_at)
|
||||
]
|
||||
range_num_records = [min(total_num_records, record_range_size * (range_idx + 1)) - begin_at for range_idx, begin_at in enumerate(range_begin_at)]
|
||||
assert sum(range_num_records) == total_num_records
|
||||
record_range = DataSourceNode(
|
||||
ctx,
|
||||
ArrowTableDataSet(
|
||||
arrow.Table.from_arrays(
|
||||
[range_begin_at, range_num_records], names=["begin_at", "num_records"]
|
||||
)
|
||||
),
|
||||
)
|
||||
record_range_partitions = DataSetPartitionNode(
|
||||
ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True
|
||||
ArrowTableDataSet(arrow.Table.from_arrays([range_begin_at, range_num_records], names=["begin_at", "num_records"])),
|
||||
)
|
||||
record_range_partitions = DataSetPartitionNode(ctx, (record_range,), npartitions=num_data_partitions, partition_by_rows=True)
|
||||
|
||||
random_records = ArrowStreamNode(
|
||||
ctx,
|
||||
@@ -288,9 +258,7 @@ def gray_sort_benchmark(
|
||||
if input_paths:
|
||||
input_dataset = ParquetDataSet(input_paths)
|
||||
input_nbytes = sum(os.path.getsize(p) for p in input_dataset.resolved_paths)
|
||||
logging.warning(
|
||||
f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files"
|
||||
)
|
||||
logging.warning(f"input data size: {input_nbytes/GB:.3f}GB, {input_dataset.num_files} files")
|
||||
random_records = DataSourceNode(ctx, input_dataset)
|
||||
else:
|
||||
random_records = generate_random_records(
|
||||
@@ -335,12 +303,8 @@ def gray_sort_benchmark(
|
||||
process_func=validate_records,
|
||||
output_name="partitioned_summaries",
|
||||
)
|
||||
merged_summaries = DataSetPartitionNode(
|
||||
ctx, (partitioned_summaries,), npartitions=1
|
||||
)
|
||||
final_check = PythonScriptNode(
|
||||
ctx, (merged_summaries,), process_func=validate_summary
|
||||
)
|
||||
merged_summaries = DataSetPartitionNode(ctx, (partitioned_summaries,), npartitions=1)
|
||||
final_check = PythonScriptNode(ctx, (merged_summaries,), process_func=validate_summary)
|
||||
root = final_check
|
||||
else:
|
||||
root = sorted_records
|
||||
@@ -359,17 +323,11 @@ def main():
|
||||
driver.add_argument("-n", "--num_data_partitions", type=int, default=None)
|
||||
driver.add_argument("-t", "--num_sort_partitions", type=int, default=None)
|
||||
driver.add_argument("-i", "--input_paths", nargs="+", default=[])
|
||||
driver.add_argument(
|
||||
"-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow")
|
||||
)
|
||||
driver.add_argument(
|
||||
"-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars")
|
||||
)
|
||||
driver.add_argument("-e", "--shuffle_engine", default="duckdb", choices=("duckdb", "arrow"))
|
||||
driver.add_argument("-s", "--sort_engine", default="duckdb", choices=("duckdb", "arrow", "polars"))
|
||||
driver.add_argument("-H", "--hive_partitioning", action="store_true")
|
||||
driver.add_argument("-V", "--validate_results", action="store_true")
|
||||
driver.add_argument(
|
||||
"-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit
|
||||
)
|
||||
driver.add_argument("-C", "--shuffle_cpu_limit", type=int, default=ShuffleNode.default_cpu_limit)
|
||||
driver.add_argument(
|
||||
"-M",
|
||||
"--shuffle_memory_limit",
|
||||
@@ -378,12 +336,8 @@ def main():
|
||||
)
|
||||
driver.add_argument("-TC", "--sort_cpu_limit", type=int, default=8)
|
||||
driver.add_argument("-TM", "--sort_memory_limit", type=int, default=None)
|
||||
driver.add_argument(
|
||||
"-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False)
|
||||
)
|
||||
driver.add_argument(
|
||||
"-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total
|
||||
)
|
||||
driver.add_argument("-NC", "--cpus_per_node", type=int, default=psutil.cpu_count(logical=False))
|
||||
driver.add_argument("-NM", "--memory_per_node", type=int, default=psutil.virtual_memory().total)
|
||||
driver.add_argument("-CP", "--parquet_compression", default=None)
|
||||
driver.add_argument("-LV", "--parquet_compression_level", type=int, default=None)
|
||||
|
||||
@@ -393,16 +347,9 @@ def main():
|
||||
total_num_cpus = max(1, driver_args.num_executors) * user_args.cpus_per_node
|
||||
memory_per_cpu = user_args.memory_per_node // user_args.cpus_per_node
|
||||
|
||||
user_args.sort_cpu_limit = (
|
||||
1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit
|
||||
)
|
||||
sort_memory_limit = (
|
||||
user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu
|
||||
)
|
||||
user_args.total_data_nbytes = (
|
||||
user_args.total_data_nbytes
|
||||
or max(1, driver_args.num_executors) * user_args.memory_per_node
|
||||
)
|
||||
user_args.sort_cpu_limit = 1 if user_args.sort_engine == "arrow" else user_args.sort_cpu_limit
|
||||
sort_memory_limit = user_args.sort_memory_limit or user_args.sort_cpu_limit * memory_per_cpu
|
||||
user_args.total_data_nbytes = user_args.total_data_nbytes or max(1, driver_args.num_executors) * user_args.memory_per_node
|
||||
user_args.num_data_partitions = user_args.num_data_partitions or total_num_cpus // 2
|
||||
user_args.num_sort_partitions = user_args.num_sort_partitions or max(
|
||||
total_num_cpus // user_args.sort_cpu_limit,
|
||||
|
||||
@@ -70,18 +70,12 @@ def main():
|
||||
driver.add_argument("-i", "--input_paths", nargs="+", required=True)
|
||||
driver.add_argument("-n", "--npartitions", type=int, default=None)
|
||||
driver.add_argument("-c", "--hash_columns", nargs="+", required=True)
|
||||
driver.add_argument(
|
||||
"-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow")
|
||||
)
|
||||
driver.add_argument("-e", "--engine_type", default="duckdb", choices=("duckdb", "arrow"))
|
||||
driver.add_argument("-S", "--partition_stats", action="store_true")
|
||||
driver.add_argument("-W", "--use_parquet_writer", action="store_true")
|
||||
driver.add_argument("-H", "--hive_partitioning", action="store_true")
|
||||
driver.add_argument(
|
||||
"-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit
|
||||
)
|
||||
driver.add_argument(
|
||||
"-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit
|
||||
)
|
||||
driver.add_argument("-C", "--cpu_limit", type=int, default=HashPartitionNode.default_cpu_limit)
|
||||
driver.add_argument("-M", "--memory_limit", type=int, default=HashPartitionNode.default_memory_limit)
|
||||
driver.add_argument("-NC", "--cpus_per_node", type=int, default=192)
|
||||
driver.add_argument("-NM", "--memory_per_node", type=int, default=2000 * GB)
|
||||
|
||||
|
||||
@@ -29,9 +29,7 @@ def urls_sort_benchmark(
|
||||
delim=r"\t",
|
||||
)
|
||||
data_files = DataSourceNode(ctx, dataset)
|
||||
data_partitions = DataSetPartitionNode(
|
||||
ctx, (data_files,), npartitions=num_data_partitions
|
||||
)
|
||||
data_partitions = DataSetPartitionNode(ctx, (data_files,), npartitions=num_data_partitions)
|
||||
|
||||
imported_urls = SqlEngineNode(
|
||||
ctx,
|
||||
@@ -80,16 +78,10 @@ def urls_sort_benchmark_v2(
|
||||
sort_cpu_limit=8,
|
||||
sort_memory_limit=None,
|
||||
):
|
||||
dataset = sp.read_csv(
|
||||
input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t"
|
||||
)
|
||||
dataset = sp.read_csv(input_paths, schema={"urlstr": "varchar", "valstr": "varchar"}, delim=r"\t")
|
||||
data_partitions = dataset.repartition(num_data_partitions)
|
||||
urls_partitions = data_partitions.repartition(
|
||||
num_hash_partitions, hash_by="urlstr", engine_type=engine_type
|
||||
)
|
||||
sorted_urls = urls_partitions.partial_sort(
|
||||
by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit
|
||||
)
|
||||
urls_partitions = data_partitions.repartition(num_hash_partitions, hash_by="urlstr", engine_type=engine_type)
|
||||
sorted_urls = urls_partitions.partial_sort(by="urlstr", cpu_limit=sort_cpu_limit, memory_limit=sort_memory_limit)
|
||||
sorted_urls.write_parquet(output_path)
|
||||
|
||||
|
||||
@@ -106,12 +98,8 @@ def main():
|
||||
num_nodes = driver_args.num_executors
|
||||
cpus_per_node = 120
|
||||
partition_rounds = 2
|
||||
user_args.num_data_partitions = (
|
||||
user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds
|
||||
)
|
||||
user_args.num_hash_partitions = (
|
||||
user_args.num_hash_partitions or num_nodes * cpus_per_node
|
||||
)
|
||||
user_args.num_data_partitions = user_args.num_data_partitions or num_nodes * cpus_per_node * partition_rounds
|
||||
user_args.num_hash_partitions = user_args.num_hash_partitions or num_nodes * cpus_per_node
|
||||
|
||||
# v1
|
||||
plan = urls_sort_benchmark(**vars(user_args))
|
||||
|
||||
Reference in New Issue
Block a user